refactor(app): update database code

This commit is contained in:
MingxuanGame
2025-08-18 16:37:30 +00:00
parent 6bae937e01
commit 1c65b21bb9
34 changed files with 167 additions and 188 deletions

View File

@@ -15,17 +15,16 @@ from app.database.events import EventResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
@@ -43,6 +42,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users(
session: Database,
user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
),
@@ -50,7 +50,6 @@ async def get_users(
include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
session: AsyncSession = Depends(get_db),
):
if user_ids:
searched_users = (
@@ -79,9 +78,9 @@ async def get_users(
tags=["用户"],
)
async def get_user_info_ruleset(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
@@ -112,8 +111,8 @@ async def get_user_info_ruleset(
tags=["用户"],
)
async def get_user_info(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
@@ -142,10 +141,10 @@ async def get_user_info(
tags=["用户"],
)
async def get_user_beatmapsets(
session: Database,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
):
@@ -202,6 +201,7 @@ async def get_user_beatmapsets(
tags=["用户"],
)
async def get_user_scores(
session: Database,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=(
@@ -216,7 +216,6 @@ async def get_user_scores(
),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_user = await session.get(User, user_id)
@@ -267,10 +266,10 @@ async def get_user_scores(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
async def get_user_events(
session: Database,
user: int,
limit: int | None = Query(None),
offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移
session: AsyncSession = Depends(get_db),
):
db_user = await session.get(User, user)
if db_user is None or db_user.id == BANCHOBOT_ID: