refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -1,117 +1,339 @@
from datetime import datetime
import hashlib
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.calculator import get_calculator
from app.config import settings
from app.database.beatmap_tags import BeatmapTagVote
from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DifficultyAttributesUnion
from app.models.score import GameMode
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmap_tags import BeatmapTagVote
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
from .failtime import FailTime, FailTimeResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, TypeAdapter
from redis.asyncio import Redis
from sqlalchemy import Column, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .user import User
class BeatmapOwner(SQLModel):
id: int
username: str
class BeatmapBase(SQLModel):
# Beatmap
url: str
class BeatmapDict(TypedDict):
beatmapset_id: int
difficulty_rating: float
id: int
mode: GameMode
total_length: int
user_id: int
version: str
url: str
checksum: NotRequired[str]
max_combo: NotRequired[int | None]
ar: NotRequired[float]
cs: NotRequired[float]
drain: NotRequired[float]
accuracy: NotRequired[float]
bpm: NotRequired[float]
count_circles: NotRequired[int]
count_sliders: NotRequired[int]
count_spinners: NotRequired[int]
deleted_at: NotRequired[datetime | None]
hit_length: NotRequired[int]
last_updated: NotRequired[datetime]
status: NotRequired[str]
beatmapset: NotRequired[BeatmapsetDict]
current_user_playcount: NotRequired[int]
current_user_tag_ids: NotRequired[list[int]]
failtimes: NotRequired[FailTimeResp]
top_tag_ids: NotRequired[list[dict[str, int]]]
user: NotRequired[UserDict]
convert: NotRequired[bool]
is_scoreable: NotRequired[bool]
mode_int: NotRequired[int]
ranked: NotRequired[int]
playcount: NotRequired[int]
passcount: NotRequired[int]
class BeatmapModel(DatabaseModel[BeatmapDict]):
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"checksum",
"accuracy",
"ar",
"bpm",
"convert",
"count_circles",
"count_sliders",
"count_spinners",
"cs",
"deleted_at",
"drain",
"hit_length",
"is_scoreable",
"last_updated",
"mode_int",
"passcount",
"playcount",
"ranked",
"url",
]
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
"beatmapset.ratings",
"current_user_playcount",
"failtimes",
"max_combo",
"owners",
]
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
# Beatmap
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(default=0.0, index=True)
id: int = Field(primary_key=True, index=True)
mode: GameMode
total_length: int
user_id: int = Field(index=True)
version: str = Field(index=True)
url: OnDemand[str]
# optional
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
current_user_playcount: int = Field(default=0)
max_combo: int | None = Field(default=0)
# TODO: failtimes, owners
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
max_combo: OnDemand[int | None] = Field(default=0)
# TODO: owners
# BeatmapExtended
ar: float = Field(default=0.0)
cs: float = Field(default=0.0)
drain: float = Field(default=0.0) # hp
accuracy: float = Field(default=0.0) # od
bpm: float = Field(default=0.0)
count_circles: int = Field(default=0)
count_sliders: int = Field(default=0)
count_spinners: int = Field(default=0)
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
hit_length: int = Field(default=0)
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ar: OnDemand[float] = Field(default=0.0)
cs: OnDemand[float] = Field(default=0.0)
drain: OnDemand[float] = Field(default=0.0) # hp
accuracy: OnDemand[float] = Field(default=0.0) # od
bpm: OnDemand[float] = Field(default=0.0)
count_circles: OnDemand[int] = Field(default=0)
count_sliders: OnDemand[int] = Field(default=0)
count_spinners: OnDemand[int] = Field(default=0)
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
hit_length: OnDemand[int] = Field(default=0)
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
@included
@staticmethod
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap.beatmap_status.name.lower()
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> BeatmapsetDict | None:
if beatmap.beatmapset is not None:
return await BeatmapsetModel.transform(
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
)
@ondemand
@staticmethod
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
playcount = (
await _session.exec(
select(BeatmapPlaycounts.playcount).where(
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
)
)
).first()
return int(playcount or 0)
@ondemand
@staticmethod
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
if user is None:
return []
tag_ids = (
await _session.exec(
select(BeatmapTagVote.tag_id).where(
BeatmapTagVote.beatmap_id == beatmap.id,
BeatmapTagVote.user_id == user.id,
)
)
).all()
return list(tag_ids)
@ondemand
@staticmethod
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
if beatmap.failtimes is not None:
return FailTimeResp.from_db(beatmap.failtimes)
return FailTimeResp()
@ondemand
@staticmethod
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
all_votes = (
await _session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
return top_tag_ids
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> UserDict | None:
from .user import User
user = await _session.get(User, beatmap.user_id)
if user is None:
return None
return await UserModel.transform(user, includes=includes)
@ondemand
@staticmethod
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
return False
@ondemand
@staticmethod
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@ondemand
@staticmethod
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
return int(beatmap.mode)
@ondemand
@staticmethod
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
result = (
await _session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first()
return int(result or 0)
@ondemand
@staticmethod
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
from .score import Score
return (
await _session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
class Beatmap(BeatmapBase, table=True):
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
d = {k: v for k, v in resp.items() if k != "beatmapset"}
beatmapset_id = resp.get("beatmapset_id")
bid = resp.get("id")
ranked = resp.get("ranked")
if beatmapset_id is None or bid is None or ranked is None:
raise ValueError("beatmapset_id, id and ranked are required")
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
return beatmap
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
beatmap = await cls.from_resp_no_save(session, resp)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
resp_id = resp.get("id")
if resp_id is None:
raise ValueError("id is required")
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
session.add(beatmap)
await session.commit()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
for resp_dict in inp:
bid = resp_dict.get("id")
if bid == from_ or bid is None:
continue
d = resp.model_dump()
del d["beatmapset"]
beatmapset_id = resp_dict.get("beatmapset_id")
ranked = resp_dict.get("ranked")
if beatmapset_id is None or ranked is None:
continue
# 创建 beatmap 字典,移除 beatmapset
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
beatmap = Beatmap.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
for beatmap in beatmaps:
await session.refresh(beatmap)
return beatmaps
@classmethod
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
beatmap = (await session.exec(stmt)).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
beatmapset_id = resp.get("beatmapset_id")
if beatmapset_id is None:
raise ValueError("beatmapset_id is required")
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
set_resp = await fetcher.get_beatmapset(beatmapset_id)
resp_id = resp.get("id")
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
return await Beatmap.from_resp(session, resp)
return beatmap
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
count: int
class BeatmapResp(BeatmapBase):
id: int
beatmapset_id: int
beatmapset: BeatmapsetResp | None = None
convert: bool = False
is_scoreable: bool
status: str
mode_int: int
ranked: int
url: str = ""
playcount: int = 0
passcount: int = 0
failtimes: FailTimeResp | None = None
top_tag_ids: list[APIBeatmapTag] | None = None
current_user_tag_ids: list[int] | None = None
is_deleted: bool = False
@classmethod
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
from .score import Score
beatmap_ = beatmap.model_dump()
beatmap_status = beatmap.beatmap_status
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
beatmap_["status"] = beatmap_status.name.lower()
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
beatmap_["is_deleted"] = beatmap.deleted_at is not None
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
beatmap_["failtimes"] = FailTimeResp()
if session:
beatmap_["playcount"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first() or 0
beatmap_["passcount"] = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
all_votes = (
await session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
beatmap_["top_tag_ids"] = top_tag_ids
if user is not None:
beatmap_["current_user_tag_ids"] = (
await session.exec(
select(BeatmapTagVote.tag_id)
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.where(BeatmapTagVote.user_id == user.id)
)
).all()
else:
beatmap_["current_user_tag_ids"] = []
return cls.model_validate(beatmap_)
class BannedBeatmaps(SQLModel, table=True):
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)