feat(score): allow to recalculate all score pp

This commit is contained in:
MingxuanGame
2025-08-14 06:50:17 +00:00
parent c8b6c1fa0e
commit ebbc0b8252
5 changed files with 163 additions and 11 deletions

View File

@@ -518,7 +518,7 @@ async def get_user_best_pp(
session: AsyncSession, session: AsyncSession,
user: int, user: int,
mode: GameMode, mode: GameMode,
limit: int = 200, limit: int = 1000,
) -> Sequence[PPBestScore]: ) -> Sequence[PPBestScore]:
return ( return (
await session.exec( await session.exec(

View File

@@ -2,10 +2,10 @@ from __future__ import annotations
from .beatmap import BeatmapFetcher from .beatmap import BeatmapFetcher
from .beatmapset import BeatmapsetFetcher from .beatmapset import BeatmapsetFetcher
from .osu_dot_direct import OsuDotDirectFetcher from .osu_dot_direct import BeatmapRawFetcher
class Fetcher(BeatmapFetcher, BeatmapsetFetcher, OsuDotDirectFetcher): class Fetcher(BeatmapFetcher, BeatmapsetFetcher, BeatmapRawFetcher):
"""A class that combines all fetchers for easy access.""" """A class that combines all fetchers for easy access."""
pass pass

View File

@@ -3,21 +3,38 @@ from __future__ import annotations
from ._base import BaseFetcher from ._base import BaseFetcher
from httpx import AsyncClient from httpx import AsyncClient
from httpx._models import Response
from loguru import logger from loguru import logger
import redis.asyncio as redis import redis.asyncio as redis
urls = [
"https://osu.ppy.sh/osu/{beatmap_id}",
"https://osu.direct/api/osu/{beatmap_id}",
"https://catboy.best/osu/{beatmap_id}",
]
class OsuDotDirectFetcher(BaseFetcher):
class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str: async def get_beatmap_raw(self, beatmap_id: int) -> str:
logger.opt(colors=True).debug( for url in urls:
f"<blue>[OsuDotDirectFetcher]</blue> get_beatmap_raw: <y>{beatmap_id}</y>" req_url = url.format(beatmap_id=beatmap_id)
) logger.opt(colors=True).debug(
f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>"
)
resp = await self._request(req_url)
if resp.status_code == 429:
continue
elif resp.status_code < 400:
return resp.text
else:
resp.raise_for_status()
async def _request(self, url: str) -> Response:
async with AsyncClient() as client: async with AsyncClient() as client:
response = await client.get( response = await client.get(
f"https://osu.direct/api/osu/{beatmap_id}/raw", url,
) )
response.raise_for_status() return response
return response.text
async def get_or_fetch_beatmap_raw( async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int self, redis: redis.Redis, beatmap_id: int

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
import asyncio
import math
from app.calculator import (
calculate_pp,
calculate_weighted_acc,
calculate_weighted_pp,
clamp,
)
from app.config import settings
from app.database import UserStatistics
from app.database.beatmap import Beatmap
from app.database.pp_best_score import PPBestScore
from app.database.score import Score
from app.dependencies.database import engine, get_redis
from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.log import logger
from app.models.mods import mods_can_get_pp
from app.models.score import MODE_TO_INT, GameMode
from httpx import HTTPError
from redis.asyncio import Redis
from sqlmodel import col, delete, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def recalculate_all_players_pp():
async with AsyncSession(engine, autoflush=False) as session:
fetcher = await get_fetcher()
redis = get_redis()
for mode in GameMode:
await session.execute(
delete(PPBestScore).where(col(PPBestScore.gamemode) == mode)
)
logger.info(f"Recalculating PP for mode: {mode}")
statistics_list = (
await session.exec(
select(UserStatistics).where(UserStatistics.mode == mode)
)
).all()
await asyncio.gather(
*[
_recalculate_pp(statistics, session, fetcher, redis)
for statistics in statistics_list
]
)
await session.commit()
logger.success(
f"Recalculated PP for mode: {mode}, total: {len(statistics_list)}"
)
async def _recalculate_pp(
statistics: UserStatistics, session: AsyncSession, fetcher: Fetcher, redis: Redis
):
scores = (
await session.exec(
select(Score).where(
Score.user_id == statistics.user_id,
Score.gamemode == statistics.mode,
col(Score.passed).is_(True),
)
)
).all()
score_list: list[tuple[float, float]] = []
prev: dict[int, PPBestScore] = {}
for score in scores:
time = 10
beatmap_id = score.beatmap_id
while time > 0:
try:
db_beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=beatmap_id
)
except HTTPError:
time -= 1
await asyncio.sleep(2)
continue
ranked = (
db_beatmap.beatmap_status.has_pp()
| settings.enable_all_beatmap_leaderboard
)
if not ranked or not mods_can_get_pp(
MODE_TO_INT[score.gamemode], score.mods
):
score.pp = 0
break
try:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await asyncio.get_event_loop().run_in_executor(
None, calculate_pp, score, beatmap_raw
)
logger.info(f"{score.user_id} {score.id} pp: {pp}")
score.pp = pp
if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
best_score = PPBestScore(
user_id=statistics.user_id,
beatmap_id=beatmap_id,
acc=score.accuracy,
score_id=score.id,
pp=pp,
gamemode=score.gamemode,
)
prev[score.beatmap_id] = best_score
score_list.append((score.pp, score.accuracy))
break
except HTTPError:
time -= 1
await asyncio.sleep(2)
continue
if time <= 0:
logger.error(f"Failed to fetch beatmap {beatmap_id} after 10 attempts")
score.pp = 0
# according to pp desc
score_list.sort(key=lambda x: x[0], reverse=True)
pp_sum = 0
acc_sum = 0
for i, s in enumerate(score_list):
pp_sum += calculate_weighted_pp(s[0], i)
acc_sum += calculate_weighted_acc(s[1], i)
if len(score_list):
# https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45
acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(score_list))))
acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum
for best_score in prev.values():
session.add(best_score)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
import os
from app.config import settings from app.config import settings
from app.dependencies.database import engine, redis_client from app.dependencies.database import engine, redis_client
@@ -20,6 +21,7 @@ from app.router.redirect import redirect_router
from app.service.calculate_all_user_rank import calculate_user_rank from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.daily_challenge import daily_challenge_job from app.service.daily_challenge import daily_challenge_job
from app.service.osu_rx_statistics import create_rx_statistics from app.service.osu_rx_statistics import create_rx_statistics
from app.service.pp_recalculate import recalculate_all_players_pp
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@@ -31,9 +33,11 @@ import sentry_sdk
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# on startup # on startup
await get_fetcher() # 初始化 fetcher
if os.environ.get("RECALCULATE_PP", "false").lower() == "true":
await recalculate_all_players_pp()
await create_rx_statistics() await create_rx_statistics()
await calculate_user_rank(True) await calculate_user_rank(True)
await get_fetcher() # 初始化 fetcher
init_scheduler() init_scheduler()
await daily_challenge_job() await daily_challenge_job()
# on shutdown # on shutdown