feat(score): support calculate pp

This commit is contained in:
MingxuanGame
2025-07-29 02:56:21 +00:00
parent a78a889c5d
commit 223fa99692
8 changed files with 238 additions and 42 deletions

59
app/calculator.py Normal file
View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from app.database.score import Score
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod
from app.models.score import GameMode
import rosu_pp_py as rosu
def calculate_beatmap_attribute(
beatmap: str,
gamemode: GameMode | None = None,
mods: int | list[APIMod] | list[str] = 0,
) -> BeatmapAttributes:
map = rosu.Beatmap(content=beatmap)
if gamemode is not None:
map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
diff = rosu.Difficulty(mods=mods).calculate(map)
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim,
aim_difficult_slider_count=diff.aim_difficult_slider_count,
speed_difficulty=diff.speed,
speed_note_count=diff.speed_note_count,
slider_factor=diff.slider_factor,
aim_difficult_strain_count=diff.aim_difficult_strain_count,
speed_difficult_strain_count=diff.speed_difficult_strain_count,
mono_stamina_factor=diff.stamina,
)
def calculate_pp(
score: Score,
beatmap: str,
) -> float:
map = rosu.Beatmap(content=beatmap)
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
if map.is_suspicious():
return 0.0
perf = rosu.Performance(
mods=score.mods,
lazer=True,
accuracy=score.accuracy,
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
hitresult_priority=rosu.HitResultPriority.Fastest,
)
attrs = perf.calculate(map)
return attrs.pp

View File

@@ -49,7 +49,7 @@ class ScoreBase(SQLModel):
mods: list[APIMod] = Field(sa_column=Column(JSON)) mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool passed: bool
playlist_item_id: int | None = Field(default=None) # multiplayer playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float pp: float = Field(default=0.0)
preserve: bool = Field(default=True) preserve: bool = Field(default=True)
rank: Rank rank: Rank
room_id: int | None = Field(default=None) # multiplayer room_id: int | None = Field(default=None) # multiplayer
@@ -87,7 +87,9 @@ class Score(ScoreBase, table=True):
ngeki: int = Field(exclude=True) ngeki: int = Field(exclude=True)
nkatu: int = Field(exclude=True) nkatu: int = Field(exclude=True)
nlarge_tick_miss: int | None = Field(default=None, exclude=True) nlarge_tick_miss: int | None = Field(default=None, exclude=True)
nlarge_tick_hit: int | None = Field(default=None, exclude=True)
nslider_tail_hit: int | None = Field(default=None, exclude=True) nslider_tail_hit: int | None = Field(default=None, exclude=True)
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
# optional # optional
@@ -176,6 +178,10 @@ class ScoreResp(ScoreBase):
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None: if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
# s.user = await convert_db_user_to_api_user(score.user) # s.user = await convert_db_user_to_api_user(score.user)
s.rank_global = ( s.rank_global = (
await get_score_position_by_id( await get_score_position_by_id(

View File

@@ -4,6 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient from httpx import AsyncClient
from loguru import logger from loguru import logger
import redis
class OsuDotDirectFetcher(BaseFetcher): class OsuDotDirectFetcher(BaseFetcher):
@@ -17,3 +18,12 @@ class OsuDotDirectFetcher(BaseFetcher):
) )
response.raise_for_status() response.raise_for_status()
return response.text return response.text
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw

View File

@@ -105,3 +105,65 @@ def mods_to_int(mods: list[APIMod]) -> int:
for mod in mods: for mod in mods:
sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0) sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0)
return sum_ return sum_
NO_CHECK = "DO_NO_CHECK"
# FIXME: 这里为空表示了两种情况mod 没有配置项;任何时候都可以获得 pp
# 如果是后者,则 mod 更新的时候可能会误判。
COMMON_CONFIG: dict[str, dict] = {
"EZ": {"retries": 2},
"NF": {},
"HT": {"speed_change": 0.75, "adjust_pitch": NO_CHECK},
"DC": {"speed_change": 0.75},
"HR": {},
"SD": {},
"PF": {},
"HD": {},
"DT": {"speed_change": 1.5, "adjust_pitch": NO_CHECK},
"NC": {"speed_change": 1.5},
"FL": {"size_multiplier": 1.0, "combo_based_size": True},
"AC": {},
"MU": {},
"TD": {},
}
RANKED_MODS: dict[int, dict[str, dict]] = {
0: COMMON_CONFIG,
1: COMMON_CONFIG,
2: COMMON_CONFIG,
3: COMMON_CONFIG,
}
# osu
RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False
RANKED_MODS[0]["FL"]["follow_delay"] = 1.0
RANKED_MODS[0]["BL"] = {}
RANKED_MODS[0]["NS"] = {}
RANKED_MODS[0]["SO"] = {}
RANKED_MODS[0]["TC"] = {}
# taiko
del RANKED_MODS[1]["EZ"]["retries"]
# catch
RANKED_MODS[2]["NS"] = {}
# mania
del RANKED_MODS[3]["HR"]
RANKED_MODS[3]["FL"]["combo_based_size"] = False
RANKED_MODS[3]["MR"] = {}
for i in range(4, 10):
RANKED_MODS[3][f"{i}K"] = {}
def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods:
mod["settings"] = mod.get("settings", {})
if (settings := ranked_mods.get(mod["acronym"])) is None:
return False
if settings == {}:
continue
for setting, value in mod["settings"].items():
if (expected_value := settings.get(setting)) is None:
return False
if expected_value != NO_CHECK and value != expected_value:
return False
return True

View File

@@ -4,6 +4,7 @@ import asyncio
import hashlib import hashlib
import json import json
from app.calculator import calculate_beatmap_attribute
from app.database import ( from app.database import (
Beatmap, Beatmap,
BeatmapResp, BeatmapResp,
@@ -20,7 +21,6 @@ from app.models.score import (
INT_TO_MODE, INT_TO_MODE,
GameMode, GameMode,
) )
from app.utils import calculate_beatmap_attribute
from .api_router import router from .api_router import router
@@ -157,7 +157,7 @@ async def get_beatmap_attributes(
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
try: try:
resp = await fetcher.get_beatmap_raw(beatmap) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
try: try:
attr = await asyncio.get_event_loop().run_in_executor( attr = await asyncio.get_event_loop().run_in_executor(
None, calculate_beatmap_attribute, resp, ruleset, mods_ None, calculate_beatmap_attribute, resp, ruleset, mods_

View File

@@ -1,14 +1,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import datetime import datetime
from app.calculator import calculate_pp
from app.database import ( from app.database import (
User as DBUser, User as DBUser,
) )
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp from app.database.score import Score, ScoreResp
from app.database.score_token import ScoreToken, ScoreTokenResp from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_can_get_pp
from app.models.score import ( from app.models.score import (
INT_TO_MODE, INT_TO_MODE,
GameMode, GameMode,
@@ -21,6 +27,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import col, select, true from sqlmodel import col, select, true
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -187,6 +194,8 @@ async def submit_solo_score(
info: SoloScoreSubmissionInfo, info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
): ):
if not info.passed: if not info.passed:
info.rank = Rank.F info.rank = Rank.F
@@ -214,6 +223,13 @@ async def submit_solo_score(
if not score: if not score:
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")
else: else:
beatmap_status = (
await db.exec(
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
)
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
score = Score( score = Score(
accuracy=info.accuracy, accuracy=info.accuracy,
max_combo=info.max_combo, max_combo=info.max_combo,
@@ -231,7 +247,6 @@ async def submit_solo_score(
preserve=info.passed, preserve=info.passed,
map_md5=score_token.beatmap.checksum, map_md5=score_token.beatmap.checksum,
has_replay=False, has_replay=False,
pp=info.pp,
type="solo", type="solo",
n300=info.statistics.get(HitResult.GREAT, 0), n300=info.statistics.get(HitResult.GREAT, 0),
n100=info.statistics.get(HitResult.OK, 0), n100=info.statistics.get(HitResult.OK, 0),
@@ -239,7 +254,25 @@ async def submit_solo_score(
nmiss=info.statistics.get(HitResult.MISS, 0), nmiss=info.statistics.get(HitResult.MISS, 0),
ngeki=info.statistics.get(HitResult.PERFECT, 0), ngeki=info.statistics.get(HitResult.PERFECT, 0),
nkatu=info.statistics.get(HitResult.GOOD, 0), nkatu=info.statistics.get(HitResult.GOOD, 0),
nlarge_tick_miss=info.statistics.get(HitResult.LARGE_TICK_MISS, 0),
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
) )
if (
info.passed
and beatmap_status
in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
and mods_can_get_pp(info.ruleset_id, info.mods)
):
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
pp = await asyncio.get_event_loop().run_in_executor(
None, calculate_pp, score, beatmap_raw
)
score.pp = pp
db.add(score) db.add(score)
await db.commit() await db.commit()
await db.refresh(score) await db.refresh(score)

View File

@@ -8,9 +8,6 @@ from app.database import (
LazerUserStatistics, LazerUserStatistics,
User as DBUser, User as DBUser,
) )
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod
from app.models.score import GameMode
from app.models.user import ( from app.models.user import (
Country, Country,
Cover, Cover,
@@ -26,8 +23,6 @@ from app.models.user import (
UserAchievement, UserAchievement,
) )
import rosu_pp_py as rosu
def unix_timestamp_to_windows(timestamp: int) -> int: def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp.""" """Convert a Unix timestamp to a Windows timestamp."""
@@ -407,15 +402,33 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") ->
current_season_stats=None, current_season_stats=None,
daily_challenge_user_stats=DailyChallengeStats( daily_challenge_user_stats=DailyChallengeStats(
user_id=user_id, user_id=user_id,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best if db_user.daily_challenge_stats else 0, daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current if db_user.daily_challenge_stats else 0, if db_user.daily_challenge_stats
last_update=db_user.daily_challenge_stats.last_update if db_user.daily_challenge_stats else None, else 0,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak if db_user.daily_challenge_stats else None, daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
playcount=db_user.daily_challenge_stats.playcount if db_user.daily_challenge_stats else 0, if db_user.daily_challenge_stats
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements if db_user.daily_challenge_stats else 0, else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements if db_user.daily_challenge_stats else 0, last_update=db_user.daily_challenge_stats.last_update
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best if db_user.daily_challenge_stats else 0, if db_user.daily_challenge_stats
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current if db_user.daily_challenge_stats else 0, else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
if db_user.daily_challenge_stats
else None,
playcount=db_user.daily_challenge_stats.playcount
if db_user.daily_challenge_stats
else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
if db_user.daily_challenge_stats
else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
if db_user.daily_challenge_stats
else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
if db_user.daily_challenge_stats
else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
if db_user.daily_challenge_stats
else 0,
), ),
groups=[], groups=[],
monthly_playcounts=monthly_playcounts, monthly_playcounts=monthly_playcounts,
@@ -450,26 +463,3 @@ def get_country_name(country_code: str) -> str:
# 可以添加更多国家 # 可以添加更多国家
} }
return country_names.get(country_code, "Unknown") return country_names.get(country_code, "Unknown")
def calculate_beatmap_attribute(
beatmap: str,
gamemode: GameMode | None = None,
mods: int | list[APIMod] | list[str] = 0,
) -> BeatmapAttributes:
map = rosu.Beatmap(content=beatmap)
if gamemode is not None:
map.convert(gamemode.to_rosu(), mods)
diff = rosu.Difficulty(mods=mods).calculate(map)
return BeatmapAttributes(
star_rating=diff.stars,
max_combo=diff.max_combo,
aim_difficulty=diff.aim,
aim_difficult_slider_count=diff.aim_difficult_slider_count,
speed_difficulty=diff.speed,
speed_note_count=diff.speed_note_count,
slider_factor=diff.slider_factor,
aim_difficult_strain_count=diff.aim_difficult_strain_count,
speed_difficult_strain_count=diff.speed_difficult_strain_count,
mono_stamina_factor=diff.stamina,
)

View File

@@ -0,0 +1,36 @@
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
Revision ID: dc4d25c428c7
Revises:
Create Date: 2025-07-29 01:43:40.221070
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dc4d25c428c7"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "nsmall_tick_hit")
op.drop_column("scores", "nlarge_tick_hit")
# ### end Alembic commands ###