refactor(app): update database code

This commit is contained in:
MingxuanGame
2025-08-18 16:37:30 +00:00
parent 6bae937e01
commit 1c65b21bb9
34 changed files with 167 additions and 188 deletions

View File

@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import engine, get_db, get_redis
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
@@ -30,11 +30,10 @@ from fastapi import (
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with AsyncSession(engine) as session:
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
@@ -49,13 +48,13 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
redis = Depends(get_redis),
redis=Depends(get_redis),
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
@@ -112,9 +111,9 @@ async def search_beatmapset(
description=("通过谱面 ID 查询所属谱面集。"),
)
async def lookup_beatmapset(
db: Database,
beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
@@ -132,9 +131,9 @@ async def lookup_beatmapset(
description="获取单个谱面集详情。",
)
async def get_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
@@ -196,12 +195,12 @@ async def download_beatmapset(
description="**客户端专属**\n收藏或取消收藏指定谱面集。",
)
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id is not None
existing_favourite = (