refactor(database): 优化数据库关联对象的载入 (#10)
This commit is contained in:
@@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode
|
|||||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||||
|
|
||||||
from sqlalchemy import DECIMAL, Column, DateTime
|
from sqlalchemy import DECIMAL, Column, DateTime
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||||
beatmap_status: BeatmapRankStatus
|
beatmap_status: BeatmapRankStatus
|
||||||
# optional
|
# optional
|
||||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
|
beatmapset: Beatmapset = Relationship(
|
||||||
|
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_ranked(self) -> bool:
|
def can_ranked(self) -> bool:
|
||||||
@@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
session.add(beatmap)
|
session.add(beatmap)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
beatmap = (
|
beatmap = (
|
||||||
await session.exec(
|
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
|
||||||
select(Beatmap)
|
|
||||||
.options(
|
|
||||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
.where(Beatmap.id == resp.id)
|
|
||||||
)
|
|
||||||
).first()
|
).first()
|
||||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||||
return beatmap
|
return beatmap
|
||||||
@@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
) -> "Beatmap":
|
) -> "Beatmap":
|
||||||
beatmap = (
|
beatmap = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Beatmap)
|
select(Beatmap).where(
|
||||||
.where(
|
|
||||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||||
)
|
)
|
||||||
.options(
|
|
||||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not beatmap:
|
if not beatmap:
|
||||||
@@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase):
|
|||||||
url: str = ""
|
url: str = ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(
|
async def from_db(
|
||||||
cls,
|
cls,
|
||||||
beatmap: Beatmap,
|
beatmap: Beatmap,
|
||||||
query_mode: GameMode | None = None,
|
query_mode: GameMode | None = None,
|
||||||
@@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase):
|
|||||||
beatmap_["ranked"] = beatmap.beatmap_status.value
|
beatmap_["ranked"] = beatmap.beatmap_status.value
|
||||||
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
|
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
|
||||||
if not from_set:
|
if not from_set:
|
||||||
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
|
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset)
|
||||||
return cls.model_validate(beatmap_)
|
return cls.model_validate(beatmap_)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from app.models.score import GameMode
|
|||||||
|
|
||||||
from pydantic import BaseModel, model_serializer
|
from pydantic import BaseModel, model_serializer
|
||||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import Field, Relationship, SQLModel
|
from sqlmodel import Field, Relationship, SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
|
|||||||
tags: str = Field(default="", sa_column=Column(Text))
|
tags: str = Field(default="", sa_column=Column(Text))
|
||||||
|
|
||||||
|
|
||||||
class Beatmapset(BeatmapsetBase, table=True):
|
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||||
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||||
@@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase):
|
|||||||
nominations: BeatmapNominations | None = None
|
nominations: BeatmapNominations | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
|
async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
|
||||||
from .beatmap import BeatmapResp
|
from .beatmap import BeatmapResp
|
||||||
|
|
||||||
beatmaps = [
|
beatmaps = [
|
||||||
BeatmapResp.from_db(beatmap, from_set=True)
|
await BeatmapResp.from_db(beatmap, from_set=True)
|
||||||
for beatmap in beatmapset.beatmaps
|
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||||
]
|
]
|
||||||
return cls.model_validate(
|
return cls.model_validate(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
|
|||||||
from .team import Team, TeamMember
|
from .team import Team, TeamMember
|
||||||
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
||||||
|
|
||||||
from sqlalchemy.orm import joinedload, selectinload
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
JSON,
|
JSON,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
@@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
|||||||
is_bng: bool = False
|
is_bng: bool = False
|
||||||
|
|
||||||
|
|
||||||
class User(UserBase, table=True):
|
class User(AsyncAttrs, UserBase, table=True):
|
||||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
@@ -154,17 +154,6 @@ class User(UserBase, table=True):
|
|||||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def all_select_option(cls):
|
|
||||||
return (
|
|
||||||
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserResp(UserBase):
|
class UserResp(UserBase):
|
||||||
id: int | None = None
|
id: int | None = None
|
||||||
@@ -249,13 +238,7 @@ class UserResp(UserBase):
|
|||||||
await RelationshipResp.from_db(session, r)
|
await RelationshipResp.from_db(session, r)
|
||||||
for r in (
|
for r in (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Relationship)
|
select(Relationship).where(
|
||||||
.options(
|
|
||||||
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
|
|
||||||
*User.all_select_option()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(
|
|
||||||
Relationship.user_id == obj.id,
|
Relationship.user_id == obj.id,
|
||||||
Relationship.type == RelationshipType.FOLLOW,
|
Relationship.type == RelationshipType.FOLLOW,
|
||||||
)
|
)
|
||||||
@@ -264,23 +247,26 @@ class UserResp(UserBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if "team" in include:
|
if "team" in include:
|
||||||
if obj.team_membership:
|
if await obj.awaitable_attrs.team_membership:
|
||||||
|
assert obj.team_membership
|
||||||
u.team = obj.team_membership.team
|
u.team = obj.team_membership.team
|
||||||
|
|
||||||
if "account_history" in include:
|
if "account_history" in include:
|
||||||
u.account_history = [
|
u.account_history = [
|
||||||
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history
|
UserAccountHistoryResp.from_db(ah)
|
||||||
|
for ah in await obj.awaitable_attrs.account_history
|
||||||
]
|
]
|
||||||
|
|
||||||
if "daily_challenge_user_stats":
|
if "daily_challenge_user_stats":
|
||||||
if obj.daily_challenge_stats:
|
if await obj.awaitable_attrs.daily_challenge_stats:
|
||||||
|
assert obj.daily_challenge_stats
|
||||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
||||||
obj.daily_challenge_stats
|
obj.daily_challenge_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
if "statistics" in include:
|
if "statistics" in include:
|
||||||
current_stattistics = None
|
current_stattistics = None
|
||||||
for i in obj.statistics:
|
for i in await obj.awaitable_attrs.statistics:
|
||||||
if i.mode == (ruleset or obj.playmode):
|
if i.mode == (ruleset or obj.playmode):
|
||||||
current_stattistics = i
|
current_stattistics = i
|
||||||
break
|
break
|
||||||
@@ -292,17 +278,20 @@ class UserResp(UserBase):
|
|||||||
|
|
||||||
if "statistics_rulesets" in include:
|
if "statistics_rulesets" in include:
|
||||||
u.statistics_rulesets = {
|
u.statistics_rulesets = {
|
||||||
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics
|
i.mode.value: UserStatisticsResp.from_db(i)
|
||||||
|
for i in await obj.awaitable_attrs.statistics
|
||||||
}
|
}
|
||||||
|
|
||||||
if "monthly_playcounts" in include:
|
if "monthly_playcounts" in include:
|
||||||
u.monthly_playcounts = [
|
u.monthly_playcounts = [
|
||||||
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts
|
MonthlyPlaycountsResp.from_db(pc)
|
||||||
|
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||||
]
|
]
|
||||||
|
|
||||||
if "achievements" in include:
|
if "achievements" in include:
|
||||||
u.user_achievements = [
|
u.user_achievements = [
|
||||||
UserAchievementResp.from_db(ua) for ua in obj.achievement
|
UserAchievementResp.from_db(ua)
|
||||||
|
for ua in await obj.awaitable_attrs.achievement
|
||||||
]
|
]
|
||||||
|
|
||||||
return u
|
return u
|
||||||
|
|||||||
@@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True):
|
|||||||
)
|
)
|
||||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||||
target: User = SQLRelationship(
|
target: User = SQLRelationship(
|
||||||
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
sa_relationship_kwargs={
|
||||||
|
"foreign_keys": "[Relationship.target_id]",
|
||||||
|
"lazy": "selectin",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +82,6 @@ class RelationshipResp(BaseModel):
|
|||||||
"daily_challenge_user_stats",
|
"daily_challenge_user_stats",
|
||||||
"statistics",
|
"statistics",
|
||||||
"statistics_rulesets",
|
"statistics_rulesets",
|
||||||
"achievements",
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
mutual=mutual,
|
mutual=mutual,
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from app.models.score import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from .beatmap import Beatmap, BeatmapResp
|
||||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
from .beatmapset import BeatmapsetResp
|
||||||
from .best_score import BestScore
|
from .best_score import BestScore
|
||||||
from .lazer_user import User, UserResp
|
from .lazer_user import User, UserResp
|
||||||
from .monthly_playcounts import MonthlyPlaycounts
|
from .monthly_playcounts import MonthlyPlaycounts
|
||||||
@@ -35,7 +35,8 @@ from .score_token import ScoreToken
|
|||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||||
from sqlalchemy.orm import aliased, joinedload
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
JSON,
|
JSON,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
@@ -55,7 +56,7 @@ if TYPE_CHECKING:
|
|||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
|
|
||||||
|
|
||||||
class ScoreBase(SQLModel, UTCBaseModel):
|
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||||
# 基本字段
|
# 基本字段
|
||||||
accuracy: float
|
accuracy: float
|
||||||
map_md5: str = Field(max_length=32, index=True)
|
map_md5: str = Field(max_length=32, index=True)
|
||||||
@@ -114,27 +115,12 @@ class Score(ScoreBase, table=True):
|
|||||||
|
|
||||||
# optional
|
# optional
|
||||||
beatmap: Beatmap = Relationship()
|
beatmap: Beatmap = Relationship()
|
||||||
user: User = Relationship()
|
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_perfect_combo(self) -> bool:
|
def is_perfect_combo(self) -> bool:
|
||||||
return self.max_combo == self.beatmap.max_combo
|
return self.max_combo == self.beatmap.max_combo
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
|
|
||||||
clause = select(Score).options(
|
|
||||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
|
||||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
|
||||||
.selectinload(
|
|
||||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if with_user:
|
|
||||||
return clause.options(
|
|
||||||
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
return clause
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_clause_unique(
|
def select_clause_unique(
|
||||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||||
@@ -148,18 +134,7 @@ class Score(ScoreBase, table=True):
|
|||||||
)
|
)
|
||||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||||
best = aliased(Score, subq, adapt_on_names=True)
|
best = aliased(Score, subq, adapt_on_names=True)
|
||||||
return (
|
return select(best).where(subq.c.rn == 1)
|
||||||
select(best)
|
|
||||||
.where(subq.c.rn == 1)
|
|
||||||
.options(
|
|
||||||
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
|
|
||||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
|
||||||
.selectinload(
|
|
||||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
|
||||||
),
|
|
||||||
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreResp(ScoreBase):
|
class ScoreResp(ScoreBase):
|
||||||
@@ -186,8 +161,9 @@ class ScoreResp(ScoreBase):
|
|||||||
) -> "ScoreResp":
|
) -> "ScoreResp":
|
||||||
s = cls.model_validate(score.model_dump())
|
s = cls.model_validate(score.model_dump())
|
||||||
assert score.id
|
assert score.id
|
||||||
s.beatmap = BeatmapResp.from_db(score.beatmap)
|
await score.awaitable_attrs.beatmap
|
||||||
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||||
|
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
||||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||||
s.ruleset_id = MODE_TO_INT[score.gamemode]
|
s.ruleset_id = MODE_TO_INT[score.gamemode]
|
||||||
@@ -303,7 +279,6 @@ async def get_leaderboard(
|
|||||||
query = (
|
query = (
|
||||||
select(Score)
|
select(Score)
|
||||||
.join(Beatmap)
|
.join(Beatmap)
|
||||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
|
||||||
.where(
|
.where(
|
||||||
Score.map_md5 == beatmap_md5,
|
Score.map_md5 == beatmap_md5,
|
||||||
Score.gamemode == mode,
|
Score.gamemode == mode,
|
||||||
@@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap(
|
|||||||
) -> Score | None:
|
) -> Score | None:
|
||||||
return (
|
return (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
Score.select_clause(False)
|
select(Score)
|
||||||
.where(
|
.where(
|
||||||
Score.gamemode == mode if mode is not None else True,
|
Score.gamemode == mode if mode is not None else True,
|
||||||
Score.beatmap_id == beatmap,
|
Score.beatmap_id == beatmap,
|
||||||
|
|||||||
@@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True):
|
|||||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||||
)
|
)
|
||||||
|
|
||||||
user: "User" = Relationship(back_populates="team_membership")
|
user: "User" = Relationship(
|
||||||
team: "Team" = Relationship(back_populates="members")
|
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
|
||||||
|
)
|
||||||
|
team: "Team" = Relationship(
|
||||||
|
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None
|
|||||||
token_record = await get_token_by_access_token(db, token)
|
token_record = await get_token_by_access_token(db, token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
return None
|
return None
|
||||||
user = (
|
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||||
await db.exec(
|
|
||||||
select(User)
|
|
||||||
.options(*User.all_select_option())
|
|
||||||
.where(User.id == token_record.user_id)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -42,11 +42,12 @@ class Language(IntEnum):
|
|||||||
KOREAN = 6
|
KOREAN = 6
|
||||||
FRENCH = 7
|
FRENCH = 7
|
||||||
GERMAN = 8
|
GERMAN = 8
|
||||||
ITALIAN = 9
|
SWEDISH = 9
|
||||||
SPANISH = 10
|
ITALIAN = 10
|
||||||
RUSSIAN = 11
|
SPANISH = 11
|
||||||
POLISH = 12
|
RUSSIAN = 12
|
||||||
OTHER = 13
|
POLISH = 13
|
||||||
|
OTHER = 14
|
||||||
|
|
||||||
|
|
||||||
class BeatmapAttributes(BaseModel):
|
class BeatmapAttributes(BaseModel):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from app.calculator import calculate_beatmap_attribute
|
from app.calculator import calculate_beatmap_attribute
|
||||||
from app.database import Beatmap, BeatmapResp, Beatmapset, User
|
from app.database import Beatmap, BeatmapResp, User
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import get_db, get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
@@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
import rosu_pp_py as rosu
|
import rosu_pp_py as rosu
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ async def lookup_beatmap(
|
|||||||
if beatmap is None:
|
if beatmap is None:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
|
|
||||||
return BeatmapResp.from_db(beatmap)
|
return await BeatmapResp.from_db(beatmap)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||||
@@ -63,7 +62,7 @@ async def get_beatmap(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||||
return BeatmapResp.from_db(beatmap)
|
return await BeatmapResp.from_db(beatmap)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
|
|
||||||
@@ -83,35 +82,15 @@ async def batch_get_beatmaps(
|
|||||||
# select 50 beatmaps by last_updated
|
# select 50 beatmaps by last_updated
|
||||||
beatmaps = (
|
beatmaps = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
select(Beatmap)
|
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||||
.options(
|
|
||||||
joinedload(
|
|
||||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
|
||||||
).selectinload(
|
|
||||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(col(Beatmap.last_updated).desc())
|
|
||||||
.limit(50)
|
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
else:
|
else:
|
||||||
beatmaps = (
|
beatmaps = (
|
||||||
await db.exec(
|
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||||
select(Beatmap)
|
|
||||||
.options(
|
|
||||||
joinedload(
|
|
||||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
|
||||||
).selectinload(
|
|
||||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(col(Beatmap.id).in_(b_ids))
|
|
||||||
.limit(50)
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from .api_router import router
|
|||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from httpx import HTTPStatusError
|
from httpx import HTTPStatusError
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -27,13 +26,7 @@ async def get_beatmapset(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
beatmapset = (
|
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||||
await db.exec(
|
|
||||||
select(Beatmapset)
|
|
||||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
|
||||||
.where(Beatmapset.id == sid)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
if not beatmapset:
|
if not beatmapset:
|
||||||
try:
|
try:
|
||||||
resp = await fetcher.get_beatmapset(sid)
|
resp = await fetcher.get_beatmapset(sid)
|
||||||
@@ -41,7 +34,7 @@ async def get_beatmapset(
|
|||||||
except HTTPStatusError:
|
except HTTPStatusError:
|
||||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||||
else:
|
else:
|
||||||
resp = BeatmapsetResp.from_db(beatmapset)
|
resp = await BeatmapsetResp.from_db(beatmapset)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from .api_router import router
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, Request
|
from fastapi import Depends, HTTPException, Query, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -27,14 +26,12 @@ async def get_relationship(
|
|||||||
else RelationshipType.BLOCK
|
else RelationshipType.BLOCK
|
||||||
)
|
)
|
||||||
relationships = await db.exec(
|
relationships = await db.exec(
|
||||||
select(Relationship)
|
select(Relationship).where(
|
||||||
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
|
||||||
.where(
|
|
||||||
Relationship.user_id == current_user.id,
|
Relationship.user_id == current_user.id,
|
||||||
Relationship.type == relationship_type,
|
Relationship.type == relationship_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
||||||
|
|
||||||
|
|
||||||
class AddFriendResp(BaseModel):
|
class AddFriendResp(BaseModel):
|
||||||
@@ -92,14 +89,10 @@ async def add_relationship(
|
|||||||
if origin_type == RelationshipType.FOLLOW:
|
if origin_type == RelationshipType.FOLLOW:
|
||||||
relationship = (
|
relationship = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
select(Relationship)
|
select(Relationship).where(
|
||||||
.where(
|
|
||||||
Relationship.user_id == current_user_id,
|
Relationship.user_id == current_user_id,
|
||||||
Relationship.target_id == target,
|
Relationship.target_id == target,
|
||||||
)
|
)
|
||||||
.options(
|
|
||||||
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
assert relationship, "Relationship should exist after commit"
|
assert relationship, "Relationship should exist after commit"
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ async def get_user_beatmap_score(
|
|||||||
)
|
)
|
||||||
user_score = (
|
user_score = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
Score.select_clause(True)
|
select(Score)
|
||||||
.where(
|
.where(
|
||||||
Score.gamemode == mode if mode is not None else True,
|
Score.gamemode == mode if mode is not None else True,
|
||||||
Score.beatmap_id == beatmap,
|
Score.beatmap_id == beatmap,
|
||||||
@@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores(
|
|||||||
)
|
)
|
||||||
all_user_scores = (
|
all_user_scores = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
Score.select_clause()
|
select(Score)
|
||||||
.where(
|
.where(
|
||||||
Score.gamemode == ruleset if ruleset is not None else True,
|
Score.gamemode == ruleset if ruleset is not None else True,
|
||||||
Score.beatmap_id == beatmap,
|
Score.beatmap_id == beatmap,
|
||||||
@@ -207,9 +207,7 @@ async def submit_solo_score(
|
|||||||
if score_token.score_id:
|
if score_token.score_id:
|
||||||
score = (
|
score = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
select(Score)
|
select(Score).where(
|
||||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
|
||||||
.where(
|
|
||||||
Score.id == score_token.score_id,
|
Score.id == score_token.score_id,
|
||||||
Score.user_id == current_user.id,
|
Score.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
@@ -243,8 +241,6 @@ async def submit_solo_score(
|
|||||||
score_id = score.id
|
score_id = score.id
|
||||||
score_token.score_id = score_id
|
score_token.score_id = score_id
|
||||||
await process_user(db, current_user, score, ranked)
|
await process_user(db, current_user, score, ranked)
|
||||||
score = (
|
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
|
||||||
).first()
|
|
||||||
assert score is not None
|
assert score is not None
|
||||||
return await ScoreResp.from_db(db, score, current_user)
|
return await ScoreResp.from_db(db, score, current_user)
|
||||||
|
|||||||
@@ -28,19 +28,10 @@ async def get_users(
|
|||||||
):
|
):
|
||||||
if user_ids:
|
if user_ids:
|
||||||
searched_users = (
|
searched_users = (
|
||||||
await session.exec(
|
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
|
||||||
select(User)
|
|
||||||
.options(*User.all_select_option())
|
|
||||||
.limit(50)
|
|
||||||
.where(col(User.id).in_(user_ids))
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
else:
|
else:
|
||||||
searched_users = (
|
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||||
await session.exec(
|
|
||||||
select(User).options(*User.all_select_option()).limit(50)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
return BatchUserResponse(
|
return BatchUserResponse(
|
||||||
users=[
|
users=[
|
||||||
await UserResp.from_db(
|
await UserResp.from_db(
|
||||||
@@ -63,9 +54,7 @@ async def get_user_info(
|
|||||||
):
|
):
|
||||||
searched_user = (
|
searched_user = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(User)
|
select(User).where(
|
||||||
.options(*User.all_select_option())
|
|
||||||
.where(
|
|
||||||
User.id == int(user)
|
User.id == int(user)
|
||||||
if user.isdigit()
|
if user.isdigit()
|
||||||
else User.username == user.removeprefix("@")
|
else User.username == user.removeprefix("@")
|
||||||
|
|||||||
Reference in New Issue
Block a user