feat(pp-calculator): support other pp calculators (#57)
New configurations:
- CALCULATOR="rosu": specific pp calculator
- CALCULATOR_CONFIG='{}': argument passed through into calculator
This commit is contained in:
3
app/calculators/performance/__init__.py
Normal file
3
app/calculators/performance/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._base import CalculateError, ConvertError, DifficultyError, PerformanceCalculator, PerformanceError
|
||||
|
||||
__all__ = ["CalculateError", "ConvertError", "DifficultyError", "PerformanceCalculator", "PerformanceError"]
|
||||
37
app/calculators/performance/_base.py
Normal file
37
app/calculators/performance/_base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import abc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.mods import APIMod
|
||||
from app.models.performance import BeatmapAttributes, PerformanceAttributes
|
||||
from app.models.score import GameMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.score import Score
|
||||
|
||||
|
||||
class CalculateError(Exception):
|
||||
"""An error occurred during performance calculation."""
|
||||
|
||||
|
||||
class DifficultyError(CalculateError):
|
||||
"""The difficulty could not be calculated."""
|
||||
|
||||
|
||||
class ConvertError(DifficultyError):
|
||||
"""A beatmap cannot be converted to the specified game mode."""
|
||||
|
||||
|
||||
class PerformanceError(CalculateError):
|
||||
"""The performance could not be calculated."""
|
||||
|
||||
|
||||
class PerformanceCalculator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def calculate_difficulty(
|
||||
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
|
||||
) -> BeatmapAttributes:
|
||||
raise NotImplementedError
|
||||
88
app/calculators/performance/performance_server.py
Normal file
88
app/calculators/performance/performance_server.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.mods import APIMod
|
||||
from app.models.performance import (
|
||||
DIFFICULTY_CLASS,
|
||||
PERFORMANCE_CLASS,
|
||||
BeatmapAttributes,
|
||||
PerformanceAttributes,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
|
||||
from ._base import (
|
||||
CalculateError,
|
||||
DifficultyError,
|
||||
PerformanceCalculator as BasePerformanceCalculator,
|
||||
PerformanceError,
|
||||
)
|
||||
|
||||
from httpx import AsyncClient, HTTPError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.score import Score
|
||||
|
||||
|
||||
class PerformanceCalculator(BasePerformanceCalculator):
|
||||
def __init__(self, server_url: str = "http://localhost:5225") -> None:
|
||||
self.server_url = server_url
|
||||
|
||||
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
|
||||
# https://github.com/GooGuTeam/osu-performance-server#post-performance
|
||||
async with AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.server_url}/performance",
|
||||
json={
|
||||
"beatmap_id": score.beatmap_id,
|
||||
"beatmap_file": beatmap_raw,
|
||||
"checksum": score.map_md5,
|
||||
"accuracy": score.accuracy,
|
||||
"combo": score.max_combo,
|
||||
"mods": score.mods,
|
||||
"statistics": {
|
||||
"great": score.n300,
|
||||
"ok": score.n100,
|
||||
"meh": score.n50,
|
||||
"miss": score.nmiss,
|
||||
"perfect": score.ngeki,
|
||||
"good": score.nkatu,
|
||||
"large_tick_hit": score.nlarge_tick_hit or 0,
|
||||
"large_tick_miss": score.nlarge_tick_miss or 0,
|
||||
"small_tick_hit": score.nsmall_tick_hit or 0,
|
||||
"slider_tail_hit": score.nslider_tail_hit or 0,
|
||||
},
|
||||
"ruleset": score.gamemode.value,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise PerformanceError(f"Failed to calculate performance: {resp.text}")
|
||||
data = resp.json()
|
||||
return PERFORMANCE_CLASS.get(score.gamemode, PerformanceAttributes).model_validate(data)
|
||||
except HTTPError as e:
|
||||
raise PerformanceError(f"Failed to calculate performance: {e}") from e
|
||||
except Exception as e:
|
||||
raise CalculateError(f"Unknown error: {e}") from e
|
||||
|
||||
async def calculate_difficulty(
|
||||
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
|
||||
) -> BeatmapAttributes:
|
||||
# https://github.com/GooGuTeam/osu-performance-server#post-difficulty
|
||||
async with AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.server_url}/difficulty",
|
||||
json={
|
||||
"beatmap_file": beatmap_raw,
|
||||
"mods": mods or [],
|
||||
"ruleset": int(gamemode) if gamemode else None,
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise DifficultyError(f"Failed to calculate difficulty: {resp.text}")
|
||||
data = resp.json()
|
||||
ruleset_id = data.pop("ruleset", "osu")
|
||||
return DIFFICULTY_CLASS.get(GameMode(ruleset_id), BeatmapAttributes).model_validate(data)
|
||||
except HTTPError as e:
|
||||
raise DifficultyError(f"Failed to calculate difficulty: {e}") from e
|
||||
except Exception as e:
|
||||
raise DifficultyError(f"Unknown error: {e}") from e
|
||||
169
app/calculators/performance/rosu.py
Normal file
169
app/calculators/performance/rosu.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from asyncio import get_event_loop
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.models.mods import APIMod, parse_enum_to_str
|
||||
from app.models.performance import (
|
||||
DIFFICULTY_CLASS,
|
||||
PERFORMANCE_CLASS,
|
||||
BeatmapAttributes,
|
||||
ManiaPerformanceAttributes,
|
||||
OsuBeatmapAttributes,
|
||||
OsuPerformanceAttributes,
|
||||
PerformanceAttributes,
|
||||
TaikoBeatmapAttributes,
|
||||
TaikoPerformanceAttributes,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
|
||||
from ._base import (
|
||||
CalculateError,
|
||||
ConvertError,
|
||||
DifficultyError,
|
||||
PerformanceCalculator as BasePerformanceCalculator,
|
||||
PerformanceError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.score import Score
|
||||
|
||||
try:
|
||||
import rosu_pp_py as rosu
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"rosu-pp-py is not installed. "
|
||||
"Please install it.\n"
|
||||
" Official: uv add rosu-pp-py\n"
|
||||
" gu: uv add git+https://github.com/GooGuTeam/gu-pp-py.git"
|
||||
)
|
||||
|
||||
|
||||
class PerformanceCalculator(BasePerformanceCalculator):
|
||||
@classmethod
|
||||
def _to_rosu_mode(cls, mode: GameMode) -> rosu.GameMode:
|
||||
return {
|
||||
GameMode.OSU: rosu.GameMode.Osu,
|
||||
GameMode.TAIKO: rosu.GameMode.Taiko,
|
||||
GameMode.FRUITS: rosu.GameMode.Catch,
|
||||
GameMode.MANIA: rosu.GameMode.Mania,
|
||||
GameMode.OSURX: rosu.GameMode.Osu,
|
||||
GameMode.OSUAP: rosu.GameMode.Osu,
|
||||
GameMode.TAIKORX: rosu.GameMode.Taiko,
|
||||
GameMode.FRUITSRX: rosu.GameMode.Catch,
|
||||
}[mode]
|
||||
|
||||
@classmethod
|
||||
def _from_rosu_mode(cls, mode: rosu.GameMode) -> GameMode:
|
||||
return {
|
||||
rosu.GameMode.Osu: GameMode.OSU,
|
||||
rosu.GameMode.Taiko: GameMode.TAIKO,
|
||||
rosu.GameMode.Catch: GameMode.FRUITS,
|
||||
rosu.GameMode.Mania: GameMode.MANIA,
|
||||
}[mode]
|
||||
|
||||
@classmethod
|
||||
def _perf_attr_to_model(cls, attr: rosu.PerformanceAttributes, gamemode: GameMode) -> PerformanceAttributes:
|
||||
attr_class = PERFORMANCE_CLASS.get(gamemode, PerformanceAttributes)
|
||||
|
||||
if attr_class is OsuPerformanceAttributes:
|
||||
return OsuPerformanceAttributes(
|
||||
pp=attr.pp,
|
||||
aim=attr.pp_aim or 0,
|
||||
speed=attr.pp_speed or 0,
|
||||
accuracy=attr.pp_accuracy or 0,
|
||||
flashlight=attr.pp_flashlight or 0,
|
||||
effective_miss_count=attr.effective_miss_count or 0,
|
||||
speed_deviation=attr.speed_deviation,
|
||||
)
|
||||
elif attr_class is TaikoPerformanceAttributes:
|
||||
return TaikoPerformanceAttributes(
|
||||
pp=attr.pp,
|
||||
difficulty=attr.pp_difficulty or 0,
|
||||
accuracy=attr.pp_accuracy or 0,
|
||||
estimated_unstable_rate=attr.estimated_unstable_rate,
|
||||
)
|
||||
elif attr_class is ManiaPerformanceAttributes:
|
||||
return ManiaPerformanceAttributes(
|
||||
pp=attr.pp,
|
||||
difficulty=attr.pp_difficulty or 0,
|
||||
)
|
||||
else:
|
||||
return PerformanceAttributes(pp=attr.pp)
|
||||
|
||||
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
|
||||
try:
|
||||
map = rosu.Beatmap(content=beatmap_raw)
|
||||
mods = deepcopy(score.mods.copy())
|
||||
parse_enum_to_str(int(score.gamemode), mods)
|
||||
map.convert(self._to_rosu_mode(score.gamemode), mods) # pyright: ignore[reportArgumentType]
|
||||
perf = rosu.Performance(
|
||||
mods=mods,
|
||||
lazer=True,
|
||||
accuracy=clamp(score.accuracy * 100, 0, 100),
|
||||
combo=score.max_combo,
|
||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||
slider_end_hits=score.nslider_tail_hit or 0,
|
||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
)
|
||||
attr = await get_event_loop().run_in_executor(None, perf.calculate, map)
|
||||
return self._perf_attr_to_model(attr, score.gamemode.to_base_ruleset())
|
||||
except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise PerformanceError(f"Beatmap parse error: {e}")
|
||||
except Exception as e:
|
||||
raise CalculateError(f"Unknown error: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def _diff_attr_to_model(cls, diff: rosu.DifficultyAttributes, gamemode: GameMode) -> BeatmapAttributes:
|
||||
attr_class = DIFFICULTY_CLASS.get(gamemode, BeatmapAttributes)
|
||||
|
||||
if attr_class is OsuBeatmapAttributes:
|
||||
return OsuBeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim or 0,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count or 0,
|
||||
speed_difficulty=diff.speed or 0,
|
||||
speed_note_count=diff.speed_note_count or 0,
|
||||
slider_factor=diff.slider_factor or 0,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count or 0,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count or 0,
|
||||
flashlight_difficulty=diff.flashlight or 0,
|
||||
)
|
||||
elif attr_class is TaikoBeatmapAttributes:
|
||||
return TaikoBeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
rhythm_difficulty=diff.rhythm or 0,
|
||||
mono_stamina_factor=diff.stamina or 0,
|
||||
)
|
||||
else:
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
)
|
||||
|
||||
async def calculate_difficulty(
|
||||
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
|
||||
) -> BeatmapAttributes:
|
||||
try:
|
||||
map = rosu.Beatmap(content=beatmap_raw)
|
||||
if gamemode is not None:
|
||||
map.convert(self._to_rosu_mode(gamemode), mods) # pyright: ignore[reportArgumentType]
|
||||
diff_calculator = rosu.Difficulty(mods=mods)
|
||||
diff = await get_event_loop().run_in_executor(None, diff_calculator.calculate, map)
|
||||
return self._diff_attr_to_model(
|
||||
diff, gamemode.to_base_ruleset() if gamemode else self._from_rosu_mode(diff.mode)
|
||||
)
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise ConvertError(f"Beatmap convert error: {e}")
|
||||
except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise DifficultyError(f"Beatmap parse error: {e}")
|
||||
except Exception as e:
|
||||
raise CalculateError(f"Unknown error: {e}") from e
|
||||
Reference in New Issue
Block a user