refactor(database): 优化数据库关联对象的载入 (#10)

This commit is contained in:
MingxuanGame
2025-07-31 20:11:22 +08:00
committed by GitHub
parent 1281e75bb1
commit be401e8885
13 changed files with 73 additions and 166 deletions

View File

@@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property
def can_ranked(self) -> bool:
@@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap)
.where(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
)
).first()
if not beatmap:
@@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase):
url: str = ""
@classmethod
def from_db(
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
@@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset)
return cls.model_validate(beatmap_)

View File

@@ -7,6 +7,7 @@ from app.models.score import GameMode
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase):
nominations: BeatmapNominations | None = None
@classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
from .beatmap import BeatmapResp
beatmaps = [
BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in beatmapset.beatmaps
await BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
]
return cls.model_validate(
{

View File

@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
JSON,
BigInteger,
@@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_bng: bool = False
class User(UserBase, table=True):
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
@@ -154,17 +154,6 @@ class User(UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
@classmethod
def all_select_option(cls):
return (
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
)
class UserResp(UserBase):
id: int | None = None
@@ -249,13 +238,7 @@ class UserResp(UserBase):
await RelationshipResp.from_db(session, r)
for r in (
await session.exec(
select(Relationship)
.options(
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
*User.all_select_option()
)
)
.where(
select(Relationship).where(
Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW,
)
@@ -264,23 +247,26 @@ class UserResp(UserBase):
]
if "team" in include:
if obj.team_membership:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
if "daily_challenge_user_stats":
if obj.daily_challenge_stats:
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if "statistics" in include:
current_stattistics = None
for i in obj.statistics:
for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode):
current_stattistics = i
break
@@ -292,17 +278,20 @@ class UserResp(UserBase):
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics
i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts
MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua) for ua in obj.achievement
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
return u

View File

@@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True):
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: User = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
@@ -79,7 +82,6 @@ class RelationshipResp(BaseModel):
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
],
),
mutual=mutual,

View File

@@ -27,7 +27,7 @@ from app.models.score import (
)
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
@@ -35,7 +35,8 @@ from .score_token import ScoreToken
from redis import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlmodel import (
JSON,
BigInteger,
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher
class ScoreBase(SQLModel, UTCBaseModel):
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
accuracy: float
map_md5: str = Field(max_length=32, index=True)
@@ -114,27 +115,12 @@ class Score(ScoreBase, table=True):
# optional
beatmap: Beatmap = Relationship()
user: User = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -148,18 +134,7 @@ class Score(ScoreBase, table=True):
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
return (
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
return select(best).where(subq.c.rn == 1)
class ScoreResp(ScoreBase):
@@ -186,8 +161,9 @@ class ScoreResp(ScoreBase):
) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap)
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -303,7 +279,6 @@ async def get_leaderboard(
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
@@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap(
) -> Score | None:
return (
await session.exec(
Score.select_clause(False)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,

View File

@@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True):
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="team_membership")
team: "Team" = Relationship(back_populates="members")
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)