refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from datetime import timedelta
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
from .rank_history import RankHistory
|
||||
|
||||
from pydantic import field_validator
|
||||
@@ -15,7 +16,6 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
@@ -23,10 +23,40 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User, UserResp
|
||||
from .user import User, UserDict
|
||||
|
||||
|
||||
class UserStatisticsBase(SQLModel):
|
||||
class UserStatisticsDict(TypedDict):
|
||||
mode: GameMode
|
||||
count_100: int
|
||||
count_300: int
|
||||
count_50: int
|
||||
count_miss: int
|
||||
pp: float
|
||||
ranked_score: int
|
||||
hit_accuracy: float
|
||||
total_score: int
|
||||
total_hits: int
|
||||
maximum_combo: int
|
||||
play_count: int
|
||||
play_time: int
|
||||
replays_watched_by_others: int
|
||||
is_ranked: bool
|
||||
level: NotRequired[dict[str, int]]
|
||||
global_rank: NotRequired[int | None]
|
||||
grade_counts: NotRequired[dict[str, int]]
|
||||
rank_change_since_30_days: NotRequired[int]
|
||||
country_rank: NotRequired[int | None]
|
||||
user: NotRequired["UserDict"]
|
||||
|
||||
|
||||
class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
|
||||
RANKING_INCLUDES: ClassVar[list[str]] = [
|
||||
"user.country",
|
||||
"user.cover",
|
||||
"user.team",
|
||||
]
|
||||
|
||||
mode: GameMode = Field(index=True)
|
||||
count_100: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
count_300: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
|
||||
return GameMode.OSU
|
||||
return v
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||
return {
|
||||
"current": int(statistics.level_current),
|
||||
"progress": int(math.fmod(statistics.level_current, 1) * 100),
|
||||
}
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
@included
|
||||
@staticmethod
|
||||
async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
|
||||
return await get_rank(session, statistics)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||
return {
|
||||
"ss": statistics.grade_ss,
|
||||
"ssh": statistics.grade_ssh,
|
||||
"s": statistics.grade_s,
|
||||
"sh": statistics.grade_sh,
|
||||
"a": statistics.grade_a,
|
||||
}
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
|
||||
global_rank = await get_rank(session, statistics)
|
||||
rank_best = (
|
||||
await session.exec(
|
||||
select(func.max(RankHistory.rank)).where(
|
||||
RankHistory.date > utcnow() - timedelta(days=30),
|
||||
RankHistory.user_id == statistics.user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if rank_best is None or global_rank is None:
|
||||
return 0
|
||||
return rank_best - global_rank
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def country_rank(
|
||||
session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
|
||||
) -> int | None:
|
||||
return await get_rank(session, statistics, user_country)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
|
||||
from .user import UserModel
|
||||
|
||||
user_instance = await statistics.awaitable_attrs.user
|
||||
return await UserModel.transform(user_instance)
|
||||
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
|
||||
__tablename__: str = "lazer_user_statistics"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
user: "User" = Relationship(back_populates="statistics")
|
||||
|
||||
|
||||
class UserStatisticsResp(UserStatisticsBase):
|
||||
user: "UserResp | None" = None
|
||||
rank_change_since_30_days: int | None = 0
|
||||
global_rank: int | None = Field(default=None)
|
||||
country_rank: int | None = Field(default=None)
|
||||
grade_counts: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"ss": 0,
|
||||
"ssh": 0,
|
||||
"s": 0,
|
||||
"sh": 0,
|
||||
"a": 0,
|
||||
}
|
||||
)
|
||||
level: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"current": 1,
|
||||
"progress": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
obj: UserStatistics,
|
||||
session: AsyncSession,
|
||||
user_country: str | None = None,
|
||||
include: list[str] = [],
|
||||
) -> "UserStatisticsResp":
|
||||
s = cls.model_validate(obj.model_dump())
|
||||
s.grade_counts = {
|
||||
"ss": obj.grade_ss,
|
||||
"ssh": obj.grade_ssh,
|
||||
"s": obj.grade_s,
|
||||
"sh": obj.grade_sh,
|
||||
"a": obj.grade_a,
|
||||
}
|
||||
s.level = {
|
||||
"current": int(obj.level_current),
|
||||
"progress": int(math.fmod(obj.level_current, 1) * 100),
|
||||
}
|
||||
if "user" in include:
|
||||
from .user import RANKING_INCLUDES, UserResp
|
||||
|
||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
||||
s.user = user
|
||||
user_country = user.country_code
|
||||
|
||||
s.global_rank = await get_rank(session, obj)
|
||||
s.country_rank = await get_rank(session, obj, user_country)
|
||||
|
||||
if "rank_change_since_30_days" in include:
|
||||
rank_best = (
|
||||
await session.exec(
|
||||
select(func.max(RankHistory.rank)).where(
|
||||
RankHistory.date > utcnow() - timedelta(days=30),
|
||||
RankHistory.user_id == obj.user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if rank_best is None or s.global_rank is None:
|
||||
s.rank_change_since_30_days = 0
|
||||
else:
|
||||
s.rank_change_since_30_days = rank_best - s.global_rank
|
||||
|
||||
return s
|
||||
|
||||
|
||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||
from .user import User
|
||||
|
||||
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
|
||||
query = query.join(User).where(User.country_code == country)
|
||||
|
||||
subq = query.subquery()
|
||||
|
||||
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
||||
|
||||
rank = result.first()
|
||||
|
||||
Reference in New Issue
Block a user