refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.playlist import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from ._base import DatabaseModel, ondemand
|
||||
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -15,7 +15,6 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .room import Room
|
||||
from .score import ScoreDict
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
class PlaylistDict(TypedDict):
|
||||
id: int
|
||||
room_id: int
|
||||
beatmap_id: int
|
||||
created_at: datetime | None
|
||||
ruleset_id: int
|
||||
expired: bool = Field(default=False)
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
allowed_mods: list[APIMod]
|
||||
required_mods: list[APIMod]
|
||||
freestyle: bool
|
||||
expired: bool
|
||||
owner_id: int
|
||||
playlist_order: int
|
||||
played_at: datetime | None
|
||||
beatmap: NotRequired["BeatmapDict"]
|
||||
scores: NotRequired[list[dict[str, Any]]]
|
||||
|
||||
|
||||
class PlaylistModel(DatabaseModel[PlaylistDict]):
|
||||
id: int = Field(index=True)
|
||||
room_id: int = Field(foreign_key="rooms.id")
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
ruleset_id: int
|
||||
allowed_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
freestyle: bool = Field(default=False)
|
||||
expired: bool = Field(default=False)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
|
||||
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
|
||||
from .score import Score, ScoreModel
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.playlist_item_id == playlist.id,
|
||||
Score.room_id == playlist.room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
result: list[ScoreDict] = []
|
||||
for score in scores:
|
||||
result.append(
|
||||
await ScoreModel.transform(
|
||||
score,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
class Playlist(PlaylistModel, table=True):
|
||||
__tablename__: str = "room_playlists"
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
beatmap: Beatmap = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
updated_at: datetime | None = Field(
|
||||
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
||||
)
|
||||
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
|
||||
raise ValueError("Playlist item not found")
|
||||
await session.delete(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
|
||||
Reference in New Issue
Block a user