From 0ac4f1f516fc29da6333d6f974017cee195bbbf1 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 8 Aug 2025 11:54:43 +0000 Subject: [PATCH] refactor(beatmap,beatmapset): use to ensure beatmap exists --- app/database/beatmapset.py | 12 +++++++++++ app/router/beatmapset.py | 37 +++++++++------------------------- app/router/score.py | 10 ++++----- app/signalr/hub/multiplayer.py | 28 +++++++++++++++++++++---- app/signalr/hub/spectator.py | 27 ++++++++++++------------- 5 files changed, 64 insertions(+), 50 deletions(-) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 3bad7e9..12f3c67 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -13,6 +13,8 @@ from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: + from app.fetcher import Fetcher + from .beatmap import Beatmap, BeatmapResp from .favourite_beatmapset import FavouriteBeatmapset @@ -185,6 +187,16 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) return beatmapset + @classmethod + async def get_or_fetch( + cls, session: AsyncSession, fetcher: "Fetcher", sid: int + ) -> "Beatmapset": + beatmapset = await session.get(Beatmapset, sid) + if not beatmapset: + resp = await fetcher.get_beatmapset(sid) + beatmapset = await cls.from_resp(session, resp) + return beatmapset + class BeatmapsetResp(BeatmapsetBase): id: int diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index f77c2ed..7280bad 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -12,7 +12,7 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from fastapi.responses import RedirectResponse -from httpx import HTTPStatusError +from httpx import HTTPError from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -24,22 +24,10 @@ async def lookup_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset_id = ( - await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id)) - ).first() - if not beatmapset_id: - try: - resp = await fetcher.get_beatmap(beatmap_id) - await Beatmap.from_resp(db, resp) - await db.refresh(current_user) - except HTTPStatusError: - raise HTTPException(status_code=404, detail="Beatmapset not found") - beatmapset = ( - await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id)) - ).first() - if not beatmapset: - raise HTTPException(status_code=404, detail="Beatmapset not found") - resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user) + beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) + resp = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=db, user=current_user + ) return resp @@ -50,18 +38,13 @@ async def get_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() - if not beatmapset: - try: - resp = await fetcher.get_beatmapset(sid) - await Beatmapset.from_resp(db, resp) - except HTTPStatusError: - raise HTTPException(status_code=404, detail="Beatmapset not found") - else: - resp = await BeatmapsetResp.from_db( + try: + beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid) + return await BeatmapsetResp.from_db( beatmapset, session=db, include=["recent_favourites"], user=current_user ) - return resp + except HTTPError: + raise HTTPException(status_code=404, detail="Beatmapset not found") @router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) diff --git a/app/router/score.py b/app/router/score.py index 5db171d..d776b5a 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -43,6 +43,7 @@ from app.models.score import ( from .api_router import router from fastapi import Depends, Form, HTTPException, Query +from httpx import HTTPError from pydantic import BaseModel from redis.asyncio import Redis from sqlalchemy.orm import joinedload @@ -86,12 +87,11 @@ async def submit_score( if not score: raise HTTPException(status_code=404, detail="Score not found") else: - beatmap_status = ( - await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)) - ).first() - if beatmap_status is None: + try: + db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap) + except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") - ranked = beatmap_status in { + ranked = db_beatmap.beatmap_status in { BeatmapRankStatus.RANKED, BeatmapRankStatus.APPROVED, } diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 3688efa..3f081e3 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -11,6 +11,7 @@ from app.database.multiplayer_event import MultiplayerEvent from app.database.playlists import Playlist from app.database.relationship import Relationship, RelationshipType from app.dependencies.database import engine, get_redis +from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException from app.log import logger from app.models.mods import APIMod @@ -44,8 +45,9 @@ from app.models.score import GameMode from .hub import Client, Hub +from httpx import HTTPError from sqlalchemy import update -from sqlmodel import col, select +from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession GAMEPLAY_LOAD_TIMEOUT = 30 @@ -191,11 +193,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]): session.add(db_room) await session.commit() await session.refresh(db_room) + item = room.playlist[0] item.owner_id = client.user_id room.room_id = db_room.id starts_at = db_room.starts_at or datetime.now(UTC) + beatmap_exists = await session.exec( + select(exists().where(col(Beatmap.id) == item.beatmap_id)) + ) + if not beatmap_exists.one(): + fetcher = await get_fetcher() + try: + resp = await fetcher.get_beatmap(item.beatmap_id) + await Beatmap.from_resp(session, resp) + except HTTPError: + raise InvokeException( + "Failed to fetch beatmap, please retry later" + ) await Playlist.add_to_db(item, db_room.id, session) + server_room = ServerMultiplayerRoom( room=room, category=RoomCategory.NORMAL, @@ -372,6 +388,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def validate_styles(self, room: ServerMultiplayerRoom): + fetcher = await get_fetcher() if not room.queue.current_item.freestyle: for user in room.room.users: await self.change_user_style( @@ -381,9 +398,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): user, ) async with AsyncSession(engine) as session: - beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) - if beatmap is None: - raise InvokeException("Beatmap not found") + try: + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=room.queue.current_item.beatmap_id + ) + except HTTPError: + raise InvokeException("Current item beatmap not found") beatmap_ids = ( await session.exec( select(Beatmap.id, Beatmap.mode).where( diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index b9a3c99..d5a12ff 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -11,6 +11,7 @@ from app.database import Beatmap, User from app.database.score import Score from app.database.score_token import ScoreToken from app.dependencies.database import engine +from app.dependencies.fetcher import get_fetcher from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics @@ -179,15 +180,13 @@ class SpectatorHub(Hub[StoreClientState]): return if state.beatmap_id is None or state.ruleset_id is None: return + + fetcher = await get_fetcher() async with AsyncSession(engine) as session: async with session.begin(): - beatmap = ( - await session.exec( - select(Beatmap).where(Beatmap.id == state.beatmap_id) - ) - ).first() - if not beatmap: - return + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=state.beatmap_id + ) user = ( await session.exec(select(User).where(User.id == user_id)) ).first() @@ -237,16 +236,16 @@ class SpectatorHub(Hub[StoreClientState]): user_id = int(client.connection_id) store = self.get_or_create_state(client) score = store.score - assert store.beatmap_status is not None - assert store.state is not None - assert store.score is not None - if not score or not store.score_token: + if ( + score is None + or store.score_token is None + or store.beatmap_status is None + or store.state is None + ): return if ( BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED - ) and any( - k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items() - ): + ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()): await self._process_score(store, client) store.state = None store.beatmap_status = None