refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -5,17 +5,16 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import (
Beatmap,
BeatmapModel,
BeatmapPlaycounts,
BeatmapPlaycountsResp,
BeatmapResp,
BeatmapsetResp,
BeatmapsetModel,
User,
UserResp,
)
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
from app.database.best_scores import BestScore
from app.database.events import Event
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
from app.database.score import Score, get_user_first_scores
from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.cache import UserCacheService
from app.dependencies.database import Database, get_redis
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.user_cache_service import get_user_cache_service
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
class BatchUserResponse(BaseModel):
users: list[UserResp]
class BeatmapsPassedResponse(BaseModel):
beatmaps_passed: list[BeatmapResp]
def _get_difficulty_reduction_mods() -> set[str]:
mods: set[str] = set()
for ruleset_mods in API_MODS.values():
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
@router.get(
"/users/",
response_model=BatchUserResponse,
responses={
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
},
name="批量获取用户信息",
description="通过用户 ID 列表批量获取用户信息。",
tags=["用户"],
)
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup", include_in_schema=False)
@router.get("/users/lookup/", include_in_schema=False)
@asset_proxy_response
async def get_users(
session: Database,
@@ -108,16 +100,15 @@ async def get_users(
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=cached_users)
response = {"users": cached_users}
return response
else:
searched_users = (
@@ -127,16 +118,15 @@ async def get_users(
for searched_user in searched_users:
if searched_user.id == BANCHOBOT_ID:
continue
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=users)
response = {"users": users}
return response
@@ -200,10 +190,12 @@ async def get_user_kudosu(
@router.get(
"/users/{user_id}/beatmaps-passed",
response_model=BeatmapsPassedResponse,
name="获取用户已通过谱面",
description="获取指定用户在给定谱面集中的已通过谱面列表。",
tags=["用户"],
responses={
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
},
)
@asset_proxy_response
async def get_user_beatmaps_passed(
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
):
if not beatmapset_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
if len(beatmapset_ids) > 50:
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
scores = (await session.exec(score_query)).all()
if not scores:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
passed_beatmap_ids: set[int] = set()
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
continue
passed_beatmap_ids.add(beatmap_id)
if not passed_beatmap_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
beatmaps = (
await session.exec(
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
)
).all()
return BeatmapsPassedResponse(
beatmaps_passed=[
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
return {
"beatmaps_passed": [
await BeatmapModel.transform(
beatmap,
)
for beatmap in beatmaps
]
)
}
@router.get(
"/users/{user_id}/{ruleset}",
response_model=UserResp,
name="获取用户信息(指定ruleset)",
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info_ruleset(
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
ruleset=ruleset,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
@router.get("/users/{user_id}/", include_in_schema=False)
@router.get(
"/users/{user_id}",
response_model=UserResp,
name="获取用户信息",
description="通过用户 ID 或用户名获取单个用户的详细信息。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info(
@@ -381,27 +375,31 @@ async def get_user_info(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
@router.get(
"/users/{user_id}/beatmapsets/{type}",
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
name="获取用户谱面集列表",
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
tags=["用户"],
responses={
200: api_doc(
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
beatmapset_includes,
)
},
)
@asset_proxy_response
async def get_user_beatmapsets(
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
return cached_result
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
await BeatmapsetModel.transform(
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
)
for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
resp = [
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
for most_played_beatmap in most_played
]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
@router.get(
"/users/{user_id}/scores/{type}",
response_model=list[ScoreResp] | list[LegacyScoreResp],
name="获取用户成绩列表",
description=(
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
@@ -523,6 +522,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
includes = Score.USER_PROFILE_INCLUDES.copy()
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -531,6 +531,7 @@ async def get_user_scores(
elif type == "best":
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
order_by = col(Score.pp).desc()
includes.append("weight")
elif type == "recent":
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
order_by = col(Score.ended_at).desc()
@@ -551,6 +552,7 @@ async def get_user_scores(
await score.to_resp(
session,
api_version,
includes=includes,
)
for score in scores
]