diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 0f4e565..ab849e6 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True): if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) await session.commit() - beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first() - assert beatmap is not None, "Beatmap should not be None after commit" + beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() return beatmap @classmethod @@ -124,9 +123,14 @@ class Beatmap(BeatmapBase, table=True): bid: int | None = None, md5: str | None = None, ) -> "Beatmap": - beatmap = ( - await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5)) - ).first() + stmt = select(Beatmap) + if bid is not None: + stmt = stmt.where(Beatmap.id == bid) + elif md5 is not None: + stmt = stmt.where(Beatmap.checksum == md5) + else: + raise ValueError("Either bid or md5 must be provided") + beatmap = (await session.exec(stmt)).first() if not beatmap: resp = await fetcher.get_beatmap(bid, md5) r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 1ef14f5..8754a02 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -167,6 +167,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): session.add(beatmapset) await session.commit() await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) + beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one() return beatmapset @classmethod