Files
g0v0-server/app/router/v2/beatmap.py
MingxuanGame 2c81e22749 feat(calculator): support generate PerformanceAttributes & DifficultyAttributes from JSON Schema (#59)
Prepare for custom rulesets.

Schema Genetator: https://github.com/GooGuTeam/custom-rulesets/tree/main/CustomRulesetMetadataGenerator

```bash
dotnet -- schemas path/to/rulesets -o schema.json
```

```bash
python scripts/generate_ruleset_attributes.py schema.json 
```
2025-10-25 19:10:53 +08:00

179 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import hashlib
import json
from typing import Annotated
from app.calculators.performance import ConvertError
from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.mods import APIMod, int_to_mods
from app.models.performance import (
DifficultyAttributes,
DifficultyAttributesUnion,
)
from app.models.score import (
GameMode,
)
from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from sqlmodel import col, select
class BatchGetResp(BaseModel):
"""批量获取谱面返回模型。
返回字段说明:
- beatmaps: 谱面详细信息列表。"""
beatmaps: list[BeatmapResp]
@router.get(
"/beatmaps/lookup",
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
@asset_proxy_response
async def lookup_beatmap(
db: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
):
if id is None and md5 is None and filename is None:
raise HTTPException(
status_code=400,
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
)
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
await db.refresh(current_user)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
@router.get(
"/beatmaps/{beatmap_id}",
tags=["谱面"],
name="获取谱面详情",
response_model=BeatmapResp,
description="获取单个谱面详情。",
)
@asset_proxy_response
async def get_beatmap(
db: Database,
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@router.get(
"/beatmaps/",
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
@asset_proxy_response
async def batch_get_beatmaps(
db: Database,
beatmap_ids: Annotated[
list[int],
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
if not beatmap_ids:
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
else:
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
)
for beatmap in beatmaps:
await db.refresh(beatmap)
await db.refresh(current_user)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
@router.post(
"/beatmaps/{beatmap_id}/attributes",
tags=["谱面"],
name="计算谱面属性",
response_model=DifficultyAttributesUnion,
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
)
async def get_beatmap_attributes(
db: Database,
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mods: Annotated[
list[str],
Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
],
redis: Redis,
fetcher: Fetcher,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset为空则使用谱面自身模式")] = None,
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
):
mods_ = []
if mods and mods[0].isdigit():
mods_ = int_to_mods(int(mods[0]))
else:
for i in mods:
try:
mods_.append(json.loads(i))
except json.JSONDecodeError:
mods_.append(APIMod(acronym=i, settings={}))
mods_.sort(key=lambda x: x["acronym"])
if ruleset_id is not None and ruleset is None:
ruleset = GameMode.from_int(ruleset_id)
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
)
if await redis.exists(key):
return DifficultyAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except ConvertError as e:
raise HTTPException(status_code=400, detail=str(e))