feat(pp-calculator): support other pp calculators (#57)
New configurations:
- CALCULATOR="rosu": specific pp calculator
- CALCULATOR_CONFIG='{}': argument passed through into calculator
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
import importlib
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculators.performance import PerformanceCalculator
|
||||
from app.config import settings
|
||||
from app.log import log
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, parse_enum_to_str
|
||||
from app.models.score import GameMode
|
||||
|
||||
from osupyparser import HitObject, OsuFile
|
||||
@@ -16,23 +15,32 @@ from redis.asyncio import Redis
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
logger = log("Calculator")
|
||||
|
||||
try:
|
||||
import rosu_pp_py as rosu
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"rosu-pp-py is not installed. "
|
||||
"Please install it.\n"
|
||||
" Official: uv add rosu-pp-py\n"
|
||||
" ppy-sb: uv add git+https://github.com/ppy-sb/rosu-pp-py.git"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.score import Score
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
|
||||
logger = log("Calculator")
|
||||
|
||||
CALCULATOR: PerformanceCalculator | None = None
|
||||
|
||||
|
||||
def init_calculator():
|
||||
global CALCULATOR
|
||||
try:
|
||||
module = importlib.import_module(f"app.calculators.performance.{settings.calculator}")
|
||||
CALCULATOR = module.PerformanceCalculator(**settings.calculator_config)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ImportError(f"Failed to import performance calculator for {settings.calculator}") from e
|
||||
return CALCULATOR
|
||||
|
||||
|
||||
def get_calculator() -> PerformanceCalculator:
|
||||
if CALCULATOR is None:
|
||||
raise RuntimeError("Performance calculator is not initialized")
|
||||
return CALCULATOR
|
||||
|
||||
|
||||
def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T:
|
||||
if n < min_value:
|
||||
return min_value
|
||||
@@ -42,29 +50,6 @@ def clamp[T: int | float](n: T, min_value: T, max_value: T) -> T:
|
||||
return n
|
||||
|
||||
|
||||
def calculate_beatmap_attribute(
|
||||
beatmap: str,
|
||||
gamemode: GameMode | None = None,
|
||||
mods: int | list[APIMod] | list[str] = 0,
|
||||
) -> BeatmapAttributes:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
if gamemode is not None:
|
||||
map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||
diff = rosu.Difficulty(mods=mods).calculate(map)
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count,
|
||||
speed_difficulty=diff.speed,
|
||||
speed_note_count=diff.speed_note_count,
|
||||
slider_factor=diff.slider_factor,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count,
|
||||
mono_stamina_factor=diff.stamina,
|
||||
)
|
||||
|
||||
|
||||
async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> float:
|
||||
from app.database.beatmap import BannedBeatmaps
|
||||
|
||||
@@ -83,41 +68,13 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
except Exception:
|
||||
logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious")
|
||||
|
||||
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _calculate_pp_sync():
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
mods = deepcopy(score.mods.copy())
|
||||
parse_enum_to_str(int(score.gamemode), mods)
|
||||
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||
perf = rosu.Performance(
|
||||
mods=mods,
|
||||
lazer=True,
|
||||
accuracy=clamp(score.accuracy * 100, 0, 100),
|
||||
combo=score.max_combo,
|
||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||
slider_end_hits=score.nslider_tail_hit or 0,
|
||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
)
|
||||
return perf.calculate(map)
|
||||
|
||||
# 在线程池中执行计算
|
||||
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
|
||||
attrs = await get_calculator().calculate_performance(beatmap, score)
|
||||
pp = attrs.pp
|
||||
|
||||
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
||||
if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 3000):
|
||||
if settings.suspicious_score_check and (pp > 3000):
|
||||
logger.warning(
|
||||
f"User {score.user_id} played {score.beatmap_id} "
|
||||
f"(star={attrs.difficulty.stars}) with {pp=} "
|
||||
f"with {pp=} "
|
||||
f"acc={score.accuracy}. The score is suspicious and return 0pp"
|
||||
f"({score.id=})"
|
||||
)
|
||||
@@ -170,74 +127,6 @@ async def pre_fetch_and_calculate_pp(
|
||||
return await calculate_pp(score, beatmap_raw, session), True
|
||||
|
||||
|
||||
async def batch_calculate_pp(
|
||||
scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
|
||||
) -> list[float]:
|
||||
"""
|
||||
批量计算PP:适用于重新计算或批量处理场景
|
||||
Args:
|
||||
scores_data: [(score, beatmap_id), ...] 的列表
|
||||
Returns:
|
||||
对应的PP值列表
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.database.beatmap import BannedBeatmaps
|
||||
|
||||
if not scores_data:
|
||||
return []
|
||||
|
||||
# 提取所有唯一的beatmap_id
|
||||
unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data})
|
||||
|
||||
# 批量检查被封禁的beatmap
|
||||
banned_beatmaps = set()
|
||||
if settings.suspicious_score_check:
|
||||
banned_results = await session.exec(
|
||||
select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids))
|
||||
)
|
||||
banned_beatmaps = set(banned_results.all())
|
||||
|
||||
# 并发获取所有需要的beatmap原始文件
|
||||
async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]:
|
||||
if beatmap_id in banned_beatmaps:
|
||||
return beatmap_id, None
|
||||
try:
|
||||
content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
return beatmap_id, content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
|
||||
return beatmap_id, None
|
||||
|
||||
# 并发获取所有beatmap文件
|
||||
fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids]
|
||||
fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True)
|
||||
|
||||
# 构建beatmap_id -> content的映射
|
||||
beatmap_contents = {}
|
||||
for result in fetch_results:
|
||||
if isinstance(result, tuple):
|
||||
beatmap_id, content = result
|
||||
beatmap_contents[beatmap_id] = content
|
||||
|
||||
# 为每个score计算PP
|
||||
pp_results = []
|
||||
for score, beatmap_id in scores_data:
|
||||
beatmap_content = beatmap_contents.get(beatmap_id)
|
||||
if beatmap_content is None:
|
||||
pp_results.append(0.0)
|
||||
continue
|
||||
|
||||
try:
|
||||
pp = await calculate_pp(score, beatmap_content, session)
|
||||
pp_results.append(pp)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate PP for score {score.id}: {e}")
|
||||
pp_results.append(0.0)
|
||||
|
||||
return pp_results
|
||||
|
||||
|
||||
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
|
||||
def calculate_level_to_score(n: int) -> float:
|
||||
if n <= 100:
|
||||
|
||||
Reference in New Issue
Block a user