diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 46fdd96..d821e27 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -114,19 +114,25 @@ class Beatmap(BeatmapBase, table=True): @classmethod async def get_or_fetch( - cls, session: AsyncSession, bid: int, fetcher: "Fetcher" + cls, + session: AsyncSession, + fetcher: "Fetcher", + bid: int | None = None, + md5: str | None = None, ) -> "Beatmap": beatmap = ( await session.exec( select(Beatmap) - .where(Beatmap.id == bid) + .where( + Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 + ) .options( joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] ) ) ).first() if not beatmap: - resp = await fetcher.get_beatmap(bid) + resp = await fetcher.get_beatmap(bid, md5) r = await session.exec( select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id) ) diff --git a/app/fetcher/beatmap.py b/app/fetcher/beatmap.py index 8e770f1..dabfb68 100644 --- a/app/fetcher/beatmap.py +++ b/app/fetcher/beatmap.py @@ -8,11 +8,20 @@ from httpx import AsyncClient class BeatmapFetcher(BaseFetcher): - async def get_beatmap(self, beatmap_id: int) -> BeatmapResp: + async def get_beatmap( + self, beatmap_id: int | None = None, beatmap_checksum: str | None = None + ) -> BeatmapResp: + if beatmap_id: + params = {"id": beatmap_id} + elif beatmap_checksum: + params = {"checksum": beatmap_checksum} + else: + raise ValueError("Either beatmap_id or beatmap_checksum must be provided.") async with AsyncClient() as client: response = await client.get( - f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}", + "https://osu.ppy.sh/api/v2/beatmaps/lookup", headers=self.header, + params=params, ) response.raise_for_status() return BeatmapResp.model_validate(response.json()) diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 4cf717e..cf59148 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -34,6 +34,31 @@ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +@router.get("/beatmaps/lookup", tags=["beatmap"], response_model=BeatmapResp) +async def lookup_beatmap( + id: int | None = Query(default=None, alias="id"), + md5: str | None = Query(default=None, alias="checksum"), + filename: str | None = Query(default=None, alias="filename"), + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), +): + if id is None and md5 is None and filename is None: + raise HTTPException( + status_code=400, + detail="At least one of 'id', 'checksum', or 'filename' must be provided.", + ) + try: + beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5) + except HTTPError: + raise HTTPException(status_code=404, detail="Beatmap not found") + + if beatmap is None: + raise HTTPException(status_code=404, detail="Beatmap not found") + + return BeatmapResp.from_db(beatmap) + + @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, @@ -42,7 +67,7 @@ async def get_beatmap( fetcher: Fetcher = Depends(get_fetcher), ): try: - beatmap = await Beatmap.get_or_fetch(db, bid, fetcher) + beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) return BeatmapResp.from_db(beatmap) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -122,7 +147,7 @@ async def get_beatmap_attributes( if ruleset_id is not None and ruleset is None: ruleset = INT_TO_MODE[ruleset_id] if ruleset is None: - beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher) + beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap) ruleset = beatmap_db.mode key = ( f"beatmap:{beatmap}:{ruleset}:"