refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -1,14 +1,15 @@
from datetime import datetime
from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from .user import BASE_INCLUDES, User, UserResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .user import User, UserDict
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp
from .beatmap import Beatmap, BeatmapDict
from .favourite_beatmapset import FavouriteBeatmapset
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
id: int | None = None
class BeatmapsetBase(SQLModel):
class BeatmapsetDict(TypedDict):
id: int
artist: str
artist_unicode: str
covers: BeatmapCovers | None
creator: str
nsfw: bool
preview_url: str
source: str
spotlight: bool
title: str
title_unicode: str
track_id: int | None
user_id: int
video: bool
current_nominations: list[BeatmapNomination] | None
description: BeatmapDescription | None
pack_tags: list[str]
bpm: NotRequired[float]
can_be_hyped: NotRequired[bool]
discussion_locked: NotRequired[bool]
last_updated: NotRequired[datetime]
ranked_date: NotRequired[datetime | None]
storyboard: NotRequired[bool]
submitted_date: NotRequired[datetime]
tags: NotRequired[str]
discussion_enabled: NotRequired[bool]
legacy_thread_url: NotRequired[str | None]
status: NotRequired[str]
ranked: NotRequired[int]
is_scoreable: NotRequired[bool]
favourite_count: NotRequired[int]
genre_id: NotRequired[int]
hype: NotRequired[BeatmapHype]
language_id: NotRequired[int]
play_count: NotRequired[int]
availability: NotRequired[BeatmapAvailability]
beatmaps: NotRequired[list["BeatmapDict"]]
has_favourited: NotRequired[bool]
recent_favourites: NotRequired[list[UserDict]]
genre: NotRequired[BeatmapTranslationText]
language: NotRequired[BeatmapTranslationText]
nominations: NotRequired["BeatmapNominations"]
ratings: NotRequired[list[int]]
class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"availability",
"has_favourited",
"bpm",
"deleted_atcan_be_hyped",
"discussion_locked",
"is_scoreable",
"last_updated",
"legacy_thread_url",
"ranked",
"ranked_date",
"submitted_date",
"tags",
"rating",
"storyboard",
]
API_INCLUDES: ClassVar[list[str]] = [
*BEATMAPSET_TRANSFORMER_INCLUDES,
"beatmaps.current_user_playcount",
"beatmaps.current_user_tag_ids",
"beatmaps.max_combo",
"current_nominations",
"current_user_attributes",
"description",
"genre",
"language",
"pack_tags",
"ratings",
"recent_favourites",
"related_tags",
"related_users",
"user",
"version_count",
*[
f"beatmaps.{inc}"
for inc in {
"failtimes",
"owners",
"top_tag_ids",
}
],
]
# Beatmapset
id: int = Field(default=None, primary_key=True, index=True)
artist: str = Field(index=True)
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
preview_url: str
source: str = Field(default="")
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
title: str = Field(index=True)
title_unicode: str = Field(index=True)
track_id: int | None = Field(default=None, index=True) # feature artist?
user_id: int = Field(index=True)
video: bool = Field(sa_column=Column(Boolean, index=True))
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
# TODO: events: Optional[list[BeatmapsetEvent]] = None
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None, index=True) # feature artist?
# BeatmapsetExtended
bpm: float = Field(default=0.0)
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
bpm: OnDemand[float] = Field(default=0.0)
can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
@ondemand
@staticmethod
async def legacy_thread_url(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> str | None:
return None
@included
@staticmethod
async def discussion_enabled(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> bool:
return True
@included
@staticmethod
async def status(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> str:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap_status.name.lower()
@included
@staticmethod
async def ranked(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def is_scoreable(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> bool:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@included
@staticmethod
async def favourite_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .favourite_beatmapset import FavouriteBeatmapset
count = await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
return count.one()
@included
@staticmethod
async def genre_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_genre.value
@ondemand
@staticmethod
async def hype(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapHype:
return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
@included
@staticmethod
async def language_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_language.value
@included
@staticmethod
async def play_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .beatmap import Beatmap
playcount = await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
return int(playcount.first() or 0)
@ondemand
@staticmethod
async def availability(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapAvailability:
return BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
)
@ondemand
@staticmethod
async def beatmaps(
_session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
user: "User | None" = None,
) -> list["BeatmapDict"]:
from .beatmap import BeatmapModel
return [
await BeatmapModel.transform(
beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
]
# @ondemand
# @staticmethod
# async def current_nominations(
# _session: AsyncSession,
# beatmapset: "Beatmapset",
# ) -> list[BeatmapNomination] | None:
# return beatmapset.current_nominations or []
@ondemand
@staticmethod
async def has_favourited(
session: AsyncSession,
beatmapset: "Beatmapset",
user: User | None = None,
) -> bool:
from .favourite_beatmapset import FavouriteBeatmapset
if session is None:
return False
query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
if user is not None:
query = query.where(FavouriteBeatmapset.user_id == user.id)
existing = (await session.exec(query)).first()
return existing is not None
@ondemand
@staticmethod
async def recent_favourites(
session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
) -> list[UserDict]:
from .favourite_beatmapset import FavouriteBeatmapset
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
return [
await User.transform(
(await favourite.awaitable_attrs.user),
includes=includes,
)
for favourite in recent_favourites
]
@ondemand
@staticmethod
async def genre(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
)
@ondemand
@staticmethod
async def language(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
)
@ondemand
@staticmethod
async def nominations(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapNominations:
return BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
)
# @ondemand
# @staticmethod
# async def user(
# session: AsyncSession,
# beatmapset: Beatmapset,
# includes: list[str] | None = None,
# ) -> dict[str, Any] | None:
# db_user = await session.get(User, beatmapset.user_id)
# if not db_user:
# return None
# return await UserResp.transform(db_user, includes=includes)
@ondemand
@staticmethod
async def ratings(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> list[int]:
# Provide a stable default shape if no session is available
if session is None:
return []
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
return ratings_list
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
__tablename__: str = "beatmapsets"
id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
d["nominations_current"] = resp.nominations.current
if resp.hype:
d["hype_current"] = resp.hype.current
d["hype_required"] = resp.hype.required
if resp.genre_id:
d["beatmap_genre"] = Genre(resp.genre_id)
elif resp.genre:
d["beatmap_genre"] = Genre(resp.genre.id)
if resp.language_id:
d["beatmap_language"] = Language(resp.language_id)
elif resp.language:
d["beatmap_language"] = Language(resp.language.id)
async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
# make a shallow copy so we can mutate safely
d: dict[str, Any] = dict(resp)
# nominations = resp.get("nominations")
# if nominations is not None:
# d["nominations_required"] = nominations.required
# d["nominations_current"] = nominations.current
hype = resp.get("hype")
if hype is not None:
d["hype_current"] = hype.current
d["hype_required"] = hype.required
genre_id = resp.get("genre_id")
genre = resp.get("genre")
if genre_id is not None:
d["beatmap_genre"] = Genre(genre_id)
elif genre is not None:
d["beatmap_genre"] = Genre(genre.id)
language_id = resp.get("language_id")
language = resp.get("language")
if language_id is not None:
d["beatmap_language"] = Language(language_id)
elif language is not None:
d["beatmap_language"] = Language(language.id)
availability = resp.get("availability")
ranked = resp.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
beatmapset = Beatmapset.model_validate(
{
**d,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"availability_info": resp.availability.more_information,
"download_disabled": resp.availability.download_disabled or False,
"beatmap_status": BeatmapRankStatus(ranked),
"availability_info": availability.more_information if availability is not None else None,
"download_disabled": bool(availability.download_disabled) if availability is not None else False,
}
)
return beatmapset
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
resp: BeatmapsetDict,
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
beatmapset_id = resp["id"]
beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
beatmaps = resp.get("beatmaps", [])
await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
return beatmapset
@classmethod
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
await get_beatmapset_update_service().add(resp)
await session.refresh(beatmapset)
return beatmapset
class BeatmapsetResp(BeatmapsetBase):
id: int
beatmaps: list["BeatmapResp"] = Field(default_factory=list)
discussion_enabled: bool = True
status: str
ranked: int
legacy_thread_url: str | None = ""
is_scoreable: bool
hype: BeatmapHype | None = None
availability: BeatmapAvailability
genre: BeatmapTranslationText | None = None
genre_id: int
language: BeatmapTranslationText | None = None
language_id: int
nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
play_count: int = 0
@field_validator(
"nsfw",
"spotlight",
"video",
"can_be_hyped",
"discussion_locked",
"storyboard",
"discussion_enabled",
"is_scoreable",
"has_favourited",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
if self.language is None:
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
return self
@classmethod
async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
update = {
"beatmaps": [
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"genre_id": beatmapset.beatmap_genre.value,
"language_id": beatmapset.beatmap_language.value,
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
**beatmapset.model_dump(),
}
if session is not None:
# 从数据库读取对应谱面集的评分
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
update["ratings"] = ratings_list
else:
# 返回非空值避免客户端崩溃
if update.get("ratings") is None:
update["ratings"] = []
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
update["status"] = beatmap_status.name.lower()
update["ranked"] = beatmap_status.value
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
update["play_count"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
).first() or 0
return cls.model_validate(
update,
)
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None