feat(score): support calculate pp

This commit is contained in:
MingxuanGame
2025-07-29 02:56:21 +00:00
parent a78a889c5d
commit 223fa99692
8 changed files with 238 additions and 42 deletions

59
app/calculator.py Normal file
View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from app.database.score import Score
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod
from app.models.score import GameMode
import rosu_pp_py as rosu
def calculate_beatmap_attribute(
beatmap: str,
gamemode: GameMode | None = None,
mods: int | list[APIMod] | list[str] = 0,
) -> BeatmapAttributes:
map = rosu.Beatmap(content=beatmap)
if gamemode is not None:
map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
diff = rosu.Difficulty(mods=mods).calculate(map)
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim,
aim_difficult_slider_count=diff.aim_difficult_slider_count,
speed_difficulty=diff.speed,
speed_note_count=diff.speed_note_count,
slider_factor=diff.slider_factor,
aim_difficult_strain_count=diff.aim_difficult_strain_count,
speed_difficult_strain_count=diff.speed_difficult_strain_count,
mono_stamina_factor=diff.stamina,
)
def calculate_pp(
score: Score,
beatmap: str,
) -> float:
map = rosu.Beatmap(content=beatmap)
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
if map.is_suspicious():
return 0.0
perf = rosu.Performance(
mods=score.mods,
lazer=True,
accuracy=score.accuracy,
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
hitresult_priority=rosu.HitResultPriority.Fastest,
)
attrs = perf.calculate(map)
return attrs.pp

View File

@@ -49,7 +49,7 @@ class ScoreBase(SQLModel):
mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool
playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float
pp: float = Field(default=0.0)
preserve: bool = Field(default=True)
rank: Rank
room_id: int | None = Field(default=None) # multiplayer
@@ -87,7 +87,9 @@ class Score(ScoreBase, table=True):
ngeki: int = Field(exclude=True)
nkatu: int = Field(exclude=True)
nlarge_tick_miss: int | None = Field(default=None, exclude=True)
nlarge_tick_hit: int | None = Field(default=None, exclude=True)
nslider_tail_hit: int | None = Field(default=None, exclude=True)
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
# optional
@@ -176,6 +178,10 @@ class ScoreResp(ScoreBase):
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
# s.user = await convert_db_user_to_api_user(score.user)
s.rank_global = (
await get_score_position_by_id(

View File

@@ -4,6 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
import redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -17,3 +18,12 @@ class OsuDotDirectFetcher(BaseFetcher):
)
response.raise_for_status()
return response.text
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw

View File

@@ -105,3 +105,65 @@ def mods_to_int(mods: list[APIMod]) -> int:
for mod in mods:
sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0)
return sum_
NO_CHECK = "DO_NO_CHECK"
# FIXME: 这里为空表示了两种情况mod 没有配置项;任何时候都可以获得 pp
# 如果是后者,则 mod 更新的时候可能会误判。
COMMON_CONFIG: dict[str, dict] = {
"EZ": {"retries": 2},
"NF": {},
"HT": {"speed_change": 0.75, "adjust_pitch": NO_CHECK},
"DC": {"speed_change": 0.75},
"HR": {},
"SD": {},
"PF": {},
"HD": {},
"DT": {"speed_change": 1.5, "adjust_pitch": NO_CHECK},
"NC": {"speed_change": 1.5},
"FL": {"size_multiplier": 1.0, "combo_based_size": True},
"AC": {},
"MU": {},
"TD": {},
}
RANKED_MODS: dict[int, dict[str, dict]] = {
0: COMMON_CONFIG,
1: COMMON_CONFIG,
2: COMMON_CONFIG,
3: COMMON_CONFIG,
}
# osu
RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False
RANKED_MODS[0]["FL"]["follow_delay"] = 1.0
RANKED_MODS[0]["BL"] = {}
RANKED_MODS[0]["NS"] = {}
RANKED_MODS[0]["SO"] = {}
RANKED_MODS[0]["TC"] = {}
# taiko
del RANKED_MODS[1]["EZ"]["retries"]
# catch
RANKED_MODS[2]["NS"] = {}
# mania
del RANKED_MODS[3]["HR"]
RANKED_MODS[3]["FL"]["combo_based_size"] = False
RANKED_MODS[3]["MR"] = {}
for i in range(4, 10):
RANKED_MODS[3][f"{i}K"] = {}
def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods:
mod["settings"] = mod.get("settings", {})
if (settings := ranked_mods.get(mod["acronym"])) is None:
return False
if settings == {}:
continue
for setting, value in mod["settings"].items():
if (expected_value := settings.get(setting)) is None:
return False
if expected_value != NO_CHECK and value != expected_value:
return False
return True

View File

@@ -4,6 +4,7 @@ import asyncio
import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import (
Beatmap,
BeatmapResp,
@@ -20,7 +21,6 @@ from app.models.score import (
INT_TO_MODE,
GameMode,
)
from app.utils import calculate_beatmap_attribute
from .api_router import router
@@ -157,7 +157,7 @@ async def get_beatmap_attributes(
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_beatmap_raw(beatmap)
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
try:
attr = await asyncio.get_event_loop().run_in_executor(
None, calculate_beatmap_attribute, resp, ruleset, mods_

View File

@@ -1,14 +1,20 @@
from __future__ import annotations
import asyncio
import datetime
from app.calculator import calculate_pp
from app.database import (
User as DBUser,
)
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_can_get_pp
from app.models.score import (
INT_TO_MODE,
GameMode,
@@ -21,6 +27,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from redis import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select, true
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -187,6 +194,8 @@ async def submit_solo_score(
info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
if not info.passed:
info.rank = Rank.F
@@ -214,6 +223,13 @@ async def submit_solo_score(
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
)
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
score = Score(
accuracy=info.accuracy,
max_combo=info.max_combo,
@@ -231,7 +247,6 @@ async def submit_solo_score(
preserve=info.passed,
map_md5=score_token.beatmap.checksum,
has_replay=False,
pp=info.pp,
type="solo",
n300=info.statistics.get(HitResult.GREAT, 0),
n100=info.statistics.get(HitResult.OK, 0),
@@ -239,7 +254,25 @@ async def submit_solo_score(
nmiss=info.statistics.get(HitResult.MISS, 0),
ngeki=info.statistics.get(HitResult.PERFECT, 0),
nkatu=info.statistics.get(HitResult.GOOD, 0),
nlarge_tick_miss=info.statistics.get(HitResult.LARGE_TICK_MISS, 0),
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
)
if (
info.passed
and beatmap_status
in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
and mods_can_get_pp(info.ruleset_id, info.mods)
):
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
pp = await asyncio.get_event_loop().run_in_executor(
None, calculate_pp, score, beatmap_raw
)
score.pp = pp
db.add(score)
await db.commit()
await db.refresh(score)

View File

@@ -8,9 +8,6 @@ from app.database import (
LazerUserStatistics,
User as DBUser,
)
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod
from app.models.score import GameMode
from app.models.user import (
Country,
Cover,
@@ -26,8 +23,6 @@ from app.models.user import (
UserAchievement,
)
import rosu_pp_py as rosu
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
@@ -407,15 +402,33 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") ->
current_season_stats=None,
daily_challenge_user_stats=DailyChallengeStats(
user_id=user_id,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best if db_user.daily_challenge_stats else 0,
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current if db_user.daily_challenge_stats else 0,
last_update=db_user.daily_challenge_stats.last_update if db_user.daily_challenge_stats else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak if db_user.daily_challenge_stats else None,
playcount=db_user.daily_challenge_stats.playcount if db_user.daily_challenge_stats else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements if db_user.daily_challenge_stats else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements if db_user.daily_challenge_stats else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best if db_user.daily_challenge_stats else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current if db_user.daily_challenge_stats else 0,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
if db_user.daily_challenge_stats
else 0,
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
if db_user.daily_challenge_stats
else 0,
last_update=db_user.daily_challenge_stats.last_update
if db_user.daily_challenge_stats
else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
if db_user.daily_challenge_stats
else None,
playcount=db_user.daily_challenge_stats.playcount
if db_user.daily_challenge_stats
else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
if db_user.daily_challenge_stats
else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
if db_user.daily_challenge_stats
else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
if db_user.daily_challenge_stats
else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
if db_user.daily_challenge_stats
else 0,
),
groups=[],
monthly_playcounts=monthly_playcounts,
@@ -450,26 +463,3 @@ def get_country_name(country_code: str) -> str:
# 可以添加更多国家
}
return country_names.get(country_code, "Unknown")
def calculate_beatmap_attribute(
beatmap: str,
gamemode: GameMode | None = None,
mods: int | list[APIMod] | list[str] = 0,
) -> BeatmapAttributes:
map = rosu.Beatmap(content=beatmap)
if gamemode is not None:
map.convert(gamemode.to_rosu(), mods)
diff = rosu.Difficulty(mods=mods).calculate(map)
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim,
aim_difficult_slider_count=diff.aim_difficult_slider_count,
speed_difficulty=diff.speed,
speed_note_count=diff.speed_note_count,
slider_factor=diff.slider_factor,
aim_difficult_strain_count=diff.aim_difficult_strain_count,
speed_difficult_strain_count=diff.speed_difficult_strain_count,
mono_stamina_factor=diff.stamina,
)

View File

@@ -0,0 +1,36 @@
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
Revision ID: dc4d25c428c7
Revises:
Create Date: 2025-07-29 01:43:40.221070
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dc4d25c428c7"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "nsmall_tick_hit")
op.drop_column("scores", "nlarge_tick_hit")
# ### end Alembic commands ###