211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
|
||
from app.calculator import calculate_beatmap_attribute
|
||
from app.database import Beatmap, BeatmapResp, User
|
||
from app.dependencies.database import get_db, get_redis
|
||
from app.dependencies.fetcher import get_fetcher
|
||
from app.dependencies.user import get_current_user
|
||
from app.fetcher import Fetcher
|
||
from app.models.beatmap import BeatmapAttributes
|
||
from app.models.mods import APIMod, int_to_mods
|
||
from app.models.score import (
|
||
INT_TO_MODE,
|
||
GameMode,
|
||
)
|
||
|
||
from .router import router
|
||
|
||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||
from httpx import HTTPError, HTTPStatusError
|
||
from pydantic import BaseModel
|
||
from redis.asyncio import Redis
|
||
import rosu_pp_py as rosu
|
||
from sqlmodel import col, select
|
||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
|
||
|
||
class BatchGetResp(BaseModel):
|
||
"""批量获取谱面返回模型。
|
||
|
||
返回字段说明:
|
||
- beatmaps: 谱面详细信息列表。"""
|
||
|
||
beatmaps: list[BeatmapResp]
|
||
|
||
|
||
@router.get(
|
||
"/beatmaps/lookup",
|
||
tags=["谱面"],
|
||
name="查询单个谱面",
|
||
response_model=BeatmapResp,
|
||
description=(
|
||
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
|
||
"至少提供 id / checksum / filename 之一。"
|
||
),
|
||
)
|
||
async def lookup_beatmap(
|
||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||
filename: str | None = Query(
|
||
default=None, alias="filename", description="谱面文件名"
|
||
),
|
||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||
db: AsyncSession = Depends(get_db),
|
||
fetcher: Fetcher = Depends(get_fetcher),
|
||
):
|
||
if id is None and md5 is None and filename is None:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
|
||
)
|
||
try:
|
||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
|
||
except HTTPError:
|
||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||
|
||
if beatmap is None:
|
||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||
|
||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||
|
||
|
||
@router.get(
|
||
"/beatmaps/{beatmap_id}",
|
||
tags=["谱面"],
|
||
name="获取谱面详情",
|
||
response_model=BeatmapResp,
|
||
description="获取单个谱面详情。",
|
||
)
|
||
async def get_beatmap(
|
||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||
db: AsyncSession = Depends(get_db),
|
||
fetcher: Fetcher = Depends(get_fetcher),
|
||
):
|
||
try:
|
||
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||
except HTTPError:
|
||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||
|
||
|
||
@router.get(
|
||
"/beatmaps/",
|
||
tags=["谱面"],
|
||
name="批量获取谱面",
|
||
response_model=BatchGetResp,
|
||
description=(
|
||
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
|
||
"为空时按最近更新时间返回。"
|
||
),
|
||
)
|
||
async def batch_get_beatmaps(
|
||
beatmap_ids: list[int] = Query(
|
||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||
),
|
||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||
db: AsyncSession = Depends(get_db),
|
||
fetcher: Fetcher = Depends(get_fetcher),
|
||
):
|
||
if not beatmap_ids:
|
||
beatmaps = (
|
||
await db.exec(
|
||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||
)
|
||
).all()
|
||
else:
|
||
beatmaps = list(
|
||
(
|
||
await db.exec(
|
||
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
|
||
)
|
||
).all()
|
||
)
|
||
not_found_beatmaps = [
|
||
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
|
||
]
|
||
beatmaps.extend(
|
||
beatmap
|
||
for beatmap in await asyncio.gather(
|
||
*[
|
||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||
for bid in not_found_beatmaps
|
||
],
|
||
return_exceptions=True,
|
||
)
|
||
if isinstance(beatmap, Beatmap)
|
||
)
|
||
for beatmap in beatmaps:
|
||
await db.refresh(beatmap)
|
||
|
||
return BatchGetResp(
|
||
beatmaps=[
|
||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||
for bm in beatmaps
|
||
]
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/beatmaps/{beatmap_id}/attributes",
|
||
tags=["谱面"],
|
||
name="计算谱面属性",
|
||
response_model=BeatmapAttributes,
|
||
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
|
||
)
|
||
async def get_beatmap_attributes(
|
||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||
mods: list[str] = Query(
|
||
default_factory=list,
|
||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||
),
|
||
ruleset: GameMode | None = Query(
|
||
default=None, description="指定 ruleset;为空则使用谱面自身模式"
|
||
),
|
||
ruleset_id: int | None = Query(
|
||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)"
|
||
),
|
||
redis: Redis = Depends(get_redis),
|
||
db: AsyncSession = Depends(get_db),
|
||
fetcher: Fetcher = Depends(get_fetcher),
|
||
):
|
||
mods_ = []
|
||
if mods and mods[0].isdigit():
|
||
mods_ = int_to_mods(int(mods[0]))
|
||
else:
|
||
for i in mods:
|
||
try:
|
||
mods_.append(json.loads(i))
|
||
except json.JSONDecodeError:
|
||
mods_.append(APIMod(acronym=i, settings={}))
|
||
mods_.sort(key=lambda x: x["acronym"])
|
||
if ruleset_id is not None and ruleset is None:
|
||
ruleset = INT_TO_MODE[ruleset_id]
|
||
if ruleset is None:
|
||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||
ruleset = beatmap_db.mode
|
||
key = (
|
||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||
)
|
||
if await redis.exists(key):
|
||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||
|
||
try:
|
||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||
try:
|
||
attr = await asyncio.get_event_loop().run_in_executor(
|
||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||
)
|
||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
await redis.set(key, attr.model_dump_json())
|
||
return attr
|
||
except HTTPStatusError:
|
||
raise HTTPException(status_code=404, detail="Beatmap not found")
|