refactor(beatmap,beatmapset): use to ensure beatmap exists

This commit is contained in:
MingxuanGame
2025-08-08 11:54:43 +00:00
parent 9ddcf9ec7b
commit 0ac4f1f516
5 changed files with 64 additions and 50 deletions

View File

@@ -13,6 +13,8 @@ from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
@@ -185,6 +187,16 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid)
if not beatmapset:
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
return beatmapset
class BeatmapsetResp(BeatmapsetBase):
id: int

View File

@@ -12,7 +12,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from httpx import HTTPError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -24,22 +24,10 @@ async def lookup_beatmapset(
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset_id = (
await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id))
).first()
if not beatmapset_id:
try:
resp = await fetcher.get_beatmap(beatmap_id)
await Beatmap.from_resp(db, resp)
await db.refresh(current_user)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
beatmapset = (
await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user)
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
return resp
@@ -50,18 +38,13 @@ async def get_beatmapset(
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
await Beatmapset.from_resp(db, resp)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = await BeatmapsetResp.from_db(
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return resp
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])

View File

@@ -43,6 +43,7 @@ from app.models.score import (
from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from httpx import HTTPError
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
@@ -86,12 +87,11 @@ async def submit_score(
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
).first()
if beatmap_status is None:
try:
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
ranked = db_beatmap.beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}

View File

@@ -11,6 +11,7 @@ from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist
from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine, get_redis
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.log import logger
from app.models.mods import APIMod
@@ -44,8 +45,9 @@ from app.models.score import GameMode
from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy import update
from sqlmodel import col, select
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30
@@ -191,11 +193,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
session.add(db_room)
await session.commit()
await session.refresh(db_room)
item = room.playlist[0]
item.owner_id = client.user_id
room.room_id = db_room.id
starts_at = db_room.starts_at or datetime.now(UTC)
beatmap_exists = await session.exec(
select(exists().where(col(Beatmap.id) == item.beatmap_id))
)
if not beatmap_exists.one():
fetcher = await get_fetcher()
try:
resp = await fetcher.get_beatmap(item.beatmap_id)
await Beatmap.from_resp(session, resp)
except HTTPError:
raise InvokeException(
"Failed to fetch beatmap, please retry later"
)
await Playlist.add_to_db(item, db_room.id, session)
server_room = ServerMultiplayerRoom(
room=room,
category=RoomCategory.NORMAL,
@@ -372,6 +388,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
)
async def validate_styles(self, room: ServerMultiplayerRoom):
fetcher = await get_fetcher()
if not room.queue.current_item.freestyle:
for user in room.room.users:
await self.change_user_style(
@@ -381,9 +398,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
user,
)
async with AsyncSession(engine) as session:
beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
try:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=room.queue.current_item.beatmap_id
)
except HTTPError:
raise InvokeException("Current item beatmap not found")
beatmap_ids = (
await session.exec(
select(Beatmap.id, Beatmap.mode).where(

View File

@@ -11,6 +11,7 @@ from app.database import Beatmap, User
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.dependencies.database import engine
from app.dependencies.fetcher import get_fetcher
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
@@ -179,15 +180,13 @@ class SpectatorHub(Hub[StoreClientState]):
return
if state.beatmap_id is None or state.ruleset_id is None:
return
fetcher = await get_fetcher()
async with AsyncSession(engine) as session:
async with session.begin():
beatmap = (
await session.exec(
select(Beatmap).where(Beatmap.id == state.beatmap_id)
)
).first()
if not beatmap:
return
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=state.beatmap_id
)
user = (
await session.exec(select(User).where(User.id == user_id))
).first()
@@ -237,16 +236,16 @@ class SpectatorHub(Hub[StoreClientState]):
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
score = store.score
assert store.beatmap_status is not None
assert store.state is not None
assert store.score is not None
if not score or not store.score_token:
if (
score is None
or store.score_token is None
or store.beatmap_status is None
or store.state is None
):
return
if (
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
) and any(
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
):
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
await self._process_score(store, client)
store.state = None
store.beatmap_status = None