refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
|
||||
tags=["谱面"],
|
||||
name="查询单个谱面",
|
||||
response_model=BeatmapResp,
|
||||
description=(
|
||||
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
|
||||
"至少提供 id / checksum / filename 之一。"
|
||||
),
|
||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(
|
||||
default=None, alias="filename", description="谱面文件名"
|
||||
),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -96,43 +91,23 @@ async def get_beatmap(
|
||||
tags=["谱面"],
|
||||
name="批量获取谱面",
|
||||
response_model=BatchGetResp,
|
||||
description=(
|
||||
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
|
||||
"为空时按最近更新时间返回。"
|
||||
),
|
||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(
|
||||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||||
),
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
else:
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
|
||||
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(
|
||||
default=None, description="指定 ruleset;为空则使用谱面自身模式"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
return await calculate_beatmap_attributes(
|
||||
beatmap_id, ruleset, mods_, redis, fetcher
|
||||
)
|
||||
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
Reference in New Issue
Block a user