feat(score): support calculate pp
This commit is contained in:
59
app/calculator.py
Normal file
59
app/calculator.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.score import Score
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
def calculate_beatmap_attribute(
|
||||
beatmap: str,
|
||||
gamemode: GameMode | None = None,
|
||||
mods: int | list[APIMod] | list[str] = 0,
|
||||
) -> BeatmapAttributes:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
if gamemode is not None:
|
||||
map.convert(gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||
diff = rosu.Difficulty(mods=mods).calculate(map)
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count,
|
||||
speed_difficulty=diff.speed,
|
||||
speed_note_count=diff.speed_note_count,
|
||||
slider_factor=diff.slider_factor,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count,
|
||||
mono_stamina_factor=diff.stamina,
|
||||
)
|
||||
|
||||
|
||||
def calculate_pp(
|
||||
score: Score,
|
||||
beatmap: str,
|
||||
) -> float:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
|
||||
if map.is_suspicious():
|
||||
return 0.0
|
||||
perf = rosu.Performance(
|
||||
mods=score.mods,
|
||||
lazer=True,
|
||||
accuracy=score.accuracy,
|
||||
combo=score.max_combo,
|
||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||
slider_end_hits=score.nslider_tail_hit or 0,
|
||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
hitresult_priority=rosu.HitResultPriority.Fastest,
|
||||
)
|
||||
attrs = perf.calculate(map)
|
||||
return attrs.pp
|
||||
@@ -49,7 +49,7 @@ class ScoreBase(SQLModel):
|
||||
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
||||
passed: bool
|
||||
playlist_item_id: int | None = Field(default=None) # multiplayer
|
||||
pp: float
|
||||
pp: float = Field(default=0.0)
|
||||
preserve: bool = Field(default=True)
|
||||
rank: Rank
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
@@ -87,7 +87,9 @@ class Score(ScoreBase, table=True):
|
||||
ngeki: int = Field(exclude=True)
|
||||
nkatu: int = Field(exclude=True)
|
||||
nlarge_tick_miss: int | None = Field(default=None, exclude=True)
|
||||
nlarge_tick_hit: int | None = Field(default=None, exclude=True)
|
||||
nslider_tail_hit: int | None = Field(default=None, exclude=True)
|
||||
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
|
||||
# optional
|
||||
@@ -176,6 +178,10 @@ class ScoreResp(ScoreBase):
|
||||
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
||||
if score.nslider_tail_hit is not None:
|
||||
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
||||
if score.nsmall_tick_hit is not None:
|
||||
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
|
||||
if score.nlarge_tick_hit is not None:
|
||||
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
|
||||
# s.user = await convert_db_user_to_api_user(score.user)
|
||||
s.rank_global = (
|
||||
await get_score_position_by_id(
|
||||
|
||||
@@ -4,6 +4,7 @@ from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
from loguru import logger
|
||||
import redis
|
||||
|
||||
|
||||
class OsuDotDirectFetcher(BaseFetcher):
|
||||
@@ -17,3 +18,12 @@ class OsuDotDirectFetcher(BaseFetcher):
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
async def get_or_fetch_beatmap_raw(
|
||||
self, redis: redis.Redis, beatmap_id: int
|
||||
) -> str:
|
||||
if redis.exists(f"beatmap:{beatmap_id}:raw"):
|
||||
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
||||
raw = await self.get_beatmap_raw(beatmap_id)
|
||||
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
||||
return raw
|
||||
|
||||
@@ -105,3 +105,65 @@ def mods_to_int(mods: list[APIMod]) -> int:
|
||||
for mod in mods:
|
||||
sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0)
|
||||
return sum_
|
||||
|
||||
|
||||
NO_CHECK = "DO_NO_CHECK"
|
||||
|
||||
# FIXME: 这里为空表示了两种情况:mod 没有配置项;任何时候都可以获得 pp
|
||||
# 如果是后者,则 mod 更新的时候可能会误判。
|
||||
COMMON_CONFIG: dict[str, dict] = {
|
||||
"EZ": {"retries": 2},
|
||||
"NF": {},
|
||||
"HT": {"speed_change": 0.75, "adjust_pitch": NO_CHECK},
|
||||
"DC": {"speed_change": 0.75},
|
||||
"HR": {},
|
||||
"SD": {},
|
||||
"PF": {},
|
||||
"HD": {},
|
||||
"DT": {"speed_change": 1.5, "adjust_pitch": NO_CHECK},
|
||||
"NC": {"speed_change": 1.5},
|
||||
"FL": {"size_multiplier": 1.0, "combo_based_size": True},
|
||||
"AC": {},
|
||||
"MU": {},
|
||||
"TD": {},
|
||||
}
|
||||
|
||||
RANKED_MODS: dict[int, dict[str, dict]] = {
|
||||
0: COMMON_CONFIG,
|
||||
1: COMMON_CONFIG,
|
||||
2: COMMON_CONFIG,
|
||||
3: COMMON_CONFIG,
|
||||
}
|
||||
# osu
|
||||
RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False
|
||||
RANKED_MODS[0]["FL"]["follow_delay"] = 1.0
|
||||
RANKED_MODS[0]["BL"] = {}
|
||||
RANKED_MODS[0]["NS"] = {}
|
||||
RANKED_MODS[0]["SO"] = {}
|
||||
RANKED_MODS[0]["TC"] = {}
|
||||
# taiko
|
||||
del RANKED_MODS[1]["EZ"]["retries"]
|
||||
# catch
|
||||
RANKED_MODS[2]["NS"] = {}
|
||||
# mania
|
||||
del RANKED_MODS[3]["HR"]
|
||||
RANKED_MODS[3]["FL"]["combo_based_size"] = False
|
||||
RANKED_MODS[3]["MR"] = {}
|
||||
for i in range(4, 10):
|
||||
RANKED_MODS[3][f"{i}K"] = {}
|
||||
|
||||
|
||||
def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
|
||||
ranked_mods = RANKED_MODS[ruleset_id]
|
||||
for mod in mods:
|
||||
mod["settings"] = mod.get("settings", {})
|
||||
if (settings := ranked_mods.get(mod["acronym"])) is None:
|
||||
return False
|
||||
if settings == {}:
|
||||
continue
|
||||
for setting, value in mod["settings"].items():
|
||||
if (expected_value := settings.get(setting)) is None:
|
||||
return False
|
||||
if expected_value != NO_CHECK and value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.calculator import calculate_beatmap_attribute
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
@@ -20,7 +21,6 @@ from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
)
|
||||
from app.utils import calculate_beatmap_attribute
|
||||
|
||||
from .api_router import router
|
||||
|
||||
@@ -157,7 +157,7 @@ async def get_beatmap_attributes(
|
||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
try:
|
||||
resp = await fetcher.get_beatmap_raw(beatmap)
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
||||
try:
|
||||
attr = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
from app.calculator import calculate_pp
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.database.score_token import ScoreToken, ScoreTokenResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_can_get_pp
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
@@ -21,6 +27,7 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select, true
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -187,6 +194,8 @@ async def submit_solo_score(
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
@@ -214,6 +223,13 @@ async def submit_solo_score(
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
beatmap_status = (
|
||||
await db.exec(
|
||||
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
|
||||
)
|
||||
).first()
|
||||
if beatmap_status is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
score = Score(
|
||||
accuracy=info.accuracy,
|
||||
max_combo=info.max_combo,
|
||||
@@ -231,7 +247,6 @@ async def submit_solo_score(
|
||||
preserve=info.passed,
|
||||
map_md5=score_token.beatmap.checksum,
|
||||
has_replay=False,
|
||||
pp=info.pp,
|
||||
type="solo",
|
||||
n300=info.statistics.get(HitResult.GREAT, 0),
|
||||
n100=info.statistics.get(HitResult.OK, 0),
|
||||
@@ -239,7 +254,25 @@ async def submit_solo_score(
|
||||
nmiss=info.statistics.get(HitResult.MISS, 0),
|
||||
ngeki=info.statistics.get(HitResult.PERFECT, 0),
|
||||
nkatu=info.statistics.get(HitResult.GOOD, 0),
|
||||
nlarge_tick_miss=info.statistics.get(HitResult.LARGE_TICK_MISS, 0),
|
||||
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
||||
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
||||
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
||||
)
|
||||
if (
|
||||
info.passed
|
||||
and beatmap_status
|
||||
in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
):
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
||||
pp = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_pp, score, beatmap_raw
|
||||
)
|
||||
score.pp = pp
|
||||
db.add(score)
|
||||
await db.commit()
|
||||
await db.refresh(score)
|
||||
|
||||
64
app/utils.py
64
app/utils.py
@@ -8,9 +8,6 @@ from app.database import (
|
||||
LazerUserStatistics,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
@@ -26,8 +23,6 @@ from app.models.user import (
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
def unix_timestamp_to_windows(timestamp: int) -> int:
|
||||
"""Convert a Unix timestamp to a Windows timestamp."""
|
||||
@@ -407,15 +402,33 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") ->
|
||||
current_season_stats=None,
|
||||
daily_challenge_user_stats=DailyChallengeStats(
|
||||
user_id=user_id,
|
||||
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best if db_user.daily_challenge_stats else 0,
|
||||
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current if db_user.daily_challenge_stats else 0,
|
||||
last_update=db_user.daily_challenge_stats.last_update if db_user.daily_challenge_stats else None,
|
||||
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak if db_user.daily_challenge_stats else None,
|
||||
playcount=db_user.daily_challenge_stats.playcount if db_user.daily_challenge_stats else 0,
|
||||
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements if db_user.daily_challenge_stats else 0,
|
||||
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements if db_user.daily_challenge_stats else 0,
|
||||
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best if db_user.daily_challenge_stats else 0,
|
||||
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current if db_user.daily_challenge_stats else 0,
|
||||
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
last_update=db_user.daily_challenge_stats.last_update
|
||||
if db_user.daily_challenge_stats
|
||||
else None,
|
||||
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
|
||||
if db_user.daily_challenge_stats
|
||||
else None,
|
||||
playcount=db_user.daily_challenge_stats.playcount
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
),
|
||||
groups=[],
|
||||
monthly_playcounts=monthly_playcounts,
|
||||
@@ -450,26 +463,3 @@ def get_country_name(country_code: str) -> str:
|
||||
# 可以添加更多国家
|
||||
}
|
||||
return country_names.get(country_code, "Unknown")
|
||||
|
||||
|
||||
def calculate_beatmap_attribute(
|
||||
beatmap: str,
|
||||
gamemode: GameMode | None = None,
|
||||
mods: int | list[APIMod] | list[str] = 0,
|
||||
) -> BeatmapAttributes:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
if gamemode is not None:
|
||||
map.convert(gamemode.to_rosu(), mods)
|
||||
diff = rosu.Difficulty(mods=mods).calculate(map)
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count,
|
||||
speed_difficulty=diff.speed,
|
||||
speed_note_count=diff.speed_note_count,
|
||||
slider_factor=diff.slider_factor,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count,
|
||||
mono_stamina_factor=diff.stamina,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
|
||||
|
||||
Revision ID: dc4d25c428c7
|
||||
Revises:
|
||||
Create Date: 2025-07-29 01:43:40.221070
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc4d25c428c7"
|
||||
down_revision: str | Sequence[str] | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
|
||||
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("scores", "nsmall_tick_hit")
|
||||
op.drop_column("scores", "nlarge_tick_hit")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user