diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 09d8900..bf08ff0 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -19,6 +19,7 @@ from typing import ( from app.database.beatmap import Beatmap from app.dependencies.database import engine +from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException from .mods import APIMod @@ -518,8 +519,11 @@ class MultiplayerQueue: raise InvokeException("Freestyle items cannot have allowed mods") async with AsyncSession(engine) as session: + fetcher = await get_fetcher() async with session: - beatmap = await session.get(Beatmap, item.beatmap_id) + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=item.beatmap_id + ) if beatmap is None: raise InvokeException("Beatmap not found") if item.beatmap_checksum != beatmap.checksum: @@ -543,10 +547,11 @@ class MultiplayerQueue: raise InvokeException("Freestyle items cannot have allowed mods") async with AsyncSession(engine) as session: + fetcher = await get_fetcher() async with session: - beatmap = await session.get(Beatmap, item.beatmap_id) - if beatmap is None: - raise InvokeException("Beatmap not found") + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=item.beatmap_id + ) if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 7dfd0f9..6800246 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -77,6 +77,7 @@ async def batch_get_beatmaps( b_ids: list[int] = Query(alias="ids[]", default_factory=list), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), ): if not b_ids: # select 50 beatmaps by last_updated @@ -86,9 +87,27 @@ async def batch_get_beatmaps( ) ).all() else: - beatmaps = ( - await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) - ).all() + beatmaps = list( + ( + await db.exec( + select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50) + ) + ).all() + ) + not_found_beatmaps = [ + bid for bid in b_ids if bid not in [bm.id for bm in beatmaps] + ] + beatmaps.extend( + beatmap + for beatmap in await asyncio.gather( + *[ + Beatmap.get_or_fetch(db, fetcher, bid=bid) + for bid in not_found_beatmaps + ], + return_exceptions=True, + ) + if isinstance(beatmap, Beatmap) + ) return BatchGetResp( beatmaps=[ diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index efaabd9..b7aa14c 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -211,7 +211,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException( "Failed to fetch beatmap, please retry later" ) - await Playlist.add_to_db(item, db_room.id, session) + await Playlist.add_to_db(item, room.room_id, session) server_room = ServerMultiplayerRoom( room=room,