diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 751bc5c..2ab5ad0 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime -from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True): beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + beatmapset: Beatmapset = Relationship( + back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} + ) @property def can_ranked(self) -> bool: @@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True): session.add(beatmap) await session.commit() beatmap = ( - await session.exec( - select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) - .where(Beatmap.id == resp.id) - ) + await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) ).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True): ) -> "Beatmap": beatmap = ( await session.exec( - select(Beatmap) - .where( + select(Beatmap).where( Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 ) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) ) ).first() if not beatmap: @@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase): url: str = "" @classmethod - def from_db( + async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, @@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 2ef6280..5a618b7 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -7,6 +7,7 @@ from app.models.score import GameMode from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): tags: str = Field(default="", sa_column=Column(Text)) -class Beatmapset(BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase): nominations: BeatmapNominations | None = None @classmethod - def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": from .beatmap import BeatmapResp beatmaps = [ - BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in beatmapset.beatmaps + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps ] return cls.model_validate( { diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 9b98c98..d502ccb 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( JSON, BigInteger, @@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel): is_bng: bool = False -class User(UserBase, table=True): +class User(AsyncAttrs, UserBase, table=True): __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] id: int | None = Field( @@ -154,17 +154,6 @@ class User(UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) - @classmethod - def all_select_option(cls): - return ( - selectinload(cls.account_history), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.achievement), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType] - ) - class UserResp(UserBase): id: int | None = None @@ -249,13 +238,7 @@ class UserResp(UserBase): await RelationshipResp.from_db(session, r) for r in ( await session.exec( - select(Relationship) - .options( - joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType] - *User.all_select_option() - ) - ) - .where( + select(Relationship).where( Relationship.user_id == obj.id, Relationship.type == RelationshipType.FOLLOW, ) @@ -264,23 +247,26 @@ class UserResp(UserBase): ] if "team" in include: - if obj.team_membership: + if await obj.awaitable_attrs.team_membership: + assert obj.team_membership u.team = obj.team_membership.team if "account_history" in include: u.account_history = [ - UserAccountHistoryResp.from_db(ah) for ah in obj.account_history + UserAccountHistoryResp.from_db(ah) + for ah in await obj.awaitable_attrs.account_history ] if "daily_challenge_user_stats": - if obj.daily_challenge_stats: + if await obj.awaitable_attrs.daily_challenge_stats: + assert obj.daily_challenge_stats u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( obj.daily_challenge_stats ) if "statistics" in include: current_stattistics = None - for i in obj.statistics: + for i in await obj.awaitable_attrs.statistics: if i.mode == (ruleset or obj.playmode): current_stattistics = i break @@ -292,17 +278,20 @@ class UserResp(UserBase): if "statistics_rulesets" in include: u.statistics_rulesets = { - i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics + i.mode.value: UserStatisticsResp.from_db(i) + for i in await obj.awaitable_attrs.statistics } if "monthly_playcounts" in include: u.monthly_playcounts = [ - MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts + MonthlyPlaycountsResp.from_db(pc) + for pc in await obj.awaitable_attrs.monthly_playcounts ] if "achievements" in include: u.user_achievements = [ - UserAchievementResp.from_db(ua) for ua in obj.achievement + UserAchievementResp.from_db(ua) + for ua in await obj.awaitable_attrs.achievement ] return u diff --git a/app/database/relationship.py b/app/database/relationship.py index 07daa25..7a351aa 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True): ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) target: User = SQLRelationship( - sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } ) @@ -79,7 +82,6 @@ class RelationshipResp(BaseModel): "daily_challenge_user_stats", "statistics", "statistics_rulesets", - "achievements", ], ), mutual=mutual, diff --git a/app/database/score.py b/app/database/score.py index 32c8cf5..4cee832 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -27,7 +27,7 @@ from app.models.score import ( ) from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp from .monthly_playcounts import MonthlyPlaycounts @@ -35,7 +35,8 @@ from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime -from sqlalchemy.orm import aliased, joinedload +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import aliased from sqlmodel import ( JSON, BigInteger, @@ -55,7 +56,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel, UTCBaseModel): +class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -114,27 +115,12 @@ class Score(ScoreBase, table=True): # optional beatmap: Beatmap = Relationship() - user: User = Relationship() + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: return self.max_combo == self.beatmap.max_combo - @staticmethod - def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]: - clause = select(Score).options( - joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - ) - if with_user: - return clause.options( - joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType] - ) - return clause - @staticmethod def select_clause_unique( *where_clauses: ColumnExpressionArgument[bool] | bool, @@ -148,18 +134,7 @@ class Score(ScoreBase, table=True): ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) - return ( - select(best) - .where(subq.c.rn == 1) - .options( - joinedload(best.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType] - ) - ) + return select(best).where(subq.c.rn == 1) class ScoreResp(ScoreBase): @@ -186,8 +161,9 @@ class ScoreResp(ScoreBase): ) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id - s.beatmap = BeatmapResp.from_db(score.beatmap) - s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) + await score.awaitable_attrs.beatmap + s.beatmap = await BeatmapResp.from_db(score.beatmap) + s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -303,7 +279,6 @@ async def get_leaderboard( query = ( select(Score) .join(Beatmap) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] .where( Score.map_md5 == beatmap_md5, Score.gamemode == mode, @@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap( ) -> Score | None: return ( await session.exec( - Score.select_clause(False) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, diff --git a/app/database/team.py b/app/database/team.py index 146ca9f..562b0c8 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: "User" = Relationship(back_populates="team_membership") - team: "Team" = Relationship(back_populates="members") + user: "User" = Relationship( + back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} + ) + team: "Team" = Relationship( + back_populates="members", sa_relationship_kwargs={"lazy": "joined"} + ) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 769247c..5537f4f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = ( - await db.exec( - select(User) - .options(*User.all_select_option()) - .where(User.id == token_record.user_id) - ) - ).first() + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 4f12e13..fae18ba 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -42,11 +42,12 @@ class Language(IntEnum): KOREAN = 6 FRENCH = 7 GERMAN = 8 - ITALIAN = 9 - SPANISH = 10 - RUSSIAN = 11 - POLISH = 12 - OTHER = 13 + SWEDISH = 9 + ITALIAN = 10 + SPANISH = 11 + RUSSIAN = 12 + POLISH = 13 + OTHER = 14 class BeatmapAttributes(BaseModel): diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 4af9c9a..df5f20d 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,7 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import Beatmap, BeatmapResp, Beatmapset, User +from app.database import Beatmap, BeatmapResp, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel from redis import Redis import rosu_pp_py as rosu -from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -51,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -63,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -83,35 +82,15 @@ async def batch_get_beatmaps( # select 50 beatmaps by last_updated beatmaps = ( await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .order_by(col(Beatmap.last_updated).desc()) - .limit(50) + select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) ) ).all() else: beatmaps = ( - await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .where(col(Beatmap.id).in_(b_ids)) - .limit(50) - ) + await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index 80396fd..b82678d 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -15,7 +15,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError -from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,13 +26,7 @@ async def get_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = ( - await db.exec( - select(Beatmapset) - .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] - .where(Beatmapset.id == sid) - ) - ).first() + beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() if not beatmapset: try: resp = await fetcher.get_beatmapset(sid) @@ -41,7 +34,7 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db(beatmapset) return resp diff --git a/app/router/relationship.py b/app/router/relationship.py index 9e39e8b..02292c9 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -9,7 +9,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query, Request from pydantic import BaseModel -from sqlalchemy.orm import joinedload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,14 +26,12 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship) - .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] - .where( + select(Relationship).where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) ) - return [await RelationshipResp.from_db(db, rel) for rel in relationships] + return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] class AddFriendResp(BaseModel): @@ -92,14 +89,10 @@ async def add_relationship( if origin_type == RelationshipType.FOLLOW: relationship = ( await db.exec( - select(Relationship) - .where( + select(Relationship).where( Relationship.user_id == current_user_id, Relationship.target_id == target, ) - .options( - joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] - ) ) ).first() assert relationship, "Relationship should exist after commit" diff --git a/app/router/score.py b/app/router/score.py index baab3a2..cd0a236 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -99,7 +99,7 @@ async def get_user_beatmap_score( ) user_score = ( await db.exec( - Score.select_clause(True) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, @@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores( ) all_user_scores = ( await db.exec( - Score.select_clause() + select(Score) .where( Score.gamemode == ruleset if ruleset is not None else True, Score.beatmap_id == beatmap, @@ -207,9 +207,7 @@ async def submit_solo_score( if score_token.score_id: score = ( await db.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( + select(Score).where( Score.id == score_token.score_id, Score.user_id == current_user.id, ) @@ -243,8 +241,6 @@ async def submit_solo_score( score_id = score.id score_token.score_id = score_id await process_user(db, current_user, score, ranked) - score = ( - await db.exec(Score.select_clause().where(Score.id == score_id)) - ).first() + score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None return await ScoreResp.from_db(db, score, current_user) diff --git a/app/router/user.py b/app/router/user.py index 3df5a49..649f1d4 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -28,19 +28,10 @@ async def get_users( ): if user_ids: searched_users = ( - await session.exec( - select(User) - .options(*User.all_select_option()) - .limit(50) - .where(col(User.id).in_(user_ids)) - ) + await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) ).all() else: - searched_users = ( - await session.exec( - select(User).options(*User.all_select_option()).limit(50) - ) - ).all() + searched_users = (await session.exec(select(User).limit(50))).all() return BatchUserResponse( users=[ await UserResp.from_db( @@ -63,9 +54,7 @@ async def get_user_info( ): searched_user = ( await session.exec( - select(User) - .options(*User.all_select_option()) - .where( + select(User).where( User.id == int(user) if user.isdigit() else User.username == user.removeprefix("@")