refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -5,17 +5,16 @@ from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapModel,
|
||||
BeatmapPlaycounts,
|
||||
BeatmapPlaycountsResp,
|
||||
BeatmapResp,
|
||||
BeatmapsetResp,
|
||||
BeatmapsetModel,
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
|
||||
from app.database.best_scores import BestScore
|
||||
from app.database.events import Event
|
||||
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
|
||||
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
|
||||
from app.database.score import Score, get_user_first_scores
|
||||
from app.database.user import UserModel
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.cache import UserCacheService
|
||||
from app.dependencies.database import Database, get_redis
|
||||
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import BeatmapsetType
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
from app.utils import utcnow
|
||||
from app.utils import api_doc, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[UserResp]
|
||||
|
||||
|
||||
class BeatmapsPassedResponse(BaseModel):
|
||||
beatmaps_passed: list[BeatmapResp]
|
||||
|
||||
|
||||
def _get_difficulty_reduction_mods() -> set[str]:
|
||||
mods: set[str] = set()
|
||||
for ruleset_mods in API_MODS.values():
|
||||
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
|
||||
|
||||
@router.get(
|
||||
"/users/",
|
||||
response_model=BatchUserResponse,
|
||||
responses={
|
||||
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
|
||||
},
|
||||
name="批量获取用户信息",
|
||||
description="通过用户 ID 列表批量获取用户信息。",
|
||||
tags=["用户"],
|
||||
)
|
||||
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
@router.get("/users/lookup", include_in_schema=False)
|
||||
@router.get("/users/lookup/", include_in_schema=False)
|
||||
@asset_proxy_response
|
||||
async def get_users(
|
||||
session: Database,
|
||||
@@ -108,16 +100,15 @@ async def get_users(
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id != BANCHOBOT_ID:
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
includes=User.CARD_INCLUDES,
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
response = BatchUserResponse(users=cached_users)
|
||||
response = {"users": cached_users}
|
||||
return response
|
||||
else:
|
||||
searched_users = (
|
||||
@@ -127,16 +118,15 @@ async def get_users(
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id == BANCHOBOT_ID:
|
||||
continue
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
includes=User.CARD_INCLUDES,
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
response = BatchUserResponse(users=users)
|
||||
response = {"users": users}
|
||||
return response
|
||||
|
||||
|
||||
@@ -200,10 +190,12 @@ async def get_user_kudosu(
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/beatmaps-passed",
|
||||
response_model=BeatmapsPassedResponse,
|
||||
name="获取用户已通过谱面",
|
||||
description="获取指定用户在给定谱面集中的已通过谱面列表。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_beatmaps_passed(
|
||||
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
|
||||
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
|
||||
):
|
||||
if not beatmapset_ids:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
if len(beatmapset_ids) > 50:
|
||||
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
|
||||
|
||||
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
|
||||
|
||||
scores = (await session.exec(score_query)).all()
|
||||
if not scores:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
|
||||
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
|
||||
passed_beatmap_ids: set[int] = set()
|
||||
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
|
||||
continue
|
||||
passed_beatmap_ids.add(beatmap_id)
|
||||
if not passed_beatmap_ids:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
|
||||
beatmaps = (
|
||||
await session.exec(
|
||||
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
|
||||
)
|
||||
).all()
|
||||
|
||||
return BeatmapsPassedResponse(
|
||||
beatmaps_passed=[
|
||||
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
|
||||
return {
|
||||
"beatmaps_passed": [
|
||||
await BeatmapModel.transform(
|
||||
beatmap,
|
||||
)
|
||||
for beatmap in beatmaps
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/{ruleset}",
|
||||
response_model=UserResp,
|
||||
name="获取用户信息(指定ruleset)",
|
||||
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_info_ruleset(
|
||||
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
|
||||
if should_not_show:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
include = SEARCH_INCLUDED
|
||||
if searched_is_self:
|
||||
include = ALL_INCLUDED
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=include,
|
||||
includes=User.USER_INCLUDES,
|
||||
ruleset=ruleset,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
|
||||
@router.get("/users/{user_id}/", include_in_schema=False)
|
||||
@router.get(
|
||||
"/users/{user_id}",
|
||||
response_model=UserResp,
|
||||
name="获取用户信息",
|
||||
description="通过用户 ID 或用户名获取单个用户的详细信息。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_info(
|
||||
@@ -381,27 +375,31 @@ async def get_user_info(
|
||||
if should_not_show:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
include = SEARCH_INCLUDED
|
||||
if searched_is_self:
|
||||
include = ALL_INCLUDED
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=include,
|
||||
includes=User.USER_INCLUDES,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/beatmapsets/{type}",
|
||||
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
|
||||
name="获取用户谱面集列表",
|
||||
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
|
||||
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
|
||||
beatmapset_includes,
|
||||
)
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_beatmapsets(
|
||||
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
return [BeatmapPlaycountsResp(**item) for item in cached_result]
|
||||
else:
|
||||
return [BeatmapsetResp(**item) for item in cached_result]
|
||||
return cached_result
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if not user or user.id == BANCHOBOT_ID:
|
||||
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
await BeatmapsetModel.transform(
|
||||
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
|
||||
)
|
||||
for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
resp = [
|
||||
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/scores/{type}",
|
||||
response_model=list[ScoreResp] | list[LegacyScoreResp],
|
||||
name="获取用户成绩列表",
|
||||
description=(
|
||||
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
|
||||
@@ -523,6 +522,7 @@ async def get_user_scores(
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
includes = Score.USER_PROFILE_INCLUDES.copy()
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -531,6 +531,7 @@ async def get_user_scores(
|
||||
elif type == "best":
|
||||
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
|
||||
order_by = col(Score.pp).desc()
|
||||
includes.append("weight")
|
||||
elif type == "recent":
|
||||
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
|
||||
order_by = col(Score.ended_at).desc()
|
||||
@@ -551,6 +552,7 @@ async def get_user_scores(
|
||||
await score.to_resp(
|
||||
session,
|
||||
api_version,
|
||||
includes=includes,
|
||||
)
|
||||
for score in scores
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user