refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -5,6 +5,7 @@ from app.database import (
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
@@ -12,16 +13,24 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import Session, col, select
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
beatmap = (
await db.exec(
select(Beatmap)
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmap.id == bid)
)
).first()
if not beatmap:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
)
).all()
else:
beatmaps = db.exec(
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])