refactor(database): 优化数据库关联对象的载入 (#10)

This commit is contained in:
MingxuanGame
2025-07-31 20:11:22 +08:00
committed by GitHub
parent 1281e75bb1
commit be401e8885
13 changed files with 73 additions and 166 deletions

View File

@@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus beatmap_status: BeatmapRankStatus
# optional # optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps") beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property @property
def can_ranked(self) -> bool: def can_ranked(self) -> bool:
@@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap) session.add(beatmap)
await session.commit() await session.commit()
beatmap = ( beatmap = (
await session.exec( await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
).first() ).first()
assert beatmap is not None, "Beatmap should not be None after commit" assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap return beatmap
@@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap": ) -> "Beatmap":
beatmap = ( beatmap = (
await session.exec( await session.exec(
select(Beatmap) select(Beatmap).where(
.where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
) )
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
) )
).first() ).first()
if not beatmap: if not beatmap:
@@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase):
url: str = "" url: str = ""
@classmethod @classmethod
def from_db( async def from_db(
cls, cls,
beatmap: Beatmap, beatmap: Beatmap,
query_mode: GameMode | None = None, query_mode: GameMode | None = None,
@@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set: if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset)
return cls.model_validate(beatmap_) return cls.model_validate(beatmap_)

View File

@@ -7,6 +7,7 @@ from app.models.score import GameMode
from pydantic import BaseModel, model_serializer from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel from sqlmodel import Field, Relationship, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
tags: str = Field(default="", sa_column=Column(Text)) tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True): class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
@@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase):
nominations: BeatmapNominations | None = None nominations: BeatmapNominations | None = None
@classmethod @classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
from .beatmap import BeatmapResp from .beatmap import BeatmapResp
beatmaps = [ beatmaps = [
BeatmapResp.from_db(beatmap, from_set=True) await BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in beatmapset.beatmaps for beatmap in await beatmapset.awaitable_attrs.beatmaps
] ]
return cls.model_validate( return cls.model_validate(
{ {

View File

@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.orm import joinedload, selectinload from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
BigInteger, BigInteger,
@@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_bng: bool = False is_bng: bool = False
class User(UserBase, table=True): class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field( id: int | None = Field(
@@ -154,17 +154,6 @@ class User(UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
) )
@classmethod
def all_select_option(cls):
return (
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
)
class UserResp(UserBase): class UserResp(UserBase):
id: int | None = None id: int | None = None
@@ -249,13 +238,7 @@ class UserResp(UserBase):
await RelationshipResp.from_db(session, r) await RelationshipResp.from_db(session, r)
for r in ( for r in (
await session.exec( await session.exec(
select(Relationship) select(Relationship).where(
.options(
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
*User.all_select_option()
)
)
.where(
Relationship.user_id == obj.id, Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW, Relationship.type == RelationshipType.FOLLOW,
) )
@@ -264,23 +247,26 @@ class UserResp(UserBase):
] ]
if "team" in include: if "team" in include:
if obj.team_membership: if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team u.team = obj.team_membership.team
if "account_history" in include: if "account_history" in include:
u.account_history = [ u.account_history = [
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
] ]
if "daily_challenge_user_stats": if "daily_challenge_user_stats":
if obj.daily_challenge_stats: if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats obj.daily_challenge_stats
) )
if "statistics" in include: if "statistics" in include:
current_stattistics = None current_stattistics = None
for i in obj.statistics: for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode): if i.mode == (ruleset or obj.playmode):
current_stattistics = i current_stattistics = i
break break
@@ -292,17 +278,20 @@ class UserResp(UserBase):
if "statistics_rulesets" in include: if "statistics_rulesets" in include:
u.statistics_rulesets = { u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
} }
if "monthly_playcounts" in include: if "monthly_playcounts" in include:
u.monthly_playcounts = [ u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
] ]
if "achievements" in include: if "achievements" in include:
u.user_achievements = [ u.user_achievements = [
UserAchievementResp.from_db(ua) for ua in obj.achievement UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
] ]
return u return u

View File

@@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True):
) )
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: User = SQLRelationship( target: User = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
) )
@@ -79,7 +82,6 @@ class RelationshipResp(BaseModel):
"daily_challenge_user_stats", "daily_challenge_user_stats",
"statistics", "statistics",
"statistics_rulesets", "statistics_rulesets",
"achievements",
], ],
), ),
mutual=mutual, mutual=mutual,

View File

@@ -27,7 +27,7 @@ from app.models.score import (
) )
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp from .beatmapset import BeatmapsetResp
from .best_score import BestScore from .best_score import BestScore
from .lazer_user import User, UserResp from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts from .monthly_playcounts import MonthlyPlaycounts
@@ -35,7 +35,8 @@ from .score_token import ScoreToken
from redis import Redis from redis import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
BigInteger, BigInteger,
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher from app.fetcher import Fetcher
class ScoreBase(SQLModel, UTCBaseModel): class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段 # 基本字段
accuracy: float accuracy: float
map_md5: str = Field(max_length=32, index=True) map_md5: str = Field(max_length=32, index=True)
@@ -114,27 +115,12 @@ class Score(ScoreBase, table=True):
# optional # optional
beatmap: Beatmap = Relationship() beatmap: Beatmap = Relationship()
user: User = Relationship() user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property @property
def is_perfect_combo(self) -> bool: def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod @staticmethod
def select_clause_unique( def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool, *where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -148,18 +134,7 @@ class Score(ScoreBase, table=True):
) )
subq = select(Score, rownum).where(*where_clauses).subquery() subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True) best = aliased(Score, subq, adapt_on_names=True)
return ( return select(best).where(subq.c.rn == 1)
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
class ScoreResp(ScoreBase): class ScoreResp(ScoreBase):
@@ -186,8 +161,9 @@ class ScoreResp(ScoreBase):
) -> "ScoreResp": ) -> "ScoreResp":
s = cls.model_validate(score.model_dump()) s = cls.model_validate(score.model_dump())
assert score.id assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap) await score.awaitable_attrs.beatmap
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode] s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -303,7 +279,6 @@ async def get_leaderboard(
query = ( query = (
select(Score) select(Score)
.join(Beatmap) .join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where( .where(
Score.map_md5 == beatmap_md5, Score.map_md5 == beatmap_md5,
Score.gamemode == mode, Score.gamemode == mode,
@@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap(
) -> Score | None: ) -> Score | None:
return ( return (
await session.exec( await session.exec(
Score.select_clause(False) select(Score)
.where( .where(
Score.gamemode == mode if mode is not None else True, Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,

View File

@@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True):
default_factory=datetime.utcnow, sa_column=Column(DateTime) default_factory=datetime.utcnow, sa_column=Column(DateTime)
) )
user: "User" = Relationship(back_populates="team_membership") user: "User" = Relationship(
team: "Team" = Relationship(back_populates="members") back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)

View File

@@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
return None return None
user = ( user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
await db.exec(
select(User)
.options(*User.all_select_option())
.where(User.id == token_record.user_id)
)
).first()
return user return user

View File

@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6 KOREAN = 6
FRENCH = 7 FRENCH = 7
GERMAN = 8 GERMAN = 8
ITALIAN = 9 SWEDISH = 9
SPANISH = 10 ITALIAN = 10
RUSSIAN = 11 SPANISH = 11
POLISH = 12 RUSSIAN = 12
OTHER = 13 POLISH = 13
OTHER = 14
class BeatmapAttributes(BaseModel): class BeatmapAttributes(BaseModel):

View File

@@ -5,7 +5,7 @@ import hashlib
import json import json
from app.calculator import calculate_beatmap_attribute from app.calculator import calculate_beatmap_attribute
from app.database import Beatmap, BeatmapResp, Beatmapset, User from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis from redis import Redis
import rosu_pp_py as rosu import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -51,7 +50,7 @@ async def lookup_beatmap(
if beatmap is None: if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap) return await BeatmapResp.from_db(beatmap)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
@@ -63,7 +62,7 @@ async def get_beatmap(
): ):
try: try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap) return await BeatmapResp.from_db(beatmap)
except HTTPError: except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -83,35 +82,15 @@ async def batch_get_beatmaps(
# select 50 beatmaps by last_updated # select 50 beatmaps by last_updated
beatmaps = ( beatmaps = (
await db.exec( await db.exec(
select(Beatmap) select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
) )
).all() ).all()
else: else:
beatmaps = ( beatmaps = (
await db.exec( await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all() ).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps])
@router.post( @router.post(

View File

@@ -15,7 +15,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,13 +26,7 @@ async def get_beatmapset(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmapset = ( beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset: if not beatmapset:
try: try:
resp = await fetcher.get_beatmapset(sid) resp = await fetcher.get_beatmapset(sid)
@@ -41,7 +34,7 @@ async def get_beatmapset(
except HTTPStatusError: except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found") raise HTTPException(status_code=404, detail="Beatmapset not found")
else: else:
resp = BeatmapsetResp.from_db(beatmapset) resp = await BeatmapsetResp.from_db(beatmapset)
return resp return resp

View File

@@ -9,7 +9,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query, Request from fastapi import Depends, HTTPException, Query, Request
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import joinedload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,14 +26,12 @@ async def get_relationship(
else RelationshipType.BLOCK else RelationshipType.BLOCK
) )
relationships = await db.exec( relationships = await db.exec(
select(Relationship) select(Relationship).where(
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
Relationship.type == relationship_type, Relationship.type == relationship_type,
) )
) )
return [await RelationshipResp.from_db(db, rel) for rel in relationships] return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
class AddFriendResp(BaseModel): class AddFriendResp(BaseModel):
@@ -92,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW: if origin_type == RelationshipType.FOLLOW:
relationship = ( relationship = (
await db.exec( await db.exec(
select(Relationship) select(Relationship).where(
.where(
Relationship.user_id == current_user_id, Relationship.user_id == current_user_id,
Relationship.target_id == target, Relationship.target_id == target,
) )
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
) )
).first() ).first()
assert relationship, "Relationship should exist after commit" assert relationship, "Relationship should exist after commit"

View File

@@ -99,7 +99,7 @@ async def get_user_beatmap_score(
) )
user_score = ( user_score = (
await db.exec( await db.exec(
Score.select_clause(True) select(Score)
.where( .where(
Score.gamemode == mode if mode is not None else True, Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,
@@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores(
) )
all_user_scores = ( all_user_scores = (
await db.exec( await db.exec(
Score.select_clause() select(Score)
.where( .where(
Score.gamemode == ruleset if ruleset is not None else True, Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,
@@ -207,9 +207,7 @@ async def submit_solo_score(
if score_token.score_id: if score_token.score_id:
score = ( score = (
await db.exec( await db.exec(
select(Score) select(Score).where(
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == score_token.score_id, Score.id == score_token.score_id,
Score.user_id == current_user.id, Score.user_id == current_user.id,
) )
@@ -243,8 +241,6 @@ async def submit_solo_score(
score_id = score.id score_id = score.id
score_token.score_id = score_id score_token.score_id = score_id
await process_user(db, current_user, score, ranked) await process_user(db, current_user, score, ranked)
score = ( score = (await db.exec(select(Score).where(Score.id == score_id))).first()
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
assert score is not None assert score is not None
return await ScoreResp.from_db(db, score, current_user) return await ScoreResp.from_db(db, score, current_user)

View File

@@ -28,19 +28,10 @@ async def get_users(
): ):
if user_ids: if user_ids:
searched_users = ( searched_users = (
await session.exec( await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
select(User)
.options(*User.all_select_option())
.limit(50)
.where(col(User.id).in_(user_ids))
)
).all() ).all()
else: else:
searched_users = ( searched_users = (await session.exec(select(User).limit(50))).all()
await session.exec(
select(User).options(*User.all_select_option()).limit(50)
)
).all()
return BatchUserResponse( return BatchUserResponse(
users=[ users=[
await UserResp.from_db( await UserResp.from_db(
@@ -63,9 +54,7 @@ async def get_user_info(
): ):
searched_user = ( searched_user = (
await session.exec( await session.exec(
select(User) select(User).where(
.options(*User.all_select_option())
.where(
User.id == int(user) User.id == int(user)
if user.isdigit() if user.isdigit()
else User.username == user.removeprefix("@") else User.username == user.removeprefix("@")