refactor(database): 优化数据库关联对象的载入 (#10)
This commit is contained in:
@@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True):
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
|
||||
beatmapset: Beatmapset = Relationship(
|
||||
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
|
||||
@property
|
||||
def can_ranked(self) -> bool:
|
||||
@@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(Beatmap.id == resp.id)
|
||||
)
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
|
||||
).first()
|
||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||
return beatmap
|
||||
@@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True):
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.where(
|
||||
select(Beatmap).where(
|
||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||
)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
@@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase):
|
||||
url: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmap: Beatmap,
|
||||
query_mode: GameMode | None = None,
|
||||
@@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase):
|
||||
beatmap_["ranked"] = beatmap.beatmap_status.value
|
||||
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset)
|
||||
return cls.model_validate(beatmap_)
|
||||
|
||||
@@ -7,6 +7,7 @@ from app.models.score import GameMode
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
|
||||
|
||||
class Beatmapset(BeatmapsetBase, table=True):
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
@@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
nominations: BeatmapNominations | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
|
||||
async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
|
||||
from .beatmap import BeatmapResp
|
||||
|
||||
beatmaps = [
|
||||
BeatmapResp.from_db(beatmap, from_set=True)
|
||||
for beatmap in beatmapset.beatmaps
|
||||
await BeatmapResp.from_db(beatmap, from_set=True)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
]
|
||||
return cls.model_validate(
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
|
||||
from .team import Team, TeamMember
|
||||
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
||||
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_bng: bool = False
|
||||
|
||||
|
||||
class User(UserBase, table=True):
|
||||
class User(AsyncAttrs, UserBase, table=True):
|
||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
@@ -154,17 +154,6 @@ class User(UserBase, table=True):
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_option(cls):
|
||||
return (
|
||||
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
|
||||
class UserResp(UserBase):
|
||||
id: int | None = None
|
||||
@@ -249,13 +238,7 @@ class UserResp(UserBase):
|
||||
await RelationshipResp.from_db(session, r)
|
||||
for r in (
|
||||
await session.exec(
|
||||
select(Relationship)
|
||||
.options(
|
||||
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
|
||||
*User.all_select_option()
|
||||
)
|
||||
)
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == obj.id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
@@ -264,23 +247,26 @@ class UserResp(UserBase):
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if obj.team_membership:
|
||||
if await obj.awaitable_attrs.team_membership:
|
||||
assert obj.team_membership
|
||||
u.team = obj.team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [
|
||||
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history
|
||||
UserAccountHistoryResp.from_db(ah)
|
||||
for ah in await obj.awaitable_attrs.account_history
|
||||
]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if obj.daily_challenge_stats:
|
||||
if await obj.awaitable_attrs.daily_challenge_stats:
|
||||
assert obj.daily_challenge_stats
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
||||
obj.daily_challenge_stats
|
||||
)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
for i in obj.statistics:
|
||||
for i in await obj.awaitable_attrs.statistics:
|
||||
if i.mode == (ruleset or obj.playmode):
|
||||
current_stattistics = i
|
||||
break
|
||||
@@ -292,17 +278,20 @@ class UserResp(UserBase):
|
||||
|
||||
if "statistics_rulesets" in include:
|
||||
u.statistics_rulesets = {
|
||||
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics
|
||||
i.mode.value: UserStatisticsResp.from_db(i)
|
||||
for i in await obj.awaitable_attrs.statistics
|
||||
}
|
||||
|
||||
if "monthly_playcounts" in include:
|
||||
u.monthly_playcounts = [
|
||||
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts
|
||||
MonthlyPlaycountsResp.from_db(pc)
|
||||
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||
]
|
||||
|
||||
if "achievements" in include:
|
||||
u.user_achievements = [
|
||||
UserAchievementResp.from_db(ua) for ua in obj.achievement
|
||||
UserAchievementResp.from_db(ua)
|
||||
for ua in await obj.awaitable_attrs.achievement
|
||||
]
|
||||
|
||||
return u
|
||||
|
||||
@@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True):
|
||||
)
|
||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||
target: User = SQLRelationship(
|
||||
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[Relationship.target_id]",
|
||||
"lazy": "selectin",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +82,6 @@ class RelationshipResp(BaseModel):
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
"achievements",
|
||||
],
|
||||
),
|
||||
mutual=mutual,
|
||||
|
||||
@@ -27,7 +27,7 @@ from app.models.score import (
|
||||
)
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .best_score import BestScore
|
||||
from .lazer_user import User, UserResp
|
||||
from .monthly_playcounts import MonthlyPlaycounts
|
||||
@@ -35,7 +35,8 @@ from .score_token import ScoreToken
|
||||
|
||||
from redis import Redis
|
||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy.orm import aliased, joinedload
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
|
||||
class ScoreBase(SQLModel, UTCBaseModel):
|
||||
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
# 基本字段
|
||||
accuracy: float
|
||||
map_md5: str = Field(max_length=32, index=True)
|
||||
@@ -114,27 +115,12 @@ class Score(ScoreBase, table=True):
|
||||
|
||||
# optional
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: User = Relationship()
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@property
|
||||
def is_perfect_combo(self) -> bool:
|
||||
return self.max_combo == self.beatmap.max_combo
|
||||
|
||||
@staticmethod
|
||||
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
|
||||
clause = select(Score).options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
if with_user:
|
||||
return clause.options(
|
||||
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return clause
|
||||
|
||||
@staticmethod
|
||||
def select_clause_unique(
|
||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||
@@ -148,18 +134,7 @@ class Score(ScoreBase, table=True):
|
||||
)
|
||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||
best = aliased(Score, subq, adapt_on_names=True)
|
||||
return (
|
||||
select(best)
|
||||
.where(subq.c.rn == 1)
|
||||
.options(
|
||||
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
return select(best).where(subq.c.rn == 1)
|
||||
|
||||
|
||||
class ScoreResp(ScoreBase):
|
||||
@@ -186,8 +161,9 @@ class ScoreResp(ScoreBase):
|
||||
) -> "ScoreResp":
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
s.beatmap = BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
||||
await score.awaitable_attrs.beatmap
|
||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||
s.ruleset_id = MODE_TO_INT[score.gamemode]
|
||||
@@ -303,7 +279,6 @@ async def get_leaderboard(
|
||||
query = (
|
||||
select(Score)
|
||||
.join(Beatmap)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
@@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap(
|
||||
) -> Score | None:
|
||||
return (
|
||||
await session.exec(
|
||||
Score.select_clause(False)
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
|
||||
@@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True):
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="team_membership")
|
||||
team: "Team" = Relationship(back_populates="members")
|
||||
user: "User" = Relationship(
|
||||
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
team: "Team" = Relationship(
|
||||
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
|
||||
@@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = (
|
||||
await db.exec(
|
||||
select(User)
|
||||
.options(*User.all_select_option())
|
||||
.where(User.id == token_record.user_id)
|
||||
)
|
||||
).first()
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
@@ -42,11 +42,12 @@ class Language(IntEnum):
|
||||
KOREAN = 6
|
||||
FRENCH = 7
|
||||
GERMAN = 8
|
||||
ITALIAN = 9
|
||||
SPANISH = 10
|
||||
RUSSIAN = 11
|
||||
POLISH = 12
|
||||
OTHER = 13
|
||||
SWEDISH = 9
|
||||
ITALIAN = 10
|
||||
SPANISH = 11
|
||||
RUSSIAN = 12
|
||||
POLISH = 13
|
||||
OTHER = 14
|
||||
|
||||
|
||||
class BeatmapAttributes(BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ import hashlib
|
||||
import json
|
||||
|
||||
from app.calculator import calculate_beatmap_attribute
|
||||
from app.database import Beatmap, BeatmapResp, Beatmapset, User
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -51,7 +50,7 @@ async def lookup_beatmap(
|
||||
if beatmap is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap)
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
@@ -63,7 +62,7 @@ async def get_beatmap(
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -83,35 +82,15 @@ async def batch_get_beatmaps(
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -15,7 +15,6 @@ from .api_router import router
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPStatusError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -27,13 +26,7 @@ async def get_beatmapset(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||
if not beatmapset:
|
||||
try:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
@@ -41,7 +34,7 @@ async def get_beatmapset(
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
else:
|
||||
resp = BeatmapsetResp.from_db(beatmapset)
|
||||
resp = await BeatmapsetResp.from_db(beatmapset)
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -27,14 +26,12 @@ async def get_relationship(
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship)
|
||||
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.type == relationship_type,
|
||||
)
|
||||
)
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
||||
|
||||
|
||||
class AddFriendResp(BaseModel):
|
||||
@@ -92,14 +89,10 @@ async def add_relationship(
|
||||
if origin_type == RelationshipType.FOLLOW:
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship)
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user_id,
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
.options(
|
||||
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
|
||||
@@ -99,7 +99,7 @@ async def get_user_beatmap_score(
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
Score.select_clause(True)
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
Score.select_clause()
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -207,9 +207,7 @@ async def submit_solo_score(
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
@@ -243,8 +241,6 @@ async def submit_solo_score(
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (
|
||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
||||
).first()
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score, current_user)
|
||||
|
||||
@@ -28,19 +28,10 @@ async def get_users(
|
||||
):
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User)
|
||||
.options(*User.all_select_option())
|
||||
.limit(50)
|
||||
.where(col(User.id).in_(user_ids))
|
||||
)
|
||||
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
|
||||
).all()
|
||||
else:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User).options(*User.all_select_option()).limit(50)
|
||||
)
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await UserResp.from_db(
|
||||
@@ -63,9 +54,7 @@ async def get_user_info(
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User)
|
||||
.options(*User.all_select_option())
|
||||
.where(
|
||||
select(User).where(
|
||||
User.id == int(user)
|
||||
if user.isdigit()
|
||||
else User.username == user.removeprefix("@")
|
||||
|
||||
Reference in New Issue
Block a user