From 8790ccad649ba4d621ab3a942302ead2b85955e6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 18 Oct 2025 19:10:53 +0800 Subject: [PATCH] feat(pp-calculator): support other pp calculators (#57) New configurations: - CALCULATOR="rosu": specific pp calculator - CALCULATOR_CONFIG='{}': argument passed through into calculator --- .github/scripts/generate_config_doc.py | 17 +- app/calculator.py | 163 +++-------------- app/calculators/performance/__init__.py | 3 + app/calculators/performance/_base.py | 37 ++++ .../performance/performance_server.py | 88 +++++++++ app/calculators/performance/rosu.py | 169 ++++++++++++++++++ app/config.py | 34 +++- app/database/beatmap.py | 13 +- app/models/beatmap.py | 17 -- app/models/performance.py | 81 +++++++++ app/models/score.py | 19 +- app/router/v1/beatmap.py | 8 +- app/router/v2/beatmap.py | 10 +- docker-compose-osurx.yml | 11 ++ docker-compose.yml | 11 ++ main.py | 4 +- 16 files changed, 496 insertions(+), 189 deletions(-) create mode 100644 app/calculators/performance/__init__.py create mode 100644 app/calculators/performance/_base.py create mode 100644 app/calculators/performance/performance_server.py create mode 100644 app/calculators/performance/rosu.py create mode 100644 app/models/performance.py diff --git a/.github/scripts/generate_config_doc.py b/.github/scripts/generate_config_doc.py index 5c53867..7b9ed4c 100644 --- a/.github/scripts/generate_config_doc.py +++ b/.github/scripts/generate_config_doc.py @@ -1,11 +1,12 @@ import datetime from enum import Enum import importlib.util +from inspect import isclass import json from pathlib import Path import sys from types import NoneType, UnionType -from typing import Any, Union, get_origin +from typing import Any, Literal, Union, get_origin from pydantic import AliasChoices, BaseModel, HttpUrl from pydantic_settings import BaseSettings @@ -64,6 +65,7 @@ BASE_TYPE_MAPPING = { dict: "object", NoneType: "null", HttpUrl: "string (url)", + Any: "any", } @@ -81,9 +83,16 @@ def mapping_type(typ: type) -> str: if len(args) == 1: return f"array[{mapping_type(args[0])}]" return "array" - if issubclass(typ, Enum): + elif get_origin(typ) is dict: + args = typ.__args__ + if len(args) == 2: + return f"object[{mapping_type(args[0])}, {mapping_type(args[1])}]" + return "object" + elif get_origin(typ) is Literal: + return f"enum({', '.join([str(n) for n in typ.__args__])})" + elif isclass(typ) and issubclass(typ, Enum): return f"enum({', '.join([e.value for e in typ])})" - elif issubclass(typ, BaseSettings): + elif isclass(typ) and issubclass(typ, BaseSettings): return typ.__name__ return "unknown" @@ -126,7 +135,7 @@ doc.extend( [ module.SPECTATOR_DOC, "", - f"> 上次生成:{datetime.datetime.now(datetime.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')}" + f"> 上次生成:{datetime.datetime.now(datetime.UTC).strftime('%Y-%m-%d %H:%M:%S %Z')} " f"于提交 {f'[`{commit}`](https://github.com/GooGuTeam/g0v0-server/commit/{commit})' if commit != 'unknown' else 'unknown'}", # noqa: E501 "", "> **注意: 在生产环境中,请务必更改默认的密钥和密码!**", diff --git a/app/calculator.py b/app/calculator.py index ffcabec..a0aef79 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -1,13 +1,12 @@ import asyncio -from copy import deepcopy from enum import Enum +import importlib import math from typing import TYPE_CHECKING +from app.calculators.performance import PerformanceCalculator from app.config import settings from app.log import log -from app.models.beatmap import BeatmapAttributes -from app.models.mods import APIMod, parse_enum_to_str from app.models.score import GameMode from osupyparser import HitObject, OsuFile @@ -16,23 +15,32 @@ from redis.asyncio import Redis from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession -logger = log("Calculator") - -try: - import rosu_pp_py as rosu -except ImportError: - raise ImportError( - "rosu-pp-py is not installed. " - "Please install it.\n" - " Official: uv add rosu-pp-py\n" - " ppy-sb: uv add git+https://github.com/ppy-sb/rosu-pp-py.git" - ) - if TYPE_CHECKING: from app.database.score import Score from app.fetcher import Fetcher +logger = log("Calculator") + +CALCULATOR: PerformanceCalculator | None = None + + +def init_calculator(): + global CALCULATOR + try: + module = importlib.import_module(f"app.calculators.performance.{settings.calculator}") + CALCULATOR = module.PerformanceCalculator(**settings.calculator_config) + except (ImportError, AttributeError) as e: + raise ImportError(f"Failed to import performance calculator for {settings.calculator}") from e + return CALCULATOR + + +def get_calculator() -> PerformanceCalculator: + if CALCULATOR is None: + raise RuntimeError("Performance calculator is not initialized") + return CALCULATOR + + def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T: if n < min_value: return min_value @@ -42,29 +50,6 @@ def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T: return n -def calculate_beatmap_attribute( - beatmap: str, - gamemode: GameMode | None = None, - mods: int | list[APIMod] | list[str] = 0, -) -> BeatmapAttributes: - map = rosu.Beatmap(content=beatmap) - if gamemode is not None: - map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType] - diff = rosu.Difficulty(mods=mods).calculate(map) - return BeatmapAttributes( - star_rating=diff.stars, - max_combo=diff.max_combo, - aim_difficulty=diff.aim, - aim_difficult_slider_count=diff.aim_difficult_slider_count, - speed_difficulty=diff.speed, - speed_note_count=diff.speed_note_count, - slider_factor=diff.slider_factor, - aim_difficult_strain_count=diff.aim_difficult_strain_count, - speed_difficult_strain_count=diff.speed_difficult_strain_count, - mono_stamina_factor=diff.stamina, - ) - - async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> float: from app.database.beatmap import BannedBeatmaps @@ -83,41 +68,13 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f except Exception: logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious") - # 使用线程池执行计算密集型操作以避免阻塞事件循环 - - loop = asyncio.get_event_loop() - - def _calculate_pp_sync(): - map = rosu.Beatmap(content=beatmap) - mods = deepcopy(score.mods.copy()) - parse_enum_to_str(int(score.gamemode), mods) - map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType] - perf = rosu.Performance( - mods=mods, - lazer=True, - accuracy=clamp(score.accuracy * 100, 0, 100), - combo=score.max_combo, - large_tick_hits=score.nlarge_tick_hit or 0, - slider_end_hits=score.nslider_tail_hit or 0, - small_tick_hits=score.nsmall_tick_hit or 0, - n_geki=score.ngeki, - n_katu=score.nkatu, - n300=score.n300, - n100=score.n100, - n50=score.n50, - misses=score.nmiss, - ) - return perf.calculate(map) - - # 在线程池中执行计算 - attrs = await loop.run_in_executor(None, _calculate_pp_sync) + attrs = await get_calculator().calculate_performance(beatmap, score) pp = attrs.pp - # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp - if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 3000): + if settings.suspicious_score_check and (pp > 3000): logger.warning( f"User {score.user_id} played {score.beatmap_id} " - f"(star={attrs.difficulty.stars}) with {pp=} " + f"with {pp=} " f"acc={score.accuracy}. The score is suspicious and return 0pp" f"({score.id=})" ) @@ -170,74 +127,6 @@ async def pre_fetch_and_calculate_pp( return await calculate_pp(score, beatmap_raw, session), True -async def batch_calculate_pp( - scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher -) -> list[float]: - """ - 批量计算PP:适用于重新计算或批量处理场景 - Args: - scores_data: [(score, beatmap_id), ...] 的列表 - Returns: - 对应的PP值列表 - """ - import asyncio - - from app.database.beatmap import BannedBeatmaps - - if not scores_data: - return [] - - # 提取所有唯一的beatmap_id - unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data}) - - # 批量检查被封禁的beatmap - banned_beatmaps = set() - if settings.suspicious_score_check: - banned_results = await session.exec( - select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)) - ) - banned_beatmaps = set(banned_results.all()) - - # 并发获取所有需要的beatmap原始文件 - async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]: - if beatmap_id in banned_beatmaps: - return beatmap_id, None - try: - content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) - return beatmap_id, content - except Exception as e: - logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}") - return beatmap_id, None - - # 并发获取所有beatmap文件 - fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids] - fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True) - - # 构建beatmap_id -> content的映射 - beatmap_contents = {} - for result in fetch_results: - if isinstance(result, tuple): - beatmap_id, content = result - beatmap_contents[beatmap_id] = content - - # 为每个score计算PP - pp_results = [] - for score, beatmap_id in scores_data: - beatmap_content = beatmap_contents.get(beatmap_id) - if beatmap_content is None: - pp_results.append(0.0) - continue - - try: - pp = await calculate_pp(score, beatmap_content, session) - pp_results.append(pp) - except Exception as e: - logger.error(f"Failed to calculate PP for score {score.id}: {e}") - pp_results.append(0.0) - - return pp_results - - # https://osu.ppy.sh/wiki/Gameplay/Score/Total_score def calculate_level_to_score(n: int) -> float: if n <= 100: diff --git a/app/calculators/performance/__init__.py b/app/calculators/performance/__init__.py new file mode 100644 index 0000000..d217d9c --- /dev/null +++ b/app/calculators/performance/__init__.py @@ -0,0 +1,3 @@ +from ._base import CalculateError, ConvertError, DifficultyError, PerformanceCalculator, PerformanceError + +__all__ = ["CalculateError", "ConvertError", "DifficultyError", "PerformanceCalculator", "PerformanceError"] diff --git a/app/calculators/performance/_base.py b/app/calculators/performance/_base.py new file mode 100644 index 0000000..4aaf481 --- /dev/null +++ b/app/calculators/performance/_base.py @@ -0,0 +1,37 @@ +import abc +from typing import TYPE_CHECKING + +from app.models.mods import APIMod +from app.models.performance import BeatmapAttributes, PerformanceAttributes +from app.models.score import GameMode + +if TYPE_CHECKING: + from app.database.score import Score + + +class CalculateError(Exception): + """An error occurred during performance calculation.""" + + +class DifficultyError(CalculateError): + """The difficulty could not be calculated.""" + + +class ConvertError(DifficultyError): + """A beatmap cannot be converted to the specified game mode.""" + + +class PerformanceError(CalculateError): + """The performance could not be calculated.""" + + +class PerformanceCalculator(abc.ABC): + @abc.abstractmethod + async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes: + raise NotImplementedError + + @abc.abstractmethod + async def calculate_difficulty( + self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None + ) -> BeatmapAttributes: + raise NotImplementedError diff --git a/app/calculators/performance/performance_server.py b/app/calculators/performance/performance_server.py new file mode 100644 index 0000000..20480ee --- /dev/null +++ b/app/calculators/performance/performance_server.py @@ -0,0 +1,88 @@ +from typing import TYPE_CHECKING + +from app.models.mods import APIMod +from app.models.performance import ( + DIFFICULTY_CLASS, + PERFORMANCE_CLASS, + BeatmapAttributes, + PerformanceAttributes, +) +from app.models.score import GameMode + +from ._base import ( + CalculateError, + DifficultyError, + PerformanceCalculator as BasePerformanceCalculator, + PerformanceError, +) + +from httpx import AsyncClient, HTTPError + +if TYPE_CHECKING: + from app.database.score import Score + + +class PerformanceCalculator(BasePerformanceCalculator): + def __init__(self, server_url: str = "http://localhost:5225") -> None: + self.server_url = server_url + + async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes: + # https://github.com/GooGuTeam/osu-performance-server#post-performance + async with AsyncClient() as client: + try: + resp = await client.post( + f"{self.server_url}/performance", + json={ + "beatmap_id": score.beatmap_id, + "beatmap_file": beatmap_raw, + "checksum": score.map_md5, + "accuracy": score.accuracy, + "combo": score.max_combo, + "mods": score.mods, + "statistics": { + "great": score.n300, + "ok": score.n100, + "meh": score.n50, + "miss": score.nmiss, + "perfect": score.ngeki, + "good": score.nkatu, + "large_tick_hit": score.nlarge_tick_hit or 0, + "large_tick_miss": score.nlarge_tick_miss or 0, + "small_tick_hit": score.nsmall_tick_hit or 0, + "slider_tail_hit": score.nslider_tail_hit or 0, + }, + "ruleset": score.gamemode.value, + }, + ) + if resp.status_code != 200: + raise PerformanceError(f"Failed to calculate performance: {resp.text}") + data = resp.json() + return PERFORMANCE_CLASS.get(score.gamemode, PerformanceAttributes).model_validate(data) + except HTTPError as e: + raise PerformanceError(f"Failed to calculate performance: {e}") from e + except Exception as e: + raise CalculateError(f"Unknown error: {e}") from e + + async def calculate_difficulty( + self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None + ) -> BeatmapAttributes: + # https://github.com/GooGuTeam/osu-performance-server#post-difficulty + async with AsyncClient() as client: + try: + resp = await client.post( + f"{self.server_url}/difficulty", + json={ + "beatmap_file": beatmap_raw, + "mods": mods or [], + "ruleset": int(gamemode) if gamemode else None, + }, + ) + if resp.status_code != 200: + raise DifficultyError(f"Failed to calculate difficulty: {resp.text}") + data = resp.json() + ruleset_id = data.pop("ruleset", "osu") + return DIFFICULTY_CLASS.get(GameMode(ruleset_id), BeatmapAttributes).model_validate(data) + except HTTPError as e: + raise DifficultyError(f"Failed to calculate difficulty: {e}") from e + except Exception as e: + raise DifficultyError(f"Unknown error: {e}") from e diff --git a/app/calculators/performance/rosu.py b/app/calculators/performance/rosu.py new file mode 100644 index 0000000..244e1f9 --- /dev/null +++ b/app/calculators/performance/rosu.py @@ -0,0 +1,169 @@ +from asyncio import get_event_loop +from copy import deepcopy +from typing import TYPE_CHECKING + +from app.calculator import clamp +from app.models.mods import APIMod, parse_enum_to_str +from app.models.performance import ( + DIFFICULTY_CLASS, + PERFORMANCE_CLASS, + BeatmapAttributes, + ManiaPerformanceAttributes, + OsuBeatmapAttributes, + OsuPerformanceAttributes, + PerformanceAttributes, + TaikoBeatmapAttributes, + TaikoPerformanceAttributes, +) +from app.models.score import GameMode + +from ._base import ( + CalculateError, + ConvertError, + DifficultyError, + PerformanceCalculator as BasePerformanceCalculator, + PerformanceError, +) + +if TYPE_CHECKING: + from app.database.score import Score + +try: + import rosu_pp_py as rosu +except ImportError: + raise ImportError( + "rosu-pp-py is not installed. " + "Please install it.\n" + " Official: uv add rosu-pp-py\n" + " gu: uv add git+https://github.com/GooGuTeam/gu-pp-py.git" + ) + + +class PerformanceCalculator(BasePerformanceCalculator): + @classmethod + def _to_rosu_mode(cls, mode: GameMode) -> rosu.GameMode: + return { + GameMode.OSU: rosu.GameMode.Osu, + GameMode.TAIKO: rosu.GameMode.Taiko, + GameMode.FRUITS: rosu.GameMode.Catch, + GameMode.MANIA: rosu.GameMode.Mania, + GameMode.OSURX: rosu.GameMode.Osu, + GameMode.OSUAP: rosu.GameMode.Osu, + GameMode.TAIKORX: rosu.GameMode.Taiko, + GameMode.FRUITSRX: rosu.GameMode.Catch, + }[mode] + + @classmethod + def _from_rosu_mode(cls, mode: rosu.GameMode) -> GameMode: + return { + rosu.GameMode.Osu: GameMode.OSU, + rosu.GameMode.Taiko: GameMode.TAIKO, + rosu.GameMode.Catch: GameMode.FRUITS, + rosu.GameMode.Mania: GameMode.MANIA, + }[mode] + + @classmethod + def _perf_attr_to_model(cls, attr: rosu.PerformanceAttributes, gamemode: GameMode) -> PerformanceAttributes: + attr_class = PERFORMANCE_CLASS.get(gamemode, PerformanceAttributes) + + if attr_class is OsuPerformanceAttributes: + return OsuPerformanceAttributes( + pp=attr.pp, + aim=attr.pp_aim or 0, + speed=attr.pp_speed or 0, + accuracy=attr.pp_accuracy or 0, + flashlight=attr.pp_flashlight or 0, + effective_miss_count=attr.effective_miss_count or 0, + speed_deviation=attr.speed_deviation, + ) + elif attr_class is TaikoPerformanceAttributes: + return TaikoPerformanceAttributes( + pp=attr.pp, + difficulty=attr.pp_difficulty or 0, + accuracy=attr.pp_accuracy or 0, + estimated_unstable_rate=attr.estimated_unstable_rate, + ) + elif attr_class is ManiaPerformanceAttributes: + return ManiaPerformanceAttributes( + pp=attr.pp, + difficulty=attr.pp_difficulty or 0, + ) + else: + return PerformanceAttributes(pp=attr.pp) + + async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes: + try: + map = rosu.Beatmap(content=beatmap_raw) + mods = deepcopy(score.mods.copy()) + parse_enum_to_str(int(score.gamemode), mods) + map.convert(self._to_rosu_mode(score.gamemode), mods) # pyright: ignore[reportArgumentType] + perf = rosu.Performance( + mods=mods, + lazer=True, + accuracy=clamp(score.accuracy * 100, 0, 100), + combo=score.max_combo, + large_tick_hits=score.nlarge_tick_hit or 0, + slider_end_hits=score.nslider_tail_hit or 0, + small_tick_hits=score.nsmall_tick_hit or 0, + n_geki=score.ngeki, + n_katu=score.nkatu, + n300=score.n300, + n100=score.n100, + n50=score.n50, + misses=score.nmiss, + ) + attr = await get_event_loop().run_in_executor(None, perf.calculate, map) + return self._perf_attr_to_model(attr, score.gamemode.to_base_ruleset()) + except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue] + raise PerformanceError(f"Beatmap parse error: {e}") + except Exception as e: + raise CalculateError(f"Unknown error: {e}") from e + + @classmethod + def _diff_attr_to_model(cls, diff: rosu.DifficultyAttributes, gamemode: GameMode) -> BeatmapAttributes: + attr_class = DIFFICULTY_CLASS.get(gamemode, BeatmapAttributes) + + if attr_class is OsuBeatmapAttributes: + return OsuBeatmapAttributes( + star_rating=diff.stars, + max_combo=diff.max_combo, + aim_difficulty=diff.aim or 0, + aim_difficult_slider_count=diff.aim_difficult_slider_count or 0, + speed_difficulty=diff.speed or 0, + speed_note_count=diff.speed_note_count or 0, + slider_factor=diff.slider_factor or 0, + aim_difficult_strain_count=diff.aim_difficult_strain_count or 0, + speed_difficult_strain_count=diff.speed_difficult_strain_count or 0, + flashlight_difficulty=diff.flashlight or 0, + ) + elif attr_class is TaikoBeatmapAttributes: + return TaikoBeatmapAttributes( + star_rating=diff.stars, + max_combo=diff.max_combo, + rhythm_difficulty=diff.rhythm or 0, + mono_stamina_factor=diff.stamina or 0, + ) + else: + return BeatmapAttributes( + star_rating=diff.stars, + max_combo=diff.max_combo, + ) + + async def calculate_difficulty( + self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None + ) -> BeatmapAttributes: + try: + map = rosu.Beatmap(content=beatmap_raw) + if gamemode is not None: + map.convert(self._to_rosu_mode(gamemode), mods) # pyright: ignore[reportArgumentType] + diff_calculator = rosu.Difficulty(mods=mods) + diff = await get_event_loop().run_in_executor(None, diff_calculator.calculate, map) + return self._diff_attr_to_model( + diff, gamemode.to_base_ruleset() if gamemode else self._from_rosu_mode(diff.mode) + ) + except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] + raise ConvertError(f"Beatmap convert error: {e}") + except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue] + raise DifficultyError(f"Beatmap parse error: {e}") + except Exception as e: + raise CalculateError(f"Unknown error: {e}") from e diff --git a/app/config.py b/app/config.py index 4559ebf..dddac0e 100644 --- a/app/config.py +++ b/app/config.py @@ -103,6 +103,23 @@ STORAGE_SETTINGS='{ "s3_public_url_base": "https://your-custom-domain.com" }' ``` +""", + "表现计算设置": """配置表现分计算器及其参数。 + +### rosu-pp-py (默认) + +```bash +CALCULATOR="rosu" +CALCULATOR_CONFIG='{}' +``` + +### [osu-performance-server](https://github.com/GooGuTeam/osu-performance-server) + +```bash +CALCULATOR="performance_server" +CALCULATOR_CONFIG='{ + "server_url": "http://localhost:5225" +}' """, } }, @@ -486,6 +503,21 @@ STORAGE_SETTINGS='{ "游戏设置", ] + # 表现计算设置 + calculator: Annotated[ + Literal["rosu", "performance_server"], + Field(default="rosu", description="表现分计算器"), + "表现计算设置", + ] + calculator_config: Annotated[ + dict[str, Any], + Field( + default={}, + description="表现分计算器配置 (JSON 格式),具体配置项请参考上方", + ), + "表现计算设置", + ] + # 谱面缓存设置 enable_beatmap_preload: Annotated[ bool, @@ -612,7 +644,7 @@ STORAGE_SETTINGS='{ # 反作弊设置 suspicious_score_check: Annotated[ bool, - Field(default=True, description="启用可疑分数检查(star>25&acc<80 或 pp>3000)"), + Field(default=True, description="启用可疑分数检查(pp>3000)"), "反作弊设置", ] banned_name: Annotated[ diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 2cfb3a6..072902a 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -1,14 +1,14 @@ -import asyncio from datetime import datetime import hashlib from typing import TYPE_CHECKING -from app.calculator import calculate_beatmap_attribute +from app.calculator import get_calculator from app.config import settings from app.database.beatmap_tags import BeatmapTagVote from app.database.failtime import FailTime, FailTimeResp -from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus +from app.models.beatmap import BeatmapRankStatus from app.models.mods import APIMod +from app.models.performance import DIFFICULTY_CLASS, BeatmapAttributes from app.models.score import GameMode from .beatmap_playcounts import BeatmapPlaycounts @@ -247,10 +247,13 @@ async def calculate_beatmap_attributes( redis: Redis, fetcher: "Fetcher", ): + attr_class = DIFFICULTY_CLASS.get(ruleset, BeatmapAttributes) + key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes" if await redis.exists(key): - return BeatmapAttributes.model_validate_json(await redis.get(key)) + return attr_class.model_validate_json(await redis.get(key)) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) - attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_) + + attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset) await redis.set(key, attr.model_dump_json()) return attr diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 068a97f..e0c5802 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -69,23 +69,6 @@ class Language(IntEnum): OTHER = 14 -class BeatmapAttributes(BaseModel): - star_rating: float - max_combo: int - - # osu - aim_difficulty: float | None = None - aim_difficult_slider_count: float | None = None - speed_difficulty: float | None = None - speed_note_count: float | None = None - slider_factor: float | None = None - aim_difficult_strain_count: float | None = None - speed_difficult_strain_count: float | None = None - - # taiko - mono_stamina_factor: float | None = None - - def _parse_list(v: Any): if isinstance(v, str): return v.split(".") diff --git a/app/models/performance.py b/app/models/performance.py new file mode 100644 index 0000000..f75d364 --- /dev/null +++ b/app/models/performance.py @@ -0,0 +1,81 @@ +from app.models.score import GameMode + +from pydantic import BaseModel + + +class PerformanceAttributes(BaseModel): + pp: float + + +# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceAttributes.cs +class OsuPerformanceAttributes(PerformanceAttributes): + aim: float + speed: float + accuracy: float + flashlight: float + effective_miss_count: float + speed_deviation: float | None = None + + # 2025 Q3 update + # combo_based_estimated_miss_count: int + # score_based_estimated_miss_count: int | None = None + # aim_estimated_slider_breaks: int + # speed_estimated_slider_breaks: int + + +# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Taiko/Difficulty/TaikoPerformanceAttributes.cs +class TaikoPerformanceAttributes(PerformanceAttributes): + difficulty: float + accuracy: float + estimated_unstable_rate: float | None = None + + +# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Mania/Difficulty/ManiaPerformanceAttributes.cs +class ManiaPerformanceAttributes(PerformanceAttributes): + difficulty: float + + +PERFORMANCE_CLASS: dict[GameMode, type[PerformanceAttributes]] = { + GameMode.OSU: OsuPerformanceAttributes, + GameMode.MANIA: ManiaPerformanceAttributes, + GameMode.TAIKO: TaikoPerformanceAttributes, +} + + +class BeatmapAttributes(BaseModel): + star_rating: float + max_combo: int + + +# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +class OsuBeatmapAttributes(BeatmapAttributes): + aim_difficulty: float + aim_difficult_slider_count: float + speed_difficulty: float + speed_note_count: float + flashlight_difficulty: float | None = None + slider_factor: float + aim_difficult_strain_count: float + speed_difficult_strain_count: float + + # 2025 Q3 update + # aim_top_weighted_slider_factor: float + # speed_top_weighted_slider_factor: float + # nested_score_per_object: float + # legacy_score_base_multiplier: float + # maximum_legacy_combo_score: float + + +# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Taiko/Difficulty/TaikoDifficultyAttributes.cs +class TaikoBeatmapAttributes(BeatmapAttributes): + rhythm_difficulty: float + mono_stamina_factor: float + + # 2025 Q3 update + # consistency_factor: float + + +DIFFICULTY_CLASS: dict[GameMode, type[BeatmapAttributes]] = { + GameMode.OSU: OsuBeatmapAttributes, + GameMode.TAIKO: TaikoBeatmapAttributes, +} diff --git a/app/models/score.py b/app/models/score.py index 4102d54..abf2d68 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import TYPE_CHECKING, Literal, TypedDict, cast +from typing import Literal, TypedDict, cast from app.config import settings @@ -7,9 +7,6 @@ from .mods import API_MODS, APIMod from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator -if TYPE_CHECKING: - import rosu_pp_py as rosu - class GameMode(str, Enum): OSU = "osu" @@ -21,20 +18,6 @@ class GameMode(str, Enum): TAIKORX = "taikorx" FRUITSRX = "fruitsrx" - def to_rosu(self) -> "rosu.GameMode": - import rosu_pp_py as rosu - - return { - GameMode.OSU: rosu.GameMode.Osu, - GameMode.TAIKO: rosu.GameMode.Taiko, - GameMode.FRUITS: rosu.GameMode.Catch, - GameMode.MANIA: rosu.GameMode.Mania, - GameMode.OSURX: rosu.GameMode.Osu, - GameMode.OSUAP: rosu.GameMode.Osu, - GameMode.TAIKORX: rosu.GameMode.Taiko, - GameMode.FRUITSRX: rosu.GameMode.Catch, - }[self] - def __int__(self) -> int: return { GameMode.OSU: 0, diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py index 1df2048..a440d37 100644 --- a/app/router/v1/beatmap.py +++ b/app/router/v1/beatmap.py @@ -10,6 +10,7 @@ from app.dependencies.database import Database, Redis from app.dependencies.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.mods import int_to_mods +from app.models.performance import OsuBeatmapAttributes from app.models.score import GameMode from .router import AllStrModel, router @@ -193,7 +194,12 @@ async def get_beatmaps( redis, fetcher, ) - results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty)) + aim_diff = None + speed_diff = None + if isinstance(attrs, OsuBeatmapAttributes): + aim_diff = attrs.aim_difficulty + speed_diff = attrs.speed_difficulty + results.append(await V1Beatmap.from_db(session, beatmap, aim_diff, speed_diff)) continue except Exception: ... diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index 281e542..31cf816 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -3,14 +3,15 @@ import hashlib import json from typing import Annotated +from app.calculators.performance import ConvertError from app.database import Beatmap, BeatmapResp, User from app.database.beatmap import calculate_beatmap_attributes from app.dependencies.database import Database, Redis from app.dependencies.fetcher import Fetcher from app.dependencies.user import get_current_user from app.helpers.asset_proxy_helper import asset_proxy_response -from app.models.beatmap import BeatmapAttributes from app.models.mods import APIMod, int_to_mods +from app.models.performance import BeatmapAttributes, OsuBeatmapAttributes, TaikoBeatmapAttributes from app.models.score import ( GameMode, ) @@ -20,7 +21,6 @@ from .router import router from fastapi import HTTPException, Path, Query, Security from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -import rosu_pp_py as rosu from sqlmodel import col, select @@ -127,7 +127,7 @@ async def batch_get_beatmaps( "/beatmaps/{beatmap_id}/attributes", tags=["谱面"], name="计算谱面属性", - response_model=BeatmapAttributes, + response_model=BeatmapAttributes | OsuBeatmapAttributes | TaikoBeatmapAttributes, description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"), ) async def get_beatmap_attributes( @@ -171,5 +171,5 @@ async def get_beatmap_attributes( return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher) except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") - except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] - raise HTTPException(status_code=400, detail=str(e)) from e + except ConvertError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/docker-compose-osurx.yml b/docker-compose-osurx.yml index b06bad1..daffb48 100644 --- a/docker-compose-osurx.yml +++ b/docker-compose-osurx.yml @@ -17,6 +17,8 @@ services: - ENABLE_ALL_MODS_PP=true - ENABLE_SUPPORTER_FOR_ALL_USERS=true - ENABLE_ALL_BEATMAP_LEADERBOARD=true + # - CALCULATOR=performance_server + # - CALCULATOR_CONFIG='{"server_url":"http://performance-server:8080"}' env_file: - .env depends_on: @@ -109,6 +111,15 @@ services: networks: - osu-network + # performance-server: + # image: ghcr.io/googuteam/osu-performance-server-osurx:custom-rulesets + # container_name: performance_server_osurx + # environment: + # - SAVE_BEATMAP_FILES=false + # restart: unless-stopped + # networks: + # - osu-network + volumes: mysql_data: redis_data: diff --git a/docker-compose.yml b/docker-compose.yml index a2a769f..2587b51 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,8 @@ services: - MYSQL_HOST=mysql - MYSQL_PORT=3306 - REDIS_URL=redis://redis:6379 + # - CALCULATOR=performance_server + # - CALCULATOR_CONFIG='{"server_url":"http://performance-server:8080"}' env_file: - .env depends_on: @@ -102,6 +104,15 @@ services: - osu-network command: redis-server --appendonly yes + # performance-server: + # image: ghcr.io/googuteam/osu-performance-server:custom-rulesets + # container_name: performance_server + # environment: + # - SAVE_BEATMAP_FILES=false + # restart: unless-stopped + # networks: + # - osu-network + volumes: mysql_data: redis_data: diff --git a/main.py b/main.py index f6e8e87..05efd64 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager import json from pathlib import Path +from app.calculator import init_calculator from app.config import settings from app.database import User from app.dependencies.database import ( @@ -58,10 +59,11 @@ import sentry_sdk @asynccontextmanager async def lifespan(app: FastAPI): # noqa: ARG001 # === on startup === - # init mods and achievements + # init mods, achievements and performance calculator init_mods() init_ranked_mods() load_achievements() + init_calculator() # init rate limiter await FastAPILimiter.init(redis_rate_limit_client)