From ebbc0b825213febbab8e83963b236ed750900a0d Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 14 Aug 2025 06:50:17 +0000 Subject: [PATCH] feat(score): allow to recalculate all score pp --- app/database/score.py | 2 +- app/fetcher/__init__.py | 4 +- app/fetcher/osu_dot_direct.py | 31 ++++++-- app/service/pp_recalculate.py | 131 ++++++++++++++++++++++++++++++++++ main.py | 6 +- 5 files changed, 163 insertions(+), 11 deletions(-) create mode 100644 app/service/pp_recalculate.py diff --git a/app/database/score.py b/app/database/score.py index da204c0..bfc7840 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -518,7 +518,7 @@ async def get_user_best_pp( session: AsyncSession, user: int, mode: GameMode, - limit: int = 200, + limit: int = 1000, ) -> Sequence[PPBestScore]: return ( await session.exec( diff --git a/app/fetcher/__init__.py b/app/fetcher/__init__.py index 7e74fc9..6d8e46d 100644 --- a/app/fetcher/__init__.py +++ b/app/fetcher/__init__.py @@ -2,10 +2,10 @@ from __future__ import annotations from .beatmap import BeatmapFetcher from .beatmapset import BeatmapsetFetcher -from .osu_dot_direct import OsuDotDirectFetcher +from .osu_dot_direct import BeatmapRawFetcher -class Fetcher(BeatmapFetcher, BeatmapsetFetcher, OsuDotDirectFetcher): +class Fetcher(BeatmapFetcher, BeatmapsetFetcher, BeatmapRawFetcher): """A class that combines all fetchers for easy access.""" pass diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 6e18435..529eae4 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -3,21 +3,38 @@ from __future__ import annotations from ._base import BaseFetcher from httpx import AsyncClient +from httpx._models import Response from loguru import logger import redis.asyncio as redis +urls = [ + "https://osu.ppy.sh/osu/{beatmap_id}", + "https://osu.direct/api/osu/{beatmap_id}", + "https://catboy.best/osu/{beatmap_id}", +] -class OsuDotDirectFetcher(BaseFetcher): + +class BeatmapRawFetcher(BaseFetcher): async def get_beatmap_raw(self, beatmap_id: int) -> str: - logger.opt(colors=True).debug( - f"[OsuDotDirectFetcher] get_beatmap_raw: {beatmap_id}" - ) + for url in urls: + req_url = url.format(beatmap_id=beatmap_id) + logger.opt(colors=True).debug( + f"[BeatmapRawFetcher] get_beatmap_raw: {req_url}" + ) + resp = await self._request(req_url) + if resp.status_code == 429: + continue + elif resp.status_code < 400: + return resp.text + else: + resp.raise_for_status() + + async def _request(self, url: str) -> Response: async with AsyncClient() as client: response = await client.get( - f"https://osu.direct/api/osu/{beatmap_id}/raw", + url, ) - response.raise_for_status() - return response.text + return response async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int diff --git a/app/service/pp_recalculate.py b/app/service/pp_recalculate.py new file mode 100644 index 0000000..c6b2ad5 --- /dev/null +++ b/app/service/pp_recalculate.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import asyncio +import math + +from app.calculator import ( + calculate_pp, + calculate_weighted_acc, + calculate_weighted_pp, + clamp, +) +from app.config import settings +from app.database import UserStatistics +from app.database.beatmap import Beatmap +from app.database.pp_best_score import PPBestScore +from app.database.score import Score +from app.dependencies.database import engine, get_redis +from app.dependencies.fetcher import get_fetcher +from app.fetcher import Fetcher +from app.log import logger +from app.models.mods import mods_can_get_pp +from app.models.score import MODE_TO_INT, GameMode + +from httpx import HTTPError +from redis.asyncio import Redis +from sqlmodel import col, delete, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def recalculate_all_players_pp(): + async with AsyncSession(engine, autoflush=False) as session: + fetcher = await get_fetcher() + redis = get_redis() + for mode in GameMode: + await session.execute( + delete(PPBestScore).where(col(PPBestScore.gamemode) == mode) + ) + logger.info(f"Recalculating PP for mode: {mode}") + statistics_list = ( + await session.exec( + select(UserStatistics).where(UserStatistics.mode == mode) + ) + ).all() + await asyncio.gather( + *[ + _recalculate_pp(statistics, session, fetcher, redis) + for statistics in statistics_list + ] + ) + await session.commit() + logger.success( + f"Recalculated PP for mode: {mode}, total: {len(statistics_list)}" + ) + + +async def _recalculate_pp( + statistics: UserStatistics, session: AsyncSession, fetcher: Fetcher, redis: Redis +): + scores = ( + await session.exec( + select(Score).where( + Score.user_id == statistics.user_id, + Score.gamemode == statistics.mode, + col(Score.passed).is_(True), + ) + ) + ).all() + score_list: list[tuple[float, float]] = [] + prev: dict[int, PPBestScore] = {} + for score in scores: + time = 10 + beatmap_id = score.beatmap_id + while time > 0: + try: + db_beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=beatmap_id + ) + except HTTPError: + time -= 1 + await asyncio.sleep(2) + continue + ranked = ( + db_beatmap.beatmap_status.has_pp() + | settings.enable_all_beatmap_leaderboard + ) + if not ranked or not mods_can_get_pp( + MODE_TO_INT[score.gamemode], score.mods + ): + score.pp = 0 + break + try: + beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + pp = await asyncio.get_event_loop().run_in_executor( + None, calculate_pp, score, beatmap_raw + ) + logger.info(f"{score.user_id} {score.id} pp: {pp}") + score.pp = pp + if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp: + best_score = PPBestScore( + user_id=statistics.user_id, + beatmap_id=beatmap_id, + acc=score.accuracy, + score_id=score.id, + pp=pp, + gamemode=score.gamemode, + ) + prev[score.beatmap_id] = best_score + score_list.append((score.pp, score.accuracy)) + break + except HTTPError: + time -= 1 + await asyncio.sleep(2) + continue + if time <= 0: + logger.error(f"Failed to fetch beatmap {beatmap_id} after 10 attempts") + score.pp = 0 + # according to pp desc + score_list.sort(key=lambda x: x[0], reverse=True) + pp_sum = 0 + acc_sum = 0 + for i, s in enumerate(score_list): + pp_sum += calculate_weighted_pp(s[0], i) + acc_sum += calculate_weighted_acc(s[1], i) + if len(score_list): + # https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45 + acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(score_list)))) + acc_sum = clamp(acc_sum, 0.0, 100.0) + statistics.pp = pp_sum + statistics.hit_accuracy = acc_sum + for best_score in prev.values(): + session.add(best_score) diff --git a/main.py b/main.py index b5bd9d6..5cde52a 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import asynccontextmanager from datetime import datetime +import os from app.config import settings from app.dependencies.database import engine, redis_client @@ -20,6 +21,7 @@ from app.router.redirect import redirect_router from app.service.calculate_all_user_rank import calculate_user_rank from app.service.daily_challenge import daily_challenge_job from app.service.osu_rx_statistics import create_rx_statistics +from app.service.pp_recalculate import recalculate_all_players_pp from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError @@ -31,9 +33,11 @@ import sentry_sdk @asynccontextmanager async def lifespan(app: FastAPI): # on startup + await get_fetcher() # 初始化 fetcher + if os.environ.get("RECALCULATE_PP", "false").lower() == "true": + await recalculate_all_players_pp() await create_rx_statistics() await calculate_user_rank(True) - await get_fetcher() # 初始化 fetcher init_scheduler() await daily_challenge_job() # on shutdown