refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -1,117 +1,339 @@
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.calculator import get_calculator
|
||||
from app.config import settings
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.failtime import FailTime, FailTimeResp
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
from app.models.performance import DifficultyAttributesUnion
|
||||
from app.models.score import GameMode
|
||||
|
||||
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
from .beatmap_tags import BeatmapTagVote
|
||||
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
|
||||
from .failtime import FailTime, FailTimeResp
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .user import User
|
||||
|
||||
|
||||
class BeatmapOwner(SQLModel):
|
||||
id: int
|
||||
username: str
|
||||
|
||||
|
||||
class BeatmapBase(SQLModel):
|
||||
# Beatmap
|
||||
url: str
|
||||
class BeatmapDict(TypedDict):
|
||||
beatmapset_id: int
|
||||
difficulty_rating: float
|
||||
id: int
|
||||
mode: GameMode
|
||||
total_length: int
|
||||
user_id: int
|
||||
version: str
|
||||
url: str
|
||||
|
||||
checksum: NotRequired[str]
|
||||
max_combo: NotRequired[int | None]
|
||||
ar: NotRequired[float]
|
||||
cs: NotRequired[float]
|
||||
drain: NotRequired[float]
|
||||
accuracy: NotRequired[float]
|
||||
bpm: NotRequired[float]
|
||||
count_circles: NotRequired[int]
|
||||
count_sliders: NotRequired[int]
|
||||
count_spinners: NotRequired[int]
|
||||
deleted_at: NotRequired[datetime | None]
|
||||
hit_length: NotRequired[int]
|
||||
last_updated: NotRequired[datetime]
|
||||
|
||||
status: NotRequired[str]
|
||||
beatmapset: NotRequired[BeatmapsetDict]
|
||||
current_user_playcount: NotRequired[int]
|
||||
current_user_tag_ids: NotRequired[list[int]]
|
||||
failtimes: NotRequired[FailTimeResp]
|
||||
top_tag_ids: NotRequired[list[dict[str, int]]]
|
||||
user: NotRequired[UserDict]
|
||||
convert: NotRequired[bool]
|
||||
is_scoreable: NotRequired[bool]
|
||||
mode_int: NotRequired[int]
|
||||
ranked: NotRequired[int]
|
||||
playcount: NotRequired[int]
|
||||
passcount: NotRequired[int]
|
||||
|
||||
|
||||
class BeatmapModel(DatabaseModel[BeatmapDict]):
|
||||
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
|
||||
"checksum",
|
||||
"accuracy",
|
||||
"ar",
|
||||
"bpm",
|
||||
"convert",
|
||||
"count_circles",
|
||||
"count_sliders",
|
||||
"count_spinners",
|
||||
"cs",
|
||||
"deleted_at",
|
||||
"drain",
|
||||
"hit_length",
|
||||
"is_scoreable",
|
||||
"last_updated",
|
||||
"mode_int",
|
||||
"passcount",
|
||||
"playcount",
|
||||
"ranked",
|
||||
"url",
|
||||
]
|
||||
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
|
||||
"beatmapset.ratings",
|
||||
"current_user_playcount",
|
||||
"failtimes",
|
||||
"max_combo",
|
||||
"owners",
|
||||
]
|
||||
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
|
||||
|
||||
# Beatmap
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
difficulty_rating: float = Field(default=0.0, index=True)
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
mode: GameMode
|
||||
total_length: int
|
||||
user_id: int = Field(index=True)
|
||||
version: str = Field(index=True)
|
||||
|
||||
url: OnDemand[str]
|
||||
# optional
|
||||
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
|
||||
current_user_playcount: int = Field(default=0)
|
||||
max_combo: int | None = Field(default=0)
|
||||
# TODO: failtimes, owners
|
||||
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
|
||||
max_combo: OnDemand[int | None] = Field(default=0)
|
||||
# TODO: owners
|
||||
|
||||
# BeatmapExtended
|
||||
ar: float = Field(default=0.0)
|
||||
cs: float = Field(default=0.0)
|
||||
drain: float = Field(default=0.0) # hp
|
||||
accuracy: float = Field(default=0.0) # od
|
||||
bpm: float = Field(default=0.0)
|
||||
count_circles: int = Field(default=0)
|
||||
count_sliders: int = Field(default=0)
|
||||
count_spinners: int = Field(default=0)
|
||||
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
hit_length: int = Field(default=0)
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ar: OnDemand[float] = Field(default=0.0)
|
||||
cs: OnDemand[float] = Field(default=0.0)
|
||||
drain: OnDemand[float] = Field(default=0.0) # hp
|
||||
accuracy: OnDemand[float] = Field(default=0.0) # od
|
||||
bpm: OnDemand[float] = Field(default=0.0)
|
||||
count_circles: OnDemand[int] = Field(default=0)
|
||||
count_sliders: OnDemand[int] = Field(default=0)
|
||||
count_spinners: OnDemand[int] = Field(default=0)
|
||||
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
|
||||
hit_length: OnDemand[int] = Field(default=0)
|
||||
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.name.lower()
|
||||
return beatmap.beatmap_status.name.lower()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmapset(
|
||||
_session: AsyncSession,
|
||||
beatmap: "Beatmap",
|
||||
includes: list[str] | None = None,
|
||||
) -> BeatmapsetDict | None:
|
||||
if beatmap.beatmapset is not None:
|
||||
return await BeatmapsetModel.transform(
|
||||
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
|
||||
playcount = (
|
||||
await _session.exec(
|
||||
select(BeatmapPlaycounts.playcount).where(
|
||||
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
return int(playcount or 0)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
|
||||
if user is None:
|
||||
return []
|
||||
tag_ids = (
|
||||
await _session.exec(
|
||||
select(BeatmapTagVote.tag_id).where(
|
||||
BeatmapTagVote.beatmap_id == beatmap.id,
|
||||
BeatmapTagVote.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return list(tag_ids)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
|
||||
if beatmap.failtimes is not None:
|
||||
return FailTimeResp.from_db(beatmap.failtimes)
|
||||
return FailTimeResp()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
|
||||
all_votes = (
|
||||
await _session.exec(
|
||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.group_by(col(BeatmapTagVote.tag_id))
|
||||
.having(func.count() > settings.beatmap_tag_top_count)
|
||||
)
|
||||
).all()
|
||||
top_tag_ids: list[dict[str, int]] = []
|
||||
for id, votes in all_votes:
|
||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||||
return top_tag_ids
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(
|
||||
_session: AsyncSession,
|
||||
beatmap: "Beatmap",
|
||||
includes: list[str] | None = None,
|
||||
) -> UserDict | None:
|
||||
from .user import User
|
||||
|
||||
user = await _session.get(User, beatmap.user_id)
|
||||
if user is None:
|
||||
return None
|
||||
return await UserModel.transform(user, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
|
||||
return False
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard:
|
||||
return True
|
||||
return beatmap_status.has_leaderboard()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
return int(beatmap.mode)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.value
|
||||
return beatmap_status.value
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
result = (
|
||||
await _session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||||
)
|
||||
).first()
|
||||
return int(result or 0)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
from .score import Score
|
||||
|
||||
return (
|
||||
await _session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.beatmap_id == beatmap.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
|
||||
__tablename__: str = "beatmaps"
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
|
||||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||
d = {k: v for k, v in resp.items() if k != "beatmapset"}
|
||||
beatmapset_id = resp.get("beatmapset_id")
|
||||
bid = resp.get("id")
|
||||
ranked = resp.get("ranked")
|
||||
if beatmapset_id is None or bid is None or ranked is None:
|
||||
raise ValueError("beatmapset_id, id and ranked are required")
|
||||
beatmap = cls.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
"id": resp.id,
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
"beatmapset_id": beatmapset_id,
|
||||
"id": bid,
|
||||
"beatmap_status": BeatmapRankStatus(ranked),
|
||||
}
|
||||
)
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||
beatmap = await cls.from_resp_no_save(session, resp)
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
resp_id = resp.get("id")
|
||||
if resp_id is None:
|
||||
raise ValueError("id is required")
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
|
||||
beatmaps = []
|
||||
for resp in inp:
|
||||
if resp.id == from_:
|
||||
for resp_dict in inp:
|
||||
bid = resp_dict.get("id")
|
||||
if bid == from_ or bid is None:
|
||||
continue
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
|
||||
beatmapset_id = resp_dict.get("beatmapset_id")
|
||||
ranked = resp_dict.get("ranked")
|
||||
if beatmapset_id is None or ranked is None:
|
||||
continue
|
||||
|
||||
# 创建 beatmap 字典,移除 beatmapset
|
||||
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
|
||||
|
||||
beatmap = Beatmap.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
"id": resp.id,
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
"beatmapset_id": beatmapset_id,
|
||||
"id": bid,
|
||||
"beatmap_status": BeatmapRankStatus(ranked),
|
||||
}
|
||||
)
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
|
||||
session.add(beatmap)
|
||||
beatmaps.append(beatmap)
|
||||
await session.commit()
|
||||
for beatmap in beatmaps:
|
||||
await session.refresh(beatmap)
|
||||
return beatmaps
|
||||
|
||||
@classmethod
|
||||
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
|
||||
beatmap = (await session.exec(stmt)).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid, md5)
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
||||
beatmapset_id = resp.get("beatmapset_id")
|
||||
if beatmapset_id is None:
|
||||
raise ValueError("beatmapset_id is required")
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
set_resp = await fetcher.get_beatmapset(beatmapset_id)
|
||||
resp_id = resp.get("id")
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
|
||||
return await Beatmap.from_resp(session, resp)
|
||||
return beatmap
|
||||
|
||||
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
|
||||
count: int
|
||||
|
||||
|
||||
class BeatmapResp(BeatmapBase):
|
||||
id: int
|
||||
beatmapset_id: int
|
||||
beatmapset: BeatmapsetResp | None = None
|
||||
convert: bool = False
|
||||
is_scoreable: bool
|
||||
status: str
|
||||
mode_int: int
|
||||
ranked: int
|
||||
url: str = ""
|
||||
playcount: int = 0
|
||||
passcount: int = 0
|
||||
failtimes: FailTimeResp | None = None
|
||||
top_tag_ids: list[APIBeatmapTag] | None = None
|
||||
current_user_tag_ids: list[int] | None = None
|
||||
is_deleted: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmap: Beatmap,
|
||||
query_mode: GameMode | None = None,
|
||||
from_set: bool = False,
|
||||
session: AsyncSession | None = None,
|
||||
user: "User | None" = None,
|
||||
) -> "BeatmapResp":
|
||||
from .score import Score
|
||||
|
||||
beatmap_ = beatmap.model_dump()
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
beatmap_["convert"] = True
|
||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
else:
|
||||
beatmap_["status"] = beatmap_status.name.lower()
|
||||
beatmap_["ranked"] = beatmap_status.value
|
||||
beatmap_["mode_int"] = int(beatmap.mode)
|
||||
beatmap_["is_deleted"] = beatmap.deleted_at is not None
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
||||
if beatmap.failtimes is not None:
|
||||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
||||
else:
|
||||
beatmap_["failtimes"] = FailTimeResp()
|
||||
if session:
|
||||
beatmap_["playcount"] = (
|
||||
await session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||||
)
|
||||
).first() or 0
|
||||
beatmap_["passcount"] = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.beatmap_id == beatmap.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
all_votes = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.group_by(col(BeatmapTagVote.tag_id))
|
||||
.having(func.count() > settings.beatmap_tag_top_count)
|
||||
)
|
||||
).all()
|
||||
top_tag_ids: list[dict[str, int]] = []
|
||||
for id, votes in all_votes:
|
||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||||
beatmap_["top_tag_ids"] = top_tag_ids
|
||||
|
||||
if user is not None:
|
||||
beatmap_["current_user_tag_ids"] = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id)
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.where(BeatmapTagVote.user_id == user.id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmap_["current_user_tag_ids"] = []
|
||||
return cls.model_validate(beatmap_)
|
||||
|
||||
|
||||
class BannedBeatmaps(SQLModel, table=True):
|
||||
__tablename__: str = "banned_beatmaps"
|
||||
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
Reference in New Issue
Block a user