refactor(beatmap,beatmapset): use to ensure beatmap exists

This commit is contained in:
MingxuanGame
2025-08-08 11:54:43 +00:00
parent 9ddcf9ec7b
commit 0ac4f1f516
5 changed files with 64 additions and 50 deletions

View File

@@ -13,6 +13,8 @@ from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset from .favourite_beatmapset import FavouriteBeatmapset
@@ -185,6 +187,16 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset return beatmapset
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid)
if not beatmapset:
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
return beatmapset
class BeatmapsetResp(BeatmapsetBase): class BeatmapsetResp(BeatmapsetBase):
id: int id: int

View File

@@ -12,7 +12,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query from fastapi import Depends, Form, HTTPException, Query
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError from httpx import HTTPError
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -24,22 +24,10 @@ async def lookup_beatmapset(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmapset_id = ( beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id)) resp = await BeatmapsetResp.from_db(
).first() beatmap.beatmapset, session=db, user=current_user
if not beatmapset_id: )
try:
resp = await fetcher.get_beatmap(beatmap_id)
await Beatmap.from_resp(db, resp)
await db.refresh(current_user)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
beatmapset = (
await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user)
return resp return resp
@@ -50,18 +38,13 @@ async def get_beatmapset(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() try:
if not beatmapset: beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid)
try: return await BeatmapsetResp.from_db(
resp = await fetcher.get_beatmapset(sid)
await Beatmapset.from_resp(db, resp)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user beatmapset, session=db, include=["recent_favourites"], user=current_user
) )
return resp except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) @router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])

View File

@@ -43,6 +43,7 @@ from app.models.score import (
from .api_router import router from .api_router import router
from fastapi import Depends, Form, HTTPException, Query from fastapi import Depends, Form, HTTPException, Query
from httpx import HTTPError
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@@ -86,12 +87,11 @@ async def submit_score(
if not score: if not score:
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")
else: else:
beatmap_status = ( try:
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)) db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
).first() except HTTPError:
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in { ranked = db_beatmap.beatmap_status in {
BeatmapRankStatus.RANKED, BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED, BeatmapRankStatus.APPROVED,
} }

View File

@@ -11,6 +11,7 @@ from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist from app.database.playlists import Playlist
from app.database.relationship import Relationship, RelationshipType from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine, get_redis from app.dependencies.database import engine, get_redis
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException from app.exception import InvokeException
from app.log import logger from app.log import logger
from app.models.mods import APIMod from app.models.mods import APIMod
@@ -44,8 +45,9 @@ from app.models.score import GameMode
from .hub import Client, Hub from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy import update from sqlalchemy import update
from sqlmodel import col, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30 GAMEPLAY_LOAD_TIMEOUT = 30
@@ -191,11 +193,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
session.add(db_room) session.add(db_room)
await session.commit() await session.commit()
await session.refresh(db_room) await session.refresh(db_room)
item = room.playlist[0] item = room.playlist[0]
item.owner_id = client.user_id item.owner_id = client.user_id
room.room_id = db_room.id room.room_id = db_room.id
starts_at = db_room.starts_at or datetime.now(UTC) starts_at = db_room.starts_at or datetime.now(UTC)
beatmap_exists = await session.exec(
select(exists().where(col(Beatmap.id) == item.beatmap_id))
)
if not beatmap_exists.one():
fetcher = await get_fetcher()
try:
resp = await fetcher.get_beatmap(item.beatmap_id)
await Beatmap.from_resp(session, resp)
except HTTPError:
raise InvokeException(
"Failed to fetch beatmap, please retry later"
)
await Playlist.add_to_db(item, db_room.id, session) await Playlist.add_to_db(item, db_room.id, session)
server_room = ServerMultiplayerRoom( server_room = ServerMultiplayerRoom(
room=room, room=room,
category=RoomCategory.NORMAL, category=RoomCategory.NORMAL,
@@ -372,6 +388,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
async def validate_styles(self, room: ServerMultiplayerRoom): async def validate_styles(self, room: ServerMultiplayerRoom):
fetcher = await get_fetcher()
if not room.queue.current_item.freestyle: if not room.queue.current_item.freestyle:
for user in room.room.users: for user in room.room.users:
await self.change_user_style( await self.change_user_style(
@@ -381,9 +398,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
user, user,
) )
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) try:
if beatmap is None: beatmap = await Beatmap.get_or_fetch(
raise InvokeException("Beatmap not found") session, fetcher, bid=room.queue.current_item.beatmap_id
)
except HTTPError:
raise InvokeException("Current item beatmap not found")
beatmap_ids = ( beatmap_ids = (
await session.exec( await session.exec(
select(Beatmap.id, Beatmap.mode).where( select(Beatmap.id, Beatmap.mode).where(

View File

@@ -11,6 +11,7 @@ from app.database import Beatmap, User
from app.database.score import Score from app.database.score import Score
from app.database.score_token import ScoreToken from app.database.score_token import ScoreToken
from app.dependencies.database import engine from app.dependencies.database import engine
from app.dependencies.fetcher import get_fetcher
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
@@ -179,15 +180,13 @@ class SpectatorHub(Hub[StoreClientState]):
return return
if state.beatmap_id is None or state.ruleset_id is None: if state.beatmap_id is None or state.ruleset_id is None:
return return
fetcher = await get_fetcher()
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
async with session.begin(): async with session.begin():
beatmap = ( beatmap = await Beatmap.get_or_fetch(
await session.exec( session, fetcher, bid=state.beatmap_id
select(Beatmap).where(Beatmap.id == state.beatmap_id) )
)
).first()
if not beatmap:
return
user = ( user = (
await session.exec(select(User).where(User.id == user_id)) await session.exec(select(User).where(User.id == user_id))
).first() ).first()
@@ -237,16 +236,16 @@ class SpectatorHub(Hub[StoreClientState]):
user_id = int(client.connection_id) user_id = int(client.connection_id)
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
score = store.score score = store.score
assert store.beatmap_status is not None if (
assert store.state is not None score is None
assert store.score is not None or store.score_token is None
if not score or not store.score_token: or store.beatmap_status is None
or store.state is None
):
return return
if ( if (
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
) and any( ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
):
await self._process_score(store, client) await self._process_score(store, client)
store.state = None store.state = None
store.beatmap_status = None store.beatmap_status = None