Files
g0v0-server/app/router/v2/beatmap.py

193 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import hashlib
import json
from typing import Annotated
from app.calculator import get_calculator
from app.calculators.performance import ConvertError
from app.database import (
Beatmap,
BeatmapModel,
User,
)
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.mods import APIMod, int_to_mods
from app.models.performance import (
DifficultyAttributes,
DifficultyAttributesUnion,
)
from app.models.score import (
GameMode,
)
from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from sqlmodel import col, select
@router.get(
"/beatmaps/lookup",
tags=["谱面"],
name="查询单个谱面",
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
@asset_proxy_response
async def lookup_beatmap(
db: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
):
if id is None and md5 is None and filename is None:
raise HTTPException(
status_code=400,
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
)
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
await db.refresh(current_user)
return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
@router.get(
"/beatmaps/{beatmap_id}",
tags=["谱面"],
name="获取谱面详情",
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description="获取单个谱面详情。",
)
@asset_proxy_response
async def get_beatmap(
db: Database,
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
await db.refresh(current_user)
return await BeatmapModel.transform(
beatmap,
user=current_user,
includes=BeatmapModel.TRANSFORMER_INCLUDES,
)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@router.get(
"/beatmaps/",
tags=["谱面"],
name="批量获取谱面",
responses={
200: api_doc(
"谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
)
},
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
@asset_proxy_response
async def batch_get_beatmaps(
db: Database,
beatmap_ids: Annotated[
list[int],
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
if not beatmap_ids:
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
else:
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
)
for beatmap in beatmaps:
await db.refresh(beatmap)
await db.refresh(current_user)
return {
"beatmaps": [
await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
for bm in beatmaps
]
}
@router.post(
"/beatmaps/{beatmap_id}/attributes",
tags=["谱面"],
name="计算谱面属性",
response_model=DifficultyAttributesUnion,
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
)
async def get_beatmap_attributes(
db: Database,
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mods: Annotated[
list[str],
Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
],
redis: Redis,
fetcher: Fetcher,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset为空则使用谱面自身模式")] = None,
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)")] = None,
):
mods_ = []
if mods and mods[0].isdigit():
mods_ = int_to_mods(int(mods[0]))
else:
for i in mods:
try:
mods_.append(json.loads(i))
except json.JSONDecodeError:
mods_.append(APIMod(acronym=i, settings={}))
mods_.sort(key=lambda x: x["acronym"])
if ruleset_id is not None and ruleset is None:
ruleset = GameMode.from_int(ruleset_id)
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
)
if await redis.exists(key):
return DifficultyAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
if await get_calculator().can_calculate_difficulty(ruleset) is False:
raise HTTPException(status_code=422, detail="Cannot calculate difficulty for the specified ruleset")
try:
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except ConvertError as e:
raise HTTPException(status_code=400, detail=str(e))