refactor(database): 优化数据库关联对象的载入 (#10)

This commit is contained in:
MingxuanGame
2025-07-31 20:11:22 +08:00
committed by GitHub
parent 1281e75bb1
commit be401e8885
13 changed files with 73 additions and 166 deletions

View File

@@ -5,7 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import Beatmap, BeatmapResp, Beatmapset, User
from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis import Redis
import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -51,7 +50,7 @@ async def lookup_beatmap(
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
@@ -63,7 +62,7 @@ async def get_beatmap(
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -83,35 +82,15 @@ async def batch_get_beatmaps(
# select 50 beatmaps by last_updated
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else:
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps])
@router.post(

View File

@@ -15,7 +15,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,13 +26,7 @@ async def get_beatmapset(
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
@@ -41,7 +34,7 @@ async def get_beatmapset(
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = BeatmapsetResp.from_db(beatmapset)
resp = await BeatmapsetResp.from_db(beatmapset)
return resp

View File

@@ -9,7 +9,6 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from pydantic import BaseModel
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -27,14 +26,12 @@ async def get_relationship(
else RelationshipType.BLOCK
)
relationships = await db.exec(
select(Relationship)
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
select(Relationship).where(
Relationship.user_id == current_user.id,
Relationship.type == relationship_type,
)
)
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
class AddFriendResp(BaseModel):
@@ -92,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
select(Relationship)
.where(
select(Relationship).where(
Relationship.user_id == current_user_id,
Relationship.target_id == target,
)
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
)
).first()
assert relationship, "Relationship should exist after commit"

View File

@@ -99,7 +99,7 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
Score.select_clause(True)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
@@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
Score.select_clause()
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
@@ -207,9 +207,7 @@ async def submit_solo_score(
if score_token.score_id:
score = (
await db.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
@@ -243,8 +241,6 @@ async def submit_solo_score(
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score, current_user)

View File

@@ -28,19 +28,10 @@ async def get_users(
):
if user_ids:
searched_users = (
await session.exec(
select(User)
.options(*User.all_select_option())
.limit(50)
.where(col(User.id).in_(user_ids))
)
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
else:
searched_users = (
await session.exec(
select(User).options(*User.all_select_option()).limit(50)
)
).all()
searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
await UserResp.from_db(
@@ -63,9 +54,7 @@ async def get_user_info(
):
searched_user = (
await session.exec(
select(User)
.options(*User.all_select_option())
.where(
select(User).where(
User.id == int(user)
if user.isdigit()
else User.username == user.removeprefix("@")