diff --git a/app/database/score.py b/app/database/score.py
index da204c0..bfc7840 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -518,7 +518,7 @@ async def get_user_best_pp(
session: AsyncSession,
user: int,
mode: GameMode,
- limit: int = 200,
+ limit: int = 1000,
) -> Sequence[PPBestScore]:
return (
await session.exec(
diff --git a/app/fetcher/__init__.py b/app/fetcher/__init__.py
index 7e74fc9..6d8e46d 100644
--- a/app/fetcher/__init__.py
+++ b/app/fetcher/__init__.py
@@ -2,10 +2,10 @@ from __future__ import annotations
from .beatmap import BeatmapFetcher
from .beatmapset import BeatmapsetFetcher
-from .osu_dot_direct import OsuDotDirectFetcher
+from .osu_dot_direct import BeatmapRawFetcher
-class Fetcher(BeatmapFetcher, BeatmapsetFetcher, OsuDotDirectFetcher):
+class Fetcher(BeatmapFetcher, BeatmapsetFetcher, BeatmapRawFetcher):
"""A class that combines all fetchers for easy access."""
pass
diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py
index 6e18435..529eae4 100644
--- a/app/fetcher/osu_dot_direct.py
+++ b/app/fetcher/osu_dot_direct.py
@@ -3,21 +3,38 @@ from __future__ import annotations
from ._base import BaseFetcher
from httpx import AsyncClient
+from httpx._models import Response
from loguru import logger
import redis.asyncio as redis
+urls = [
+ "https://osu.ppy.sh/osu/{beatmap_id}",
+ "https://osu.direct/api/osu/{beatmap_id}",
+ "https://catboy.best/osu/{beatmap_id}",
+]
-class OsuDotDirectFetcher(BaseFetcher):
+
+class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str:
- logger.opt(colors=True).debug(
- f"[OsuDotDirectFetcher] get_beatmap_raw: {beatmap_id}"
- )
+ for url in urls:
+ req_url = url.format(beatmap_id=beatmap_id)
+ logger.opt(colors=True).debug(
+ f"[BeatmapRawFetcher] get_beatmap_raw: {req_url}"
+ )
+ resp = await self._request(req_url)
+ if resp.status_code == 429:
+ continue
+ elif resp.status_code < 400:
+ return resp.text
+ else:
+ resp.raise_for_status()
+
+ async def _request(self, url: str) -> Response:
async with AsyncClient() as client:
response = await client.get(
- f"https://osu.direct/api/osu/{beatmap_id}/raw",
+ url,
)
- response.raise_for_status()
- return response.text
+ return response
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
diff --git a/app/service/pp_recalculate.py b/app/service/pp_recalculate.py
new file mode 100644
index 0000000..c6b2ad5
--- /dev/null
+++ b/app/service/pp_recalculate.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import asyncio
+import math
+
+from app.calculator import (
+ calculate_pp,
+ calculate_weighted_acc,
+ calculate_weighted_pp,
+ clamp,
+)
+from app.config import settings
+from app.database import UserStatistics
+from app.database.beatmap import Beatmap
+from app.database.pp_best_score import PPBestScore
+from app.database.score import Score
+from app.dependencies.database import engine, get_redis
+from app.dependencies.fetcher import get_fetcher
+from app.fetcher import Fetcher
+from app.log import logger
+from app.models.mods import mods_can_get_pp
+from app.models.score import MODE_TO_INT, GameMode
+
+from httpx import HTTPError
+from redis.asyncio import Redis
+from sqlmodel import col, delete, select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+async def recalculate_all_players_pp():
+ async with AsyncSession(engine, autoflush=False) as session:
+ fetcher = await get_fetcher()
+ redis = get_redis()
+ for mode in GameMode:
+ await session.execute(
+ delete(PPBestScore).where(col(PPBestScore.gamemode) == mode)
+ )
+ logger.info(f"Recalculating PP for mode: {mode}")
+ statistics_list = (
+ await session.exec(
+ select(UserStatistics).where(UserStatistics.mode == mode)
+ )
+ ).all()
+ await asyncio.gather(
+ *[
+ _recalculate_pp(statistics, session, fetcher, redis)
+ for statistics in statistics_list
+ ]
+ )
+ await session.commit()
+ logger.success(
+ f"Recalculated PP for mode: {mode}, total: {len(statistics_list)}"
+ )
+
+
+async def _recalculate_pp(
+ statistics: UserStatistics, session: AsyncSession, fetcher: Fetcher, redis: Redis
+):
+ scores = (
+ await session.exec(
+ select(Score).where(
+ Score.user_id == statistics.user_id,
+ Score.gamemode == statistics.mode,
+ col(Score.passed).is_(True),
+ )
+ )
+ ).all()
+ score_list: list[tuple[float, float]] = []
+ prev: dict[int, PPBestScore] = {}
+ for score in scores:
+ time = 10
+ beatmap_id = score.beatmap_id
+ while time > 0:
+ try:
+ db_beatmap = await Beatmap.get_or_fetch(
+ session, fetcher, bid=beatmap_id
+ )
+ except HTTPError:
+ time -= 1
+ await asyncio.sleep(2)
+ continue
+ ranked = (
+ db_beatmap.beatmap_status.has_pp()
+ | settings.enable_all_beatmap_leaderboard
+ )
+ if not ranked or not mods_can_get_pp(
+ MODE_TO_INT[score.gamemode], score.mods
+ ):
+ score.pp = 0
+ break
+ try:
+ beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
+ pp = await asyncio.get_event_loop().run_in_executor(
+ None, calculate_pp, score, beatmap_raw
+ )
+ logger.info(f"{score.user_id} {score.id} pp: {pp}")
+ score.pp = pp
+ if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
+ best_score = PPBestScore(
+ user_id=statistics.user_id,
+ beatmap_id=beatmap_id,
+ acc=score.accuracy,
+ score_id=score.id,
+ pp=pp,
+ gamemode=score.gamemode,
+ )
+ prev[score.beatmap_id] = best_score
+ score_list.append((score.pp, score.accuracy))
+ break
+ except HTTPError:
+ time -= 1
+ await asyncio.sleep(2)
+ continue
+ if time <= 0:
+ logger.error(f"Failed to fetch beatmap {beatmap_id} after 10 attempts")
+ score.pp = 0
+ # according to pp desc
+ score_list.sort(key=lambda x: x[0], reverse=True)
+ pp_sum = 0
+ acc_sum = 0
+ for i, s in enumerate(score_list):
+ pp_sum += calculate_weighted_pp(s[0], i)
+ acc_sum += calculate_weighted_acc(s[1], i)
+ if len(score_list):
+ # https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45
+ acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(score_list))))
+ acc_sum = clamp(acc_sum, 0.0, 100.0)
+ statistics.pp = pp_sum
+ statistics.hit_accuracy = acc_sum
+ for best_score in prev.values():
+ session.add(best_score)
diff --git a/main.py b/main.py
index b5bd9d6..5cde52a 100644
--- a/main.py
+++ b/main.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime
+import os
from app.config import settings
from app.dependencies.database import engine, redis_client
@@ -20,6 +21,7 @@ from app.router.redirect import redirect_router
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.daily_challenge import daily_challenge_job
from app.service.osu_rx_statistics import create_rx_statistics
+from app.service.pp_recalculate import recalculate_all_players_pp
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
@@ -31,9 +33,11 @@ import sentry_sdk
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
+ await get_fetcher() # 初始化 fetcher
+ if os.environ.get("RECALCULATE_PP", "false").lower() == "true":
+ await recalculate_all_players_pp()
await create_rx_statistics()
await calculate_user_rank(True)
- await get_fetcher() # 初始化 fetcher
init_scheduler()
await daily_challenge_job()
# on shutdown