feat(pp-calculator): support other pp calculators (#57)

New configurations:

- CALCULATOR="rosu": specific pp calculator
- CALCULATOR_CONFIG='{}': argument passed through into calculator
This commit is contained in:
MingxuanGame
2025-10-18 19:10:53 +08:00
committed by GitHub
parent 563a30d28f
commit 8790ccad64
16 changed files with 496 additions and 189 deletions

View File

@@ -1,13 +1,12 @@
import asyncio
from copy import deepcopy
from enum import Enum
import importlib
import math
from typing import TYPE_CHECKING
from app.calculators.performance import PerformanceCalculator
from app.config import settings
from app.log import log
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, parse_enum_to_str
from app.models.score import GameMode
from osupyparser import HitObject, OsuFile
@@ -16,23 +15,32 @@ from redis.asyncio import Redis
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
logger = log("Calculator")
try:
import rosu_pp_py as rosu
except ImportError:
raise ImportError(
"rosu-pp-py is not installed. "
"Please install it.\n"
" Official: uv add rosu-pp-py\n"
" ppy-sb: uv add git+https://github.com/ppy-sb/rosu-pp-py.git"
)
if TYPE_CHECKING:
from app.database.score import Score
from app.fetcher import Fetcher
logger = log("Calculator")
CALCULATOR: PerformanceCalculator | None = None
def init_calculator():
global CALCULATOR
try:
module = importlib.import_module(f"app.calculators.performance.{settings.calculator}")
CALCULATOR = module.PerformanceCalculator(**settings.calculator_config)
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to import performance calculator for {settings.calculator}") from e
return CALCULATOR
def get_calculator() -> PerformanceCalculator:
if CALCULATOR is None:
raise RuntimeError("Performance calculator is not initialized")
return CALCULATOR
def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T:
if n < min_value:
return min_value
@@ -42,29 +50,6 @@ def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T:
return n
def calculate_beatmap_attribute(
beatmap: str,
gamemode: GameMode | None = None,
mods: int | list[APIMod] | list[str] = 0,
) -> BeatmapAttributes:
map = rosu.Beatmap(content=beatmap)
if gamemode is not None:
map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
diff = rosu.Difficulty(mods=mods).calculate(map)
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim,
aim_difficult_slider_count=diff.aim_difficult_slider_count,
speed_difficulty=diff.speed,
speed_note_count=diff.speed_note_count,
slider_factor=diff.slider_factor,
aim_difficult_strain_count=diff.aim_difficult_strain_count,
speed_difficult_strain_count=diff.speed_difficult_strain_count,
mono_stamina_factor=diff.stamina,
)
async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> float:
from app.database.beatmap import BannedBeatmaps
@@ -83,41 +68,13 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
except Exception:
logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious")
# 使用线程池执行计算密集型操作以避免阻塞事件循环
loop = asyncio.get_event_loop()
def _calculate_pp_sync():
map = rosu.Beatmap(content=beatmap)
mods = deepcopy(score.mods.copy())
parse_enum_to_str(int(score.gamemode), mods)
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
perf = rosu.Performance(
mods=mods,
lazer=True,
accuracy=clamp(score.accuracy * 100, 0, 100),
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
)
return perf.calculate(map)
# 在线程池中执行计算
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
attrs = await get_calculator().calculate_performance(beatmap, score)
pp = attrs.pp
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 3000):
if settings.suspicious_score_check and (pp > 3000):
logger.warning(
f"User {score.user_id} played {score.beatmap_id} "
f"(star={attrs.difficulty.stars}) with {pp=} "
f"with {pp=} "
f"acc={score.accuracy}. The score is suspicious and return 0pp"
f"({score.id=})"
)
@@ -170,74 +127,6 @@ async def pre_fetch_and_calculate_pp(
return await calculate_pp(score, beatmap_raw, session), True
async def batch_calculate_pp(
scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
) -> list[float]:
"""
批量计算PP适用于重新计算或批量处理场景
Args:
scores_data: [(score, beatmap_id), ...] 的列表
Returns:
对应的PP值列表
"""
import asyncio
from app.database.beatmap import BannedBeatmaps
if not scores_data:
return []
# 提取所有唯一的beatmap_id
unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data})
# 批量检查被封禁的beatmap
banned_beatmaps = set()
if settings.suspicious_score_check:
banned_results = await session.exec(
select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids))
)
banned_beatmaps = set(banned_results.all())
# 并发获取所有需要的beatmap原始文件
async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]:
if beatmap_id in banned_beatmaps:
return beatmap_id, None
try:
content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
return beatmap_id, content
except Exception as e:
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
return beatmap_id, None
# 并发获取所有beatmap文件
fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids]
fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True)
# 构建beatmap_id -> content的映射
beatmap_contents = {}
for result in fetch_results:
if isinstance(result, tuple):
beatmap_id, content = result
beatmap_contents[beatmap_id] = content
# 为每个score计算PP
pp_results = []
for score, beatmap_id in scores_data:
beatmap_content = beatmap_contents.get(beatmap_id)
if beatmap_content is None:
pp_results.append(0.0)
continue
try:
pp = await calculate_pp(score, beatmap_content, session)
pp_results.append(pp)
except Exception as e:
logger.error(f"Failed to calculate PP for score {score.id}: {e}")
pp_results.append(0.0)
return pp_results
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
def calculate_level_to_score(n: int) -> float:
if n <= 100:

View File

@@ -0,0 +1,3 @@
from ._base import CalculateError, ConvertError, DifficultyError, PerformanceCalculator, PerformanceError
__all__ = ["CalculateError", "ConvertError", "DifficultyError", "PerformanceCalculator", "PerformanceError"]

View File

@@ -0,0 +1,37 @@
import abc
from typing import TYPE_CHECKING
from app.models.mods import APIMod
from app.models.performance import BeatmapAttributes, PerformanceAttributes
from app.models.score import GameMode
if TYPE_CHECKING:
from app.database.score import Score
class CalculateError(Exception):
"""An error occurred during performance calculation."""
class DifficultyError(CalculateError):
"""The difficulty could not be calculated."""
class ConvertError(DifficultyError):
"""A beatmap cannot be converted to the specified game mode."""
class PerformanceError(CalculateError):
"""The performance could not be calculated."""
class PerformanceCalculator(abc.ABC):
@abc.abstractmethod
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
raise NotImplementedError
@abc.abstractmethod
async def calculate_difficulty(
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
) -> BeatmapAttributes:
raise NotImplementedError

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING
from app.models.mods import APIMod
from app.models.performance import (
DIFFICULTY_CLASS,
PERFORMANCE_CLASS,
BeatmapAttributes,
PerformanceAttributes,
)
from app.models.score import GameMode
from ._base import (
CalculateError,
DifficultyError,
PerformanceCalculator as BasePerformanceCalculator,
PerformanceError,
)
from httpx import AsyncClient, HTTPError
if TYPE_CHECKING:
from app.database.score import Score
class PerformanceCalculator(BasePerformanceCalculator):
def __init__(self, server_url: str = "http://localhost:5225") -> None:
self.server_url = server_url
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
# https://github.com/GooGuTeam/osu-performance-server#post-performance
async with AsyncClient() as client:
try:
resp = await client.post(
f"{self.server_url}/performance",
json={
"beatmap_id": score.beatmap_id,
"beatmap_file": beatmap_raw,
"checksum": score.map_md5,
"accuracy": score.accuracy,
"combo": score.max_combo,
"mods": score.mods,
"statistics": {
"great": score.n300,
"ok": score.n100,
"meh": score.n50,
"miss": score.nmiss,
"perfect": score.ngeki,
"good": score.nkatu,
"large_tick_hit": score.nlarge_tick_hit or 0,
"large_tick_miss": score.nlarge_tick_miss or 0,
"small_tick_hit": score.nsmall_tick_hit or 0,
"slider_tail_hit": score.nslider_tail_hit or 0,
},
"ruleset": score.gamemode.value,
},
)
if resp.status_code != 200:
raise PerformanceError(f"Failed to calculate performance: {resp.text}")
data = resp.json()
return PERFORMANCE_CLASS.get(score.gamemode, PerformanceAttributes).model_validate(data)
except HTTPError as e:
raise PerformanceError(f"Failed to calculate performance: {e}") from e
except Exception as e:
raise CalculateError(f"Unknown error: {e}") from e
async def calculate_difficulty(
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
) -> BeatmapAttributes:
# https://github.com/GooGuTeam/osu-performance-server#post-difficulty
async with AsyncClient() as client:
try:
resp = await client.post(
f"{self.server_url}/difficulty",
json={
"beatmap_file": beatmap_raw,
"mods": mods or [],
"ruleset": int(gamemode) if gamemode else None,
},
)
if resp.status_code != 200:
raise DifficultyError(f"Failed to calculate difficulty: {resp.text}")
data = resp.json()
ruleset_id = data.pop("ruleset", "osu")
return DIFFICULTY_CLASS.get(GameMode(ruleset_id), BeatmapAttributes).model_validate(data)
except HTTPError as e:
raise DifficultyError(f"Failed to calculate difficulty: {e}") from e
except Exception as e:
raise DifficultyError(f"Unknown error: {e}") from e

View File

@@ -0,0 +1,169 @@
from asyncio import get_event_loop
from copy import deepcopy
from typing import TYPE_CHECKING
from app.calculator import clamp
from app.models.mods import APIMod, parse_enum_to_str
from app.models.performance import (
DIFFICULTY_CLASS,
PERFORMANCE_CLASS,
BeatmapAttributes,
ManiaPerformanceAttributes,
OsuBeatmapAttributes,
OsuPerformanceAttributes,
PerformanceAttributes,
TaikoBeatmapAttributes,
TaikoPerformanceAttributes,
)
from app.models.score import GameMode
from ._base import (
CalculateError,
ConvertError,
DifficultyError,
PerformanceCalculator as BasePerformanceCalculator,
PerformanceError,
)
if TYPE_CHECKING:
from app.database.score import Score
try:
import rosu_pp_py as rosu
except ImportError:
raise ImportError(
"rosu-pp-py is not installed. "
"Please install it.\n"
" Official: uv add rosu-pp-py\n"
" gu: uv add git+https://github.com/GooGuTeam/gu-pp-py.git"
)
class PerformanceCalculator(BasePerformanceCalculator):
@classmethod
def _to_rosu_mode(cls, mode: GameMode) -> rosu.GameMode:
return {
GameMode.OSU: rosu.GameMode.Osu,
GameMode.TAIKO: rosu.GameMode.Taiko,
GameMode.FRUITS: rosu.GameMode.Catch,
GameMode.MANIA: rosu.GameMode.Mania,
GameMode.OSURX: rosu.GameMode.Osu,
GameMode.OSUAP: rosu.GameMode.Osu,
GameMode.TAIKORX: rosu.GameMode.Taiko,
GameMode.FRUITSRX: rosu.GameMode.Catch,
}[mode]
@classmethod
def _from_rosu_mode(cls, mode: rosu.GameMode) -> GameMode:
return {
rosu.GameMode.Osu: GameMode.OSU,
rosu.GameMode.Taiko: GameMode.TAIKO,
rosu.GameMode.Catch: GameMode.FRUITS,
rosu.GameMode.Mania: GameMode.MANIA,
}[mode]
@classmethod
def _perf_attr_to_model(cls, attr: rosu.PerformanceAttributes, gamemode: GameMode) -> PerformanceAttributes:
attr_class = PERFORMANCE_CLASS.get(gamemode, PerformanceAttributes)
if attr_class is OsuPerformanceAttributes:
return OsuPerformanceAttributes(
pp=attr.pp,
aim=attr.pp_aim or 0,
speed=attr.pp_speed or 0,
accuracy=attr.pp_accuracy or 0,
flashlight=attr.pp_flashlight or 0,
effective_miss_count=attr.effective_miss_count or 0,
speed_deviation=attr.speed_deviation,
)
elif attr_class is TaikoPerformanceAttributes:
return TaikoPerformanceAttributes(
pp=attr.pp,
difficulty=attr.pp_difficulty or 0,
accuracy=attr.pp_accuracy or 0,
estimated_unstable_rate=attr.estimated_unstable_rate,
)
elif attr_class is ManiaPerformanceAttributes:
return ManiaPerformanceAttributes(
pp=attr.pp,
difficulty=attr.pp_difficulty or 0,
)
else:
return PerformanceAttributes(pp=attr.pp)
async def calculate_performance(self, beatmap_raw: str, score: "Score") -> PerformanceAttributes:
try:
map = rosu.Beatmap(content=beatmap_raw)
mods = deepcopy(score.mods.copy())
parse_enum_to_str(int(score.gamemode), mods)
map.convert(self._to_rosu_mode(score.gamemode), mods) # pyright: ignore[reportArgumentType]
perf = rosu.Performance(
mods=mods,
lazer=True,
accuracy=clamp(score.accuracy * 100, 0, 100),
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
)
attr = await get_event_loop().run_in_executor(None, perf.calculate, map)
return self._perf_attr_to_model(attr, score.gamemode.to_base_ruleset())
except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue]
raise PerformanceError(f"Beatmap parse error: {e}")
except Exception as e:
raise CalculateError(f"Unknown error: {e}") from e
@classmethod
def _diff_attr_to_model(cls, diff: rosu.DifficultyAttributes, gamemode: GameMode) -> BeatmapAttributes:
attr_class = DIFFICULTY_CLASS.get(gamemode, BeatmapAttributes)
if attr_class is OsuBeatmapAttributes:
return OsuBeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim or 0,
aim_difficult_slider_count=diff.aim_difficult_slider_count or 0,
speed_difficulty=diff.speed or 0,
speed_note_count=diff.speed_note_count or 0,
slider_factor=diff.slider_factor or 0,
aim_difficult_strain_count=diff.aim_difficult_strain_count or 0,
speed_difficult_strain_count=diff.speed_difficult_strain_count or 0,
flashlight_difficulty=diff.flashlight or 0,
)
elif attr_class is TaikoBeatmapAttributes:
return TaikoBeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
rhythm_difficulty=diff.rhythm or 0,
mono_stamina_factor=diff.stamina or 0,
)
else:
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
)
async def calculate_difficulty(
self, beatmap_raw: str, mods: list[APIMod] | None = None, gamemode: GameMode | None = None
) -> BeatmapAttributes:
try:
map = rosu.Beatmap(content=beatmap_raw)
if gamemode is not None:
map.convert(self._to_rosu_mode(gamemode), mods) # pyright: ignore[reportArgumentType]
diff_calculator = rosu.Difficulty(mods=mods)
diff = await get_event_loop().run_in_executor(None, diff_calculator.calculate, map)
return self._diff_attr_to_model(
diff, gamemode.to_base_ruleset() if gamemode else self._from_rosu_mode(diff.mode)
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise ConvertError(f"Beatmap convert error: {e}")
except rosu.ParseError as e: # pyright: ignore[reportAttributeAccessIssue]
raise DifficultyError(f"Beatmap parse error: {e}")
except Exception as e:
raise CalculateError(f"Unknown error: {e}") from e

View File

@@ -103,6 +103,23 @@ STORAGE_SETTINGS='{
"s3_public_url_base": "https://your-custom-domain.com"
}'
```
""",
"表现计算设置": """配置表现分计算器及其参数。
### rosu-pp-py (默认)
```bash
CALCULATOR="rosu"
CALCULATOR_CONFIG='{}'
```
### [osu-performance-server](https://github.com/GooGuTeam/osu-performance-server)
```bash
CALCULATOR="performance_server"
CALCULATOR_CONFIG='{
"server_url": "http://localhost:5225"
}'
""",
}
},
@@ -486,6 +503,21 @@ STORAGE_SETTINGS='{
"游戏设置",
]
# 表现计算设置
calculator: Annotated[
Literal["rosu", "performance_server"],
Field(default="rosu", description="表现分计算器"),
"表现计算设置",
]
calculator_config: Annotated[
dict[str, Any],
Field(
default={},
description="表现分计算器配置 (JSON 格式),具体配置项请参考上方",
),
"表现计算设置",
]
# 谱面缓存设置
enable_beatmap_preload: Annotated[
bool,
@@ -612,7 +644,7 @@ STORAGE_SETTINGS='{
# 反作弊设置
suspicious_score_check: Annotated[
bool,
Field(default=True, description="启用可疑分数检查(star>25&acc<80 或 pp>3000"),
Field(default=True, description="启用可疑分数检查pp>3000"),
"反作弊设置",
]
banned_name: Annotated[

View File

@@ -1,14 +1,14 @@
import asyncio
from datetime import datetime
import hashlib
from typing import TYPE_CHECKING
from app.calculator import calculate_beatmap_attribute
from app.calculator import get_calculator
from app.config import settings
from app.database.beatmap_tags import BeatmapTagVote
from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DIFFICULTY_CLASS, BeatmapAttributes
from app.models.score import GameMode
from .beatmap_playcounts import BeatmapPlaycounts
@@ -247,10 +247,13 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
attr_class = DIFFICULTY_CLASS.get(ruleset, BeatmapAttributes)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key))
return attr_class.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
await redis.set(key, attr.model_dump_json())
return attr

View File

@@ -69,23 +69,6 @@ class Language(IntEnum):
OTHER = 14
class BeatmapAttributes(BaseModel):
star_rating: float
max_combo: int
# osu
aim_difficulty: float | None = None
aim_difficult_slider_count: float | None = None
speed_difficulty: float | None = None
speed_note_count: float | None = None
slider_factor: float | None = None
aim_difficult_strain_count: float | None = None
speed_difficult_strain_count: float | None = None
# taiko
mono_stamina_factor: float | None = None
def _parse_list(v: Any):
if isinstance(v, str):
return v.split(".")

81
app/models/performance.py Normal file
View File

@@ -0,0 +1,81 @@
from app.models.score import GameMode
from pydantic import BaseModel
class PerformanceAttributes(BaseModel):
pp: float
# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceAttributes.cs
class OsuPerformanceAttributes(PerformanceAttributes):
aim: float
speed: float
accuracy: float
flashlight: float
effective_miss_count: float
speed_deviation: float | None = None
# 2025 Q3 update
# combo_based_estimated_miss_count: int
# score_based_estimated_miss_count: int | None = None
# aim_estimated_slider_breaks: int
# speed_estimated_slider_breaks: int
# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Taiko/Difficulty/TaikoPerformanceAttributes.cs
class TaikoPerformanceAttributes(PerformanceAttributes):
difficulty: float
accuracy: float
estimated_unstable_rate: float | None = None
# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Mania/Difficulty/ManiaPerformanceAttributes.cs
class ManiaPerformanceAttributes(PerformanceAttributes):
difficulty: float
PERFORMANCE_CLASS: dict[GameMode, type[PerformanceAttributes]] = {
GameMode.OSU: OsuPerformanceAttributes,
GameMode.MANIA: ManiaPerformanceAttributes,
GameMode.TAIKO: TaikoPerformanceAttributes,
}
class BeatmapAttributes(BaseModel):
star_rating: float
max_combo: int
# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs
class OsuBeatmapAttributes(BeatmapAttributes):
aim_difficulty: float
aim_difficult_slider_count: float
speed_difficulty: float
speed_note_count: float
flashlight_difficulty: float | None = None
slider_factor: float
aim_difficult_strain_count: float
speed_difficult_strain_count: float
# 2025 Q3 update
# aim_top_weighted_slider_factor: float
# speed_top_weighted_slider_factor: float
# nested_score_per_object: float
# legacy_score_base_multiplier: float
# maximum_legacy_combo_score: float
# https://github.com/ppy/osu/blob/9ebc5b0a35452e50bd408af1db62cfc22a57b1f4/osu.Game.Rulesets.Taiko/Difficulty/TaikoDifficultyAttributes.cs
class TaikoBeatmapAttributes(BeatmapAttributes):
rhythm_difficulty: float
mono_stamina_factor: float
# 2025 Q3 update
# consistency_factor: float
DIFFICULTY_CLASS: dict[GameMode, type[BeatmapAttributes]] = {
GameMode.OSU: OsuBeatmapAttributes,
GameMode.TAIKO: TaikoBeatmapAttributes,
}

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypedDict, cast
from typing import Literal, TypedDict, cast
from app.config import settings
@@ -7,9 +7,6 @@ from .mods import API_MODS, APIMod
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator
if TYPE_CHECKING:
import rosu_pp_py as rosu
class GameMode(str, Enum):
OSU = "osu"
@@ -21,20 +18,6 @@ class GameMode(str, Enum):
TAIKORX = "taikorx"
FRUITSRX = "fruitsrx"
def to_rosu(self) -> "rosu.GameMode":
import rosu_pp_py as rosu
return {
GameMode.OSU: rosu.GameMode.Osu,
GameMode.TAIKO: rosu.GameMode.Taiko,
GameMode.FRUITS: rosu.GameMode.Catch,
GameMode.MANIA: rosu.GameMode.Mania,
GameMode.OSURX: rosu.GameMode.Osu,
GameMode.OSUAP: rosu.GameMode.Osu,
GameMode.TAIKORX: rosu.GameMode.Taiko,
GameMode.FRUITSRX: rosu.GameMode.Catch,
}[self]
def __int__(self) -> int:
return {
GameMode.OSU: 0,

View File

@@ -10,6 +10,7 @@ from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.mods import int_to_mods
from app.models.performance import OsuBeatmapAttributes
from app.models.score import GameMode
from .router import AllStrModel, router
@@ -193,7 +194,12 @@ async def get_beatmaps(
redis,
fetcher,
)
results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
aim_diff = None
speed_diff = None
if isinstance(attrs, OsuBeatmapAttributes):
aim_diff = attrs.aim_difficulty
speed_diff = attrs.speed_difficulty
results.append(await V1Beatmap.from_db(session, beatmap, aim_diff, speed_diff))
continue
except Exception:
...

View File

@@ -3,14 +3,15 @@ import hashlib
import json
from typing import Annotated
from app.calculators.performance import ConvertError
from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, int_to_mods
from app.models.performance import BeatmapAttributes, OsuBeatmapAttributes, TaikoBeatmapAttributes
from app.models.score import (
GameMode,
)
@@ -20,7 +21,6 @@ from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
import rosu_pp_py as rosu
from sqlmodel import col, select
@@ -127,7 +127,7 @@ async def batch_get_beatmaps(
"/beatmaps/{beatmap_id}/attributes",
tags=["谱面"],
name="计算谱面属性",
response_model=BeatmapAttributes,
response_model=BeatmapAttributes | OsuBeatmapAttributes | TaikoBeatmapAttributes,
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
)
async def get_beatmap_attributes(
@@ -171,5 +171,5 @@ async def get_beatmap_attributes(
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e)) from e
except ConvertError as e:
raise HTTPException(status_code=400, detail=str(e))