refactor(database): 优化数据库关联对象的载入 (#10)

This commit is contained in:
MingxuanGame
2025-07-31 20:11:22 +08:00
committed by GitHub
parent 1281e75bb1
commit be401e8885
13 changed files with 73 additions and 166 deletions

View File

@@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property
def can_ranked(self) -> bool:
@@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap)
.where(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
)
).first()
if not beatmap:
@@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase):
url: str = ""
@classmethod
def from_db(
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
@@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset)
return cls.model_validate(beatmap_)

View File

@@ -7,6 +7,7 @@ from app.models.score import GameMode
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase):
nominations: BeatmapNominations | None = None
@classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
from .beatmap import BeatmapResp
beatmaps = [
BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in beatmapset.beatmaps
await BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
]
return cls.model_validate(
{

View File

@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
JSON,
BigInteger,
@@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_bng: bool = False
class User(UserBase, table=True):
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
@@ -154,17 +154,6 @@ class User(UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
@classmethod
def all_select_option(cls):
return (
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
)
class UserResp(UserBase):
id: int | None = None
@@ -249,13 +238,7 @@ class UserResp(UserBase):
await RelationshipResp.from_db(session, r)
for r in (
await session.exec(
select(Relationship)
.options(
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
*User.all_select_option()
)
)
.where(
select(Relationship).where(
Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW,
)
@@ -264,23 +247,26 @@ class UserResp(UserBase):
]
if "team" in include:
if obj.team_membership:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
if "daily_challenge_user_stats":
if obj.daily_challenge_stats:
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if "statistics" in include:
current_stattistics = None
for i in obj.statistics:
for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode):
current_stattistics = i
break
@@ -292,17 +278,20 @@ class UserResp(UserBase):
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics
i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts
MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua) for ua in obj.achievement
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
return u

View File

@@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True):
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: User = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
@@ -79,7 +82,6 @@ class RelationshipResp(BaseModel):
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
],
),
mutual=mutual,

View File

@@ -27,7 +27,7 @@ from app.models.score import (
)
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
@@ -35,7 +35,8 @@ from .score_token import ScoreToken
from redis import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlmodel import (
JSON,
BigInteger,
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher
class ScoreBase(SQLModel, UTCBaseModel):
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
accuracy: float
map_md5: str = Field(max_length=32, index=True)
@@ -114,27 +115,12 @@ class Score(ScoreBase, table=True):
# optional
beatmap: Beatmap = Relationship()
user: User = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -148,18 +134,7 @@ class Score(ScoreBase, table=True):
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
return (
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
return select(best).where(subq.c.rn == 1)
class ScoreResp(ScoreBase):
@@ -186,8 +161,9 @@ class ScoreResp(ScoreBase):
) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap)
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -303,7 +279,6 @@ async def get_leaderboard(
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
@@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap(
) -> Score | None:
return (
await session.exec(
Score.select_clause(False)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,

View File

@@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True):
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="team_membership")
team: "Team" = Relationship(back_populates="members")
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)

View File

@@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = (
await db.exec(
select(User)
.options(*User.all_select_option())
.where(User.id == token_record.user_id)
)
).first()
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user

View File

@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6
FRENCH = 7
GERMAN = 8
ITALIAN = 9
SPANISH = 10
RUSSIAN = 11
POLISH = 12
OTHER = 13
SWEDISH = 9
ITALIAN = 10
SPANISH = 11
RUSSIAN = 12
POLISH = 13
OTHER = 14
class BeatmapAttributes(BaseModel):

View File

@@ -5,7 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import Beatmap, BeatmapResp, Beatmapset, User
from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis import Redis
import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -51,7 +50,7 @@ async def lookup_beatmap(
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
@@ -63,7 +62,7 @@ async def get_beatmap(
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -83,35 +82,15 @@ async def batch_get_beatmaps(
# select 50 beatmaps by last_updated
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else:
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps])
@router.post(

View File

@@ -15,7 +15,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,13 +26,7 @@ async def get_beatmapset(
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
@@ -41,7 +34,7 @@ async def get_beatmapset(
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = BeatmapsetResp.from_db(beatmapset)
resp = await BeatmapsetResp.from_db(beatmapset)
return resp

View File

@@ -9,7 +9,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from pydantic import BaseModel
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,14 +26,12 @@ async def get_relationship(
else RelationshipType.BLOCK
)
relationships = await db.exec(
select(Relationship)
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
select(Relationship).where(
Relationship.user_id == current_user.id,
Relationship.type == relationship_type,
)
)
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
class AddFriendResp(BaseModel):
@@ -92,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
select(Relationship)
.where(
select(Relationship).where(
Relationship.user_id == current_user_id,
Relationship.target_id == target,
)
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
)
).first()
assert relationship, "Relationship should exist after commit"

View File

@@ -99,7 +99,7 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
Score.select_clause(True)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
@@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
Score.select_clause()
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
@@ -207,9 +207,7 @@ async def submit_solo_score(
if score_token.score_id:
score = (
await db.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
@@ -243,8 +241,6 @@ async def submit_solo_score(
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score, current_user)

View File

@@ -28,19 +28,10 @@ async def get_users(
):
if user_ids:
searched_users = (
await session.exec(
select(User)
.options(*User.all_select_option())
.limit(50)
.where(col(User.id).in_(user_ids))
)
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
else:
searched_users = (
await session.exec(
select(User).options(*User.all_select_option()).limit(50)
)
).all()
searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
await UserResp.from_db(
@@ -63,9 +54,7 @@ async def get_user_info(
):
searched_user = (
await session.exec(
select(User)
.options(*User.all_select_option())
.where(
select(User).where(
User.id == int(user)
if user.isdigit()
else User.username == user.removeprefix("@")