diff --git a/app/database/beatmap.py b/app/database/beatmap.py index b959770..b8ae152 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -7,6 +7,7 @@ from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime from sqlmodel import VARCHAR, Field, Relationship, SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession class BeatmapOwner(SQLModel): @@ -22,7 +23,6 @@ class BeatmapBase(SQLModel): difficulty_rating: float = Field( default=0.0, sa_column=Column(DECIMAL(precision=10, scale=6)) ) - beatmap_status: BeatmapRankStatus total_length: int user_id: int version: str @@ -59,9 +59,49 @@ class Beatmap(BeatmapBase, table=True): __tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) + beatmap_status: BeatmapRankStatus # optional beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + @classmethod + async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": + d = resp.model_dump() + del d["beatmapset"] + beatmap = Beatmap.model_validate( + { + **d, + "beatmapset_id": resp.beatmapset_id, + "id": resp.id, + "beatmap_status": BeatmapRankStatus(resp.ranked), + } + ) + session.add(beatmap) + await session.commit() + return beatmap + + @classmethod + async def from_resp_batch( + cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0 + ) -> list["Beatmap"]: + beatmaps = [] + for resp in inp: + if resp.id == from_: + continue + d = resp.model_dump() + del d["beatmapset"] + beatmap = Beatmap.model_validate( + { + **d, + "beatmapset_id": resp.beatmapset_id, + "id": resp.id, + "beatmap_status": BeatmapRankStatus(resp.ranked), + } + ) + session.add(beatmap) + beatmaps.append(beatmap) + await session.commit() + return beatmaps + class BeatmapResp(BeatmapBase): id: int diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 4141212..4a3ced9 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, TypedDict, cast from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.score import GameMode @@ -7,6 +7,7 @@ from app.models.score import GameMode from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlmodel import Field, Relationship, SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp @@ -64,11 +65,11 @@ class BeatmapNominations(SQLModel): required: int | None = Field(default=None) -class BeatmapNomination(SQLModel): +class BeatmapNomination(TypedDict): beatmapset_id: int reset: bool user_id: int - rulesets: list[GameMode] | None = None + rulesets: dict[str, GameMode] | None class BeatmapDescription(SQLModel): @@ -150,20 +151,52 @@ class Beatmapset(BeatmapsetBase, table=True): availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + @classmethod + async def from_resp( + cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0 + ) -> "Beatmapset": + from .beatmap import Beatmap + + d = resp.model_dump() + update = {} + if resp.nominations: + update["nominations_required"] = resp.nominations.required + update["nominations_current"] = resp.nominations.current + if resp.hype: + update["hype_current"] = resp.hype.current + update["hype_required"] = resp.hype.required + if resp.genre: + update["beatmap_genre"] = Genre(resp.genre.id) + if resp.language: + update["beatmap_language"] = Language(resp.language.id) + beatmapset = Beatmapset.model_validate( + { + **d, + "id": resp.id, + "beatmap_status": BeatmapRankStatus(resp.ranked), + "availability_info": resp.availability.more_information, + "download_disabled": resp.availability.download_disabled or False, + } + ) + session.add(beatmapset) + await session.commit() + await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) + return beatmapset + class BeatmapsetResp(BeatmapsetBase): id: int - beatmaps: list["BeatmapResp"] + beatmaps: list["BeatmapResp"] = Field(default_factory=list) discussion_enabled: bool = True status: str ranked: int legacy_thread_url: str = "" is_scoreable: bool - hype: BeatmapHype + hype: BeatmapHype | None = None availability: BeatmapAvailability - genre: BeatmapTranslationText - language: BeatmapTranslationText - nominations: BeatmapNominations + genre: BeatmapTranslationText | None = None + language: BeatmapTranslationText | None = None + nominations: BeatmapNominations | None = None @classmethod def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 288660c..fe09139 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,5 +1,10 @@ from __future__ import annotations +import json + +from app.config import settings + +from pydantic import BaseModel from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -8,10 +13,16 @@ try: import redis except ImportError: redis = None -from app.config import settings + + +def json_serializer(value): + if isinstance(value, BaseModel | SQLModel): + return value.model_dump_json() + return json.dumps(value) + # 数据库引擎 -engine = create_async_engine(settings.DATABASE_URL) +engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 if redis: diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 86a5e27..30a1b82 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -8,7 +8,7 @@ from httpx import AsyncClient class BeatmapsetFetcher(BaseFetcher): - async def get_beatmap_set(self, beatmap_set_id: int) -> BeatmapsetResp: + async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: async with AsyncClient() as client: response = await client.get( f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}", diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 4731994..f6b8f1a 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -8,11 +8,14 @@ from app.database import ( from app.database.beatmapset import Beatmapset from app.database.score import Score, ScoreResp from app.dependencies.database import get_db +from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from .api_router import router from fastapi import Depends, HTTPException, Query +from httpx import HTTPStatusError from pydantic import BaseModel from sqlalchemy.orm import joinedload from sqlmodel import col, select @@ -24,6 +27,7 @@ async def get_beatmap( bid: int, current_user: DBUser = Depends(get_current_user), db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), ): beatmap = ( await db.exec( @@ -33,8 +37,20 @@ async def get_beatmap( ) ).first() if not beatmap: - raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + try: + resp = await fetcher.get_beatmap(bid) + r = await db.exec( + select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id) + ) + if not r.first(): + set_resp = await fetcher.get_beatmapset(resp.beatmapset_id) + await Beatmapset.from_resp(db, set_resp, from_=resp.id) + await Beatmap.from_resp(db, resp) + except HTTPStatusError: + raise HTTPException(status_code=404, detail="Beatmap not found") + else: + resp = BeatmapResp.from_db(beatmap) + return resp class BatchGetResp(BaseModel): diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index eceb19b..db2dd77 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -6,11 +6,14 @@ from app.database import ( User as DBUser, ) from app.dependencies.database import get_db +from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from .api_router import router from fastapi import Depends, HTTPException +from httpx import HTTPStatusError from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -21,6 +24,7 @@ async def get_beatmapset( sid: int, current_user: DBUser = Depends(get_current_user), db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), ): beatmapset = ( await db.exec( @@ -30,5 +34,11 @@ async def get_beatmapset( ) ).first() if not beatmapset: - raise HTTPException(status_code=404, detail="Beatmapset not found") - return BeatmapsetResp.from_db(beatmapset) + try: + resp = await fetcher.get_beatmapset(sid) + await Beatmapset.from_resp(db, resp) + except HTTPStatusError: + raise HTTPException(status_code=404, detail="Beatmapset not found") + else: + resp = BeatmapsetResp.from_db(beatmapset) + return resp