From 223fa9969289eed6d7c6aeddf5582f089fb68ded Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Tue, 29 Jul 2025 02:56:21 +0000 Subject: [PATCH] feat(score): support calculate pp --- app/calculator.py | 59 +++++++++++++++++ app/database/score.py | 8 ++- app/fetcher/osu_dot_direct.py | 10 +++ app/models/mods.py | 62 ++++++++++++++++++ app/router/beatmap.py | 4 +- app/router/score.py | 37 ++++++++++- app/utils.py | 64 ++++++++----------- ..._score_add_nlarge_tick_hit_nsmall_tick_.py | 36 +++++++++++ 8 files changed, 238 insertions(+), 42 deletions(-) create mode 100644 app/calculator.py create mode 100644 migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py diff --git a/app/calculator.py b/app/calculator.py new file mode 100644 index 0000000..40fe7d1 --- /dev/null +++ b/app/calculator.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from app.database.score import Score +from app.models.beatmap import BeatmapAttributes +from app.models.mods import APIMod +from app.models.score import GameMode + +import rosu_pp_py as rosu + + +def calculate_beatmap_attribute( + beatmap: str, + gamemode: GameMode | None = None, + mods: int | list[APIMod] | list[str] = 0, +) -> BeatmapAttributes: + map = rosu.Beatmap(content=beatmap) + if gamemode is not None: + map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType] + diff = rosu.Difficulty(mods=mods).calculate(map) + return BeatmapAttributes( + star_rating=diff.stars, + max_combo=diff.max_combo, + aim_difficulty=diff.aim, + aim_difficult_slider_count=diff.aim_difficult_slider_count, + speed_difficulty=diff.speed, + speed_note_count=diff.speed_note_count, + slider_factor=diff.slider_factor, + aim_difficult_strain_count=diff.aim_difficult_strain_count, + speed_difficult_strain_count=diff.speed_difficult_strain_count, + mono_stamina_factor=diff.stamina, + ) + + +def calculate_pp( + score: Score, + beatmap: str, +) -> float: + map = rosu.Beatmap(content=beatmap) + map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType] + if map.is_suspicious(): + return 0.0 + perf = rosu.Performance( + mods=score.mods, + lazer=True, + accuracy=score.accuracy, + combo=score.max_combo, + large_tick_hits=score.nlarge_tick_hit or 0, + slider_end_hits=score.nslider_tail_hit or 0, + small_tick_hits=score.nsmall_tick_hit or 0, + n_geki=score.ngeki, + n_katu=score.nkatu, + n300=score.n300, + n100=score.n100, + n50=score.n50, + misses=score.nmiss, + hitresult_priority=rosu.HitResultPriority.Fastest, + ) + attrs = perf.calculate(map) + return attrs.pp diff --git a/app/database/score.py b/app/database/score.py index 180096d..779b21d 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -49,7 +49,7 @@ class ScoreBase(SQLModel): mods: list[APIMod] = Field(sa_column=Column(JSON)) passed: bool playlist_item_id: int | None = Field(default=None) # multiplayer - pp: float + pp: float = Field(default=0.0) preserve: bool = Field(default=True) rank: Rank room_id: int | None = Field(default=None) # multiplayer @@ -87,7 +87,9 @@ class Score(ScoreBase, table=True): ngeki: int = Field(exclude=True) nkatu: int = Field(exclude=True) nlarge_tick_miss: int | None = Field(default=None, exclude=True) + nlarge_tick_hit: int | None = Field(default=None, exclude=True) nslider_tail_hit: int | None = Field(default=None, exclude=True) + nsmall_tick_hit: int | None = Field(default=None, exclude=True) gamemode: GameMode = Field(index=True) # optional @@ -176,6 +178,10 @@ class ScoreResp(ScoreBase): s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss if score.nslider_tail_hit is not None: s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit + if score.nsmall_tick_hit is not None: + s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit + if score.nlarge_tick_hit is not None: + s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit # s.user = await convert_db_user_to_api_user(score.user) s.rank_global = ( await get_score_position_by_id( diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index eeeddec..08b8dfc 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,6 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger +import redis class OsuDotDirectFetcher(BaseFetcher): @@ -17,3 +18,12 @@ class OsuDotDirectFetcher(BaseFetcher): ) response.raise_for_status() return response.text + + async def get_or_fetch_beatmap_raw( + self, redis: redis.Redis, beatmap_id: int + ) -> str: + if redis.exists(f"beatmap:{beatmap_id}:raw"): + return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + raw = await self.get_beatmap_raw(beatmap_id) + redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + return raw diff --git a/app/models/mods.py b/app/models/mods.py index 7b5e78d..abcd2cd 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -105,3 +105,65 @@ def mods_to_int(mods: list[APIMod]) -> int: for mod in mods: sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0) return sum_ + + +NO_CHECK = "DO_NO_CHECK" + +# FIXME: 这里为空表示了两种情况:mod 没有配置项;任何时候都可以获得 pp +# 如果是后者,则 mod 更新的时候可能会误判。 +COMMON_CONFIG: dict[str, dict] = { + "EZ": {"retries": 2}, + "NF": {}, + "HT": {"speed_change": 0.75, "adjust_pitch": NO_CHECK}, + "DC": {"speed_change": 0.75}, + "HR": {}, + "SD": {}, + "PF": {}, + "HD": {}, + "DT": {"speed_change": 1.5, "adjust_pitch": NO_CHECK}, + "NC": {"speed_change": 1.5}, + "FL": {"size_multiplier": 1.0, "combo_based_size": True}, + "AC": {}, + "MU": {}, + "TD": {}, +} + +RANKED_MODS: dict[int, dict[str, dict]] = { + 0: COMMON_CONFIG, + 1: COMMON_CONFIG, + 2: COMMON_CONFIG, + 3: COMMON_CONFIG, +} +# osu +RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False +RANKED_MODS[0]["FL"]["follow_delay"] = 1.0 +RANKED_MODS[0]["BL"] = {} +RANKED_MODS[0]["NS"] = {} +RANKED_MODS[0]["SO"] = {} +RANKED_MODS[0]["TC"] = {} +# taiko +del RANKED_MODS[1]["EZ"]["retries"] +# catch +RANKED_MODS[2]["NS"] = {} +# mania +del RANKED_MODS[3]["HR"] +RANKED_MODS[3]["FL"]["combo_based_size"] = False +RANKED_MODS[3]["MR"] = {} +for i in range(4, 10): + RANKED_MODS[3][f"{i}K"] = {} + + +def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool: + ranked_mods = RANKED_MODS[ruleset_id] + for mod in mods: + mod["settings"] = mod.get("settings", {}) + if (settings := ranked_mods.get(mod["acronym"])) is None: + return False + if settings == {}: + continue + for setting, value in mod["settings"].items(): + if (expected_value := settings.get(setting)) is None: + return False + if expected_value != NO_CHECK and value != expected_value: + return False + return True diff --git a/app/router/beatmap.py b/app/router/beatmap.py index cf59148..71d554f 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -4,6 +4,7 @@ import asyncio import hashlib import json +from app.calculator import calculate_beatmap_attribute from app.database import ( Beatmap, BeatmapResp, @@ -20,7 +21,6 @@ from app.models.score import ( INT_TO_MODE, GameMode, ) -from app.utils import calculate_beatmap_attribute from .api_router import router @@ -157,7 +157,7 @@ async def get_beatmap_attributes( return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] try: - resp = await fetcher.get_beatmap_raw(beatmap) + resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) try: attr = await asyncio.get_event_loop().run_in_executor( None, calculate_beatmap_attribute, resp, ruleset, mods_ diff --git a/app/router/score.py b/app/router/score.py index cc1629a..dda4866 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,14 +1,20 @@ from __future__ import annotations +import asyncio import datetime +from app.calculator import calculate_pp from app.database import ( User as DBUser, ) +from app.database.beatmap import Beatmap from app.database.score import Score, ScoreResp from app.database.score_token import ScoreToken, ScoreTokenResp -from app.dependencies.database import get_db +from app.dependencies.database import get_db, get_redis +from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.models.beatmap import BeatmapRankStatus +from app.models.mods import mods_can_get_pp from app.models.score import ( INT_TO_MODE, GameMode, @@ -21,6 +27,7 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel +from redis import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, select, true from sqlmodel.ext.asyncio.session import AsyncSession @@ -187,6 +194,8 @@ async def submit_solo_score( info: SoloScoreSubmissionInfo, current_user: DBUser = Depends(get_current_user), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), + fetcher=Depends(get_fetcher), ): if not info.passed: info.rank = Rank.F @@ -214,6 +223,13 @@ async def submit_solo_score( if not score: raise HTTPException(status_code=404, detail="Score not found") else: + beatmap_status = ( + await db.exec( + select(Beatmap.beatmap_status).where(Beatmap.id == beatmap) + ) + ).first() + if beatmap_status is None: + raise HTTPException(status_code=404, detail="Beatmap not found") score = Score( accuracy=info.accuracy, max_combo=info.max_combo, @@ -231,7 +247,6 @@ async def submit_solo_score( preserve=info.passed, map_md5=score_token.beatmap.checksum, has_replay=False, - pp=info.pp, type="solo", n300=info.statistics.get(HitResult.GREAT, 0), n100=info.statistics.get(HitResult.OK, 0), @@ -239,7 +254,25 @@ async def submit_solo_score( nmiss=info.statistics.get(HitResult.MISS, 0), ngeki=info.statistics.get(HitResult.PERFECT, 0), nkatu=info.statistics.get(HitResult.GOOD, 0), + nlarge_tick_miss=info.statistics.get(HitResult.LARGE_TICK_MISS, 0), + nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), + nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), + nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), ) + if ( + info.passed + and beatmap_status + in { + BeatmapRankStatus.RANKED, + BeatmapRankStatus.APPROVED, + } + and mods_can_get_pp(info.ruleset_id, info.mods) + ): + beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) + pp = await asyncio.get_event_loop().run_in_executor( + None, calculate_pp, score, beatmap_raw + ) + score.pp = pp db.add(score) await db.commit() await db.refresh(score) diff --git a/app/utils.py b/app/utils.py index 1d7cece..9008706 100644 --- a/app/utils.py +++ b/app/utils.py @@ -8,9 +8,6 @@ from app.database import ( LazerUserStatistics, User as DBUser, ) -from app.models.beatmap import BeatmapAttributes -from app.models.mods import APIMod -from app.models.score import GameMode from app.models.user import ( Country, Cover, @@ -26,8 +23,6 @@ from app.models.user import ( UserAchievement, ) -import rosu_pp_py as rosu - def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" @@ -407,15 +402,33 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> current_season_stats=None, daily_challenge_user_stats=DailyChallengeStats( user_id=user_id, - daily_streak_best=db_user.daily_challenge_stats.daily_streak_best if db_user.daily_challenge_stats else 0, - daily_streak_current=db_user.daily_challenge_stats.daily_streak_current if db_user.daily_challenge_stats else 0, - last_update=db_user.daily_challenge_stats.last_update if db_user.daily_challenge_stats else None, - last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak if db_user.daily_challenge_stats else None, - playcount=db_user.daily_challenge_stats.playcount if db_user.daily_challenge_stats else 0, - top_10p_placements=db_user.daily_challenge_stats.top_10p_placements if db_user.daily_challenge_stats else 0, - top_50p_placements=db_user.daily_challenge_stats.top_50p_placements if db_user.daily_challenge_stats else 0, - weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best if db_user.daily_challenge_stats else 0, - weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current if db_user.daily_challenge_stats else 0, + daily_streak_best=db_user.daily_challenge_stats.daily_streak_best + if db_user.daily_challenge_stats + else 0, + daily_streak_current=db_user.daily_challenge_stats.daily_streak_current + if db_user.daily_challenge_stats + else 0, + last_update=db_user.daily_challenge_stats.last_update + if db_user.daily_challenge_stats + else None, + last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak + if db_user.daily_challenge_stats + else None, + playcount=db_user.daily_challenge_stats.playcount + if db_user.daily_challenge_stats + else 0, + top_10p_placements=db_user.daily_challenge_stats.top_10p_placements + if db_user.daily_challenge_stats + else 0, + top_50p_placements=db_user.daily_challenge_stats.top_50p_placements + if db_user.daily_challenge_stats + else 0, + weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best + if db_user.daily_challenge_stats + else 0, + weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current + if db_user.daily_challenge_stats + else 0, ), groups=[], monthly_playcounts=monthly_playcounts, @@ -450,26 +463,3 @@ def get_country_name(country_code: str) -> str: # 可以添加更多国家 } return country_names.get(country_code, "Unknown") - - -def calculate_beatmap_attribute( - beatmap: str, - gamemode: GameMode | None = None, - mods: int | list[APIMod] | list[str] = 0, -) -> BeatmapAttributes: - map = rosu.Beatmap(content=beatmap) - if gamemode is not None: - map.convert(gamemode.to_rosu(), mods) - diff = rosu.Difficulty(mods=mods).calculate(map) - return BeatmapAttributes( - star_rating=diff.stars, - max_combo=diff.max_combo, - aim_difficulty=diff.aim, - aim_difficult_slider_count=diff.aim_difficult_slider_count, - speed_difficulty=diff.speed, - speed_note_count=diff.speed_note_count, - slider_factor=diff.slider_factor, - aim_difficult_strain_count=diff.aim_difficult_strain_count, - speed_difficult_strain_count=diff.speed_difficult_strain_count, - mono_stamina_factor=diff.stamina, - ) diff --git a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py new file mode 100644 index 0000000..d90ec3d --- /dev/null +++ b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py @@ -0,0 +1,36 @@ +"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator + +Revision ID: dc4d25c428c7 +Revises: +Create Date: 2025-07-29 01:43:40.221070 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "dc4d25c428c7" +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True)) + op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("scores", "nsmall_tick_hit") + op.drop_column("scores", "nlarge_tick_hit") + # ### end Alembic commands ###