refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -1,11 +1,11 @@
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.playlist import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
from ._base import DatabaseModel, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from sqlmodel import (
JSON,
@@ -15,7 +15,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
func,
select,
)
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
from .score import ScoreDict
class PlaylistBase(SQLModel, UTCBaseModel):
id: int = Field(index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
class PlaylistDict(TypedDict):
id: int
room_id: int
beatmap_id: int
created_at: datetime | None
ruleset_id: int
expired: bool = Field(default=False)
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
allowed_mods: list[APIMod]
required_mods: list[APIMod]
freestyle: bool
expired: bool
owner_id: int
playlist_order: int
played_at: datetime | None
beatmap: NotRequired["BeatmapDict"]
scores: NotRequired[list[dict[str, Any]]]
class PlaylistModel(DatabaseModel[PlaylistDict]):
id: int = Field(index=True)
room_id: int = Field(foreign_key="rooms.id")
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
ruleset_id: int
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
default_factory=list,
sa_column=Column(JSON),
)
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
freestyle: bool = Field(default=False)
expired: bool = Field(default=False)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
)
@ondemand
@staticmethod
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
@ondemand
@staticmethod
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
from .score import Score, ScoreModel
scores = (
await session.exec(
select(Score).where(
Score.playlist_item_id == playlist.id,
Score.room_id == playlist.room_id,
)
)
).all()
result: list[ScoreDict] = []
for score in scores:
result.append(
await ScoreModel.transform(
score,
)
)
return result
class Playlist(PlaylistBase, table=True):
class Playlist(PlaylistModel, table=True):
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
}
)
room: "Room" = Relationship()
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
updated_at: datetime | None = Field(
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
resp = cls.model_validate(data)
return resp