refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -3,7 +3,7 @@ from datetime import date, datetime
import json
import math
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.calculator import (
calculate_pp_weight,
@@ -15,8 +15,6 @@ from app.calculator import (
pre_fetch_and_calculate_pp,
)
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.team import TeamMember
from app.dependencies.database import get_redis
from app.log import log
from app.models.beatmap import BeatmapRankStatus
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
from app.storage import StorageService
from app.utils import utcnow
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import BeatmapsetDict, BeatmapsetModel
from .best_scores import BestScore
from .counts import MonthlyPlaycounts
from .events import Event, EventType
@@ -50,8 +50,9 @@ from .relationship import (
RelationshipType,
)
from .score_token import ScoreToken
from .team import TeamMember
from .total_score_best_scores import TotalScoreBestScore
from .user import User, UserResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, field_serializer, field_validator
from redis.asyncio import Redis
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
logger = log("Score")
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
class ScoreDict(TypedDict):
beatmap_id: int
id: int
rank: Rank
type: str
user_id: int
accuracy: float
build_id: int | None
ended_at: datetime
has_replay: bool
max_combo: int
passed: bool
pp: float
started_at: datetime
total_score: int
maximum_statistics: ScoreStatistics
mods: list[APIMod]
classic_total_score: int | None
preserve: bool
processed: bool
ranked: bool
playlist_item_id: NotRequired[int | None]
room_id: NotRequired[int | None]
best_id: NotRequired[int | None]
legacy_perfect: NotRequired[bool]
is_perfect_combo: NotRequired[bool]
ruleset_id: NotRequired[int]
statistics: NotRequired[ScoreStatistics]
beatmapset: NotRequired[BeatmapsetDict]
beatmap: NotRequired[BeatmapDict]
current_user_attributes: NotRequired[CurrentUserAttributes]
position: NotRequired[int | None]
scores_around: NotRequired["ScoreAround | None"]
rank_country: NotRequired[int | None]
rank_global: NotRequired[int | None]
user: NotRequired[UserDict]
weight: NotRequired[float | None]
# ScoreResp 字段
legacy_total_score: NotRequired[int]
class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
# https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
"user.country",
"user.cover",
"user.team",
*MULTIPLAYER_SCORE_INCLUDE,
]
USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
# 基本字段
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
rank: Rank
type: str
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
accuracy: float
map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool = Field(sa_column=Column(Boolean))
playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float = Field(default=0.0)
preserve: bool = Field(default=True, sa_column=Column(Boolean))
rank: Rank
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
processed: bool = False # solo_score
ranked: bool = False
mods: list[APIMod] = Field(sa_column=Column(JSON))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
# solo
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
preserve: bool = Field(default=True, sa_column=Column(Boolean))
processed: bool = Field(default=False)
ranked: bool = Field(default=False)
# multiplayer
playlist_item_id: OnDemand[int | None] = Field(default=None)
room_id: OnDemand[int | None] = Field(default=None)
@included
@staticmethod
async def best_id(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_best_id(session, score.id)
@included
@staticmethod
async def legacy_perfect(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def is_perfect_combo(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def ruleset_id(
_session: AsyncSession,
score: "Score",
) -> int:
return int(score.gamemode)
@included
@staticmethod
async def statistics(
_session: AsyncSession,
score: "Score",
) -> ScoreStatistics:
stats = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
return stats
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapsetDict:
await score.awaitable_attrs.beatmap
return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
# reorder beatmapset and beatmap
# https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
@ondemand
@staticmethod
async def beatmap(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapDict:
await score.awaitable_attrs.beatmap
return await BeatmapModel.transform(score.beatmap, includes=includes)
@ondemand
@staticmethod
async def current_user_attributes(
_session: AsyncSession,
score: "Score",
) -> CurrentUserAttributes:
return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
@ondemand
@staticmethod
async def position(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
@ondemand
@staticmethod
async def scores_around(
session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
) -> "ScoreAround | None":
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
total_score = score.score.total_score
resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
if score.total_score > total_score:
higher_scores.append(resp)
elif score.total_score < total_score:
lower_scores.append(resp)
return ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
@ondemand
@staticmethod
async def rank_country(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
@ondemand
@staticmethod
async def rank_global(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> UserDict:
return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
@ondemand
@staticmethod
async def weight(
session: AsyncSession,
score: "Score",
) -> float | None:
best_id = await get_best_id(session, score.id)
if best_id:
return calculate_pp_weight(best_id - 1)
return None
@ondemand
@staticmethod
async def legacy_total_score(
_session: AsyncSession,
_score: "Score",
) -> int:
return 0
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# TODO: current_user_attributes
class Score(ScoreBase, table=True):
class Score(ScoreModel, table=True):
__tablename__: str = "scores"
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
# ScoreStatistics
n300: int = Field(exclude=True)
n100: int = Field(exclude=True)
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
map_md5: str = Field(max_length=32, index=True, exclude=True)
@field_validator("gamemode", mode="before")
@classmethod
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
maximum_statistics=self.maximum_statistics,
)
async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
async def to_resp(
self, session: AsyncSession, api_version: int, includes: list[str] = []
) -> "ScoreDict | LegacyScoreResp":
if api_version >= 20220705:
return await ScoreResp.from_db(session, self)
return await ScoreModel.transform(self, includes=includes)
return await LegacyScoreResp.from_db(session, self)
async def delete(
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
await session.delete(self)
class ScoreResp(ScoreBase):
id: int
user_id: int
is_perfect_combo: bool = False
legacy_perfect: bool = False
legacy_total_score: int = 0 # FIXME
weight: float = 0.0
best_id: int | None = None
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
position: int | None = None
scores_around: "ScoreAround | None" = None
current_user_attributes: CurrentUserAttributes | None = None
@field_validator(
"has_replay",
"passed",
"preserve",
"is_perfect_combo",
"legacy_perfect",
"processed",
"ranked",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@field_validator("statistics", "maximum_statistics", mode="before")
@classmethod
def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
if isinstance(v, dict):
converted = {}
for key, value in v.items():
if isinstance(key, str):
try:
# 尝试将字符串转换为 HitResult 枚举
enum_key = HitResult(key)
converted[enum_key] = value
except ValueError:
# 如果转换失败,跳过这个键值对
continue
else:
converted[key] = value
return converted
return v
@field_serializer("statistics", when_used="json")
def serialize_statistics_fields(self, v):
"""序列化统计字段,确保枚举值正确转换为字符串"""
if isinstance(v, dict):
serialized = {}
for key, value in v.items():
if hasattr(key, "value"):
# 如果是枚举,使用其值
serialized[key.value] = value
else:
# 否则直接使用键
serialized[str(key)] = value
return serialized
return v
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
# 确保 score 对象完全加载,避免懒加载问题
await session.refresh(score)
s = cls.model_validate(score.model_dump())
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode)
best_id = await get_best_id(session, score.id)
if best_id:
s.best_id = best_id
s.weight = calculate_pp_weight(best_id - 1)
s.statistics = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
s.user = await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
s.current_user_attributes = CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
)
return s
MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
class LegacyStatistics(BaseModel):
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
class LegacyScoreResp(UTCBaseModel):
accuracy: float
best_id: int
created_at: datetime
id: int
max_combo: int
mode: GameMode
mode_int: int
best_id: int
user_id: int
accuracy: float
mods: list[str] # acronym
passed: bool
score: int
max_combo: int
perfect: bool = False
statistics: LegacyStatistics
passed: bool
pp: float
rank: Rank
created_at: datetime
mode: GameMode
mode_int: int
replay: bool
score: int
statistics: LegacyStatistics
type: str
user_id: int
current_user_attributes: CurrentUserAttributes
user: UserResp
beatmap: BeatmapResp
rank_global: int | None = Field(default=None, exclude=True)
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
await session.refresh(score)
async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
await score.awaitable_attrs.beatmap
return cls(
accuracy=score.accuracy,
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
count_geki=score.ngeki or 0,
count_katu=score.nkatu or 0,
),
type=score.type,
user_id=score.user_id,
current_user_attributes=CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
),
user=await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
),
beatmap=await BeatmapResp.from_db(score.beatmap),
perfect=score.is_perfect_combo,
rank_global=(
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
),
)
class MultiplayerScores(RespWithCursor):
scores: list[ScoreResp] = Field(default_factory=list)
scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
params: dict[str, Any] = Field(default_factory=dict)
@@ -842,13 +937,13 @@ async def get_user_best_pp(
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
def get_play_length(score: Score, beatmap_length: int):
def get_play_length(score: "Score", beatmap_length: int):
speed_rate = get_speed_rate(score.mods)
length = beatmap_length / speed_rate
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
total_length = get_play_length(score, beatmap_length)
total_obj_hited = (
score.n300
@@ -937,7 +1032,7 @@ async def process_score(
return score
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
if score.pp != 0:
logger.debug(
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
)
async def _process_score_events(score: Score, session: AsyncSession):
async def _process_score_events(score: "Score", session: AsyncSession):
total_users = (await session.exec(select(func.count()).select_from(User))).one()
rank_global = await get_score_position_by_id(
session,
@@ -1088,7 +1183,7 @@ async def _process_statistics(
session: AsyncSession,
redis: Redis,
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,
@@ -1318,7 +1413,7 @@ async def process_user(
redis: Redis,
fetcher: "Fetcher",
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,