From f34ed53a5591dcae11788602b518c977ac5f431d Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 2 Oct 2025 16:37:42 +0000 Subject: [PATCH] fix(beatmap): fix `beatmap.beatmapset` is None when it from `from_resp` --- app/database/beatmap.py | 14 +++++++++----- app/database/beatmapset.py | 1 + 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 0f4e565..ab849e6 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True): if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) await session.commit() - beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first() - assert beatmap is not None, "Beatmap should not be None after commit" + beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() return beatmap @classmethod @@ -124,9 +123,14 @@ class Beatmap(BeatmapBase, table=True): bid: int | None = None, md5: str | None = None, ) -> "Beatmap": - beatmap = ( - await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5)) - ).first() + stmt = select(Beatmap) + if bid is not None: + stmt = stmt.where(Beatmap.id == bid) + elif md5 is not None: + stmt = stmt.where(Beatmap.checksum == md5) + else: + raise ValueError("Either bid or md5 must be provided") + beatmap = (await session.exec(stmt)).first() if not beatmap: resp = await fetcher.get_beatmap(bid, md5) r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 1ef14f5..8754a02 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -167,6 +167,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): session.add(beatmapset) await session.commit() await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) + beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one() return beatmapset @classmethod