refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -447,7 +447,7 @@ async def create_multiplayer_room(
# 让房主加入频道
host_user = await db.get(User, host_user_id)
if host_user:
await server.batch_join_channel([host_user], channel, db)
await server.batch_join_channel([host_user], channel)
# Add playlist items
await _add_playlist_items(db, room_id, room_data, host_user_id)

View File

@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
from math import ceil
import random
import shlex
from typing import TYPE_CHECKING
from app.calculator import calculate_weighted_pp
from app.const import BANCHOBOT_ID
from app.database import ChatMessageResp
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
if TYPE_CHECKING:
pass
from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
from app.database.score import Score, get_best_id
from app.database.statistics import UserStatistics, get_rank
from app.database.user import User
@@ -95,7 +98,7 @@ class Bot:
await session.commit()
await session.refresh(msg)
await session.refresh(bot)
resp = await ChatMessageResp.from_db(msg, session, bot)
resp = await ChatMessageModel.transform(msg, includes=["sender"])
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
@@ -119,7 +122,7 @@ class Bot:
await session.refresh(channel)
await session.refresh(user)
await session.refresh(bot)
await server.batch_join_channel([user, bot], channel, session)
await server.batch_join_channel([user, bot], channel)
return channel
async def _send_reply(

View File

@@ -1,37 +1,40 @@
from typing import Annotated, Any, Literal, Self
from typing import Annotated, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatChannelModel,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
from app.database.user import User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from app.utils import api_doc
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from sqlmodel import col, select
class UpdateResponse(BaseModel):
presence: list[ChatChannelResp] = Field(default_factory=list)
silences: list[Any] = Field(default_factory=list)
@router.get(
"/chat/updates",
response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
responses={
200: api_doc(
"获取更新响应。",
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
ChatChannel.LISTING_INCLUDES,
name="UpdateResponse",
)
},
)
async def get_update(
session: Database,
@@ -44,45 +47,44 @@ async def get_update(
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
):
resp = UpdateResponse()
resp = {
"presence": [],
"silences": [],
}
if "presence" in includes:
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
resp["presence"].append(
await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
)
)
if "silences" in includes:
if history_since:
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
)
async def join_channel(
session: Database,
@@ -101,7 +103,7 @@ async def join_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
return await server.join_channel(current_user, db_channel, session)
return await server.join_channel(current_user, db_channel)
@router.delete(
@@ -128,13 +130,13 @@ async def leave_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
await server.leave_channel(current_user, db_channel, session)
await server.leave_channel(current_user, db_channel)
return
@router.get(
"/chat/channels",
response_model=list[ChatChannelResp],
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
@@ -142,35 +144,30 @@ async def leave_channel(
async def get_channel_list(
session: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
results = await ChatChannelModel.transform_many(
channels,
user=current_user,
server=server,
)
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
class GetChannelResp(BaseModel):
channel: ChatChannelResp
users: list[UserResp] = Field(default_factory=list)
@router.get(
"/chat/channels/{channel}",
response_model=GetChannelResp,
responses={
200: api_doc(
"频道详细信息",
{
"channel": ChatChannelModel,
"users": list[UserModel],
},
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
name="GetChannelResponse",
)
},
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
@@ -191,7 +188,6 @@ async def get_channel(
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
@@ -209,15 +205,15 @@ async def get_channel(
users.extend([target_user, current_user])
break
return GetChannelResp(
channel=await ChatChannelResp.from_db(
return {
"channel": await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
),
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
}
class CreateChannelReq(BaseModel):
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
@router.post(
"/chat/channels",
response_model=ChatChannelResp,
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
@@ -289,21 +285,13 @@ async def create_channel(
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.batch_join_channel([*target_users, current_user], channel)
await server.join_channel(current_user, channel, session)
await server.join_channel(current_user, channel)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []),
include_recent_messages=True,
return await ChatChannelModel.transform(
channel, user=current_user, server=server, includes=["recent_messages.sender"]
)

View File

@@ -1,11 +1,11 @@
from typing import Annotated
from app.database import ChatMessageResp
from app.database import ChatChannelModel
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
ChatMessageModel,
MessageType,
SilenceUser,
UserSilenceResp,
@@ -18,6 +18,7 @@ from app.log import log
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from app.service.redis_message_system import redis_message_system
from app.utils import api_doc
from .banchobot import bot
from .server import server
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
@router.post(
"/chat/channels/{channel}/messages",
response_model=ChatMessageResp,
responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
name="发送消息",
description="发送消息到指定频道。",
tags=["聊天"],
@@ -130,7 +131,7 @@ async def send_message(
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage(
message_id=resp.message_id, # 使用 Redis 系统生成的ID
message_id=resp["message_id"], # 使用 Redis 系统生成的ID
channel_id=channel_id,
content=req.message,
sender_id=user_id,
@@ -151,7 +152,7 @@ async def send_message(
@router.get(
"/chat/channels/{channel}/messages",
response_model=list[ChatMessageResp],
responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
name="获取消息",
description="获取指定频道的消息列表(统一按时间正序返回)。",
tags=["聊天"],
@@ -177,7 +178,7 @@ async def get_message(
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
messages.reverse()
return messages
except Exception as e:
@@ -189,7 +190,7 @@ async def get_message(
# 向前加载新消息 → 直接 ASC
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
rows = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
# 已经 ASC无需反转
return resp
@@ -202,15 +203,14 @@ async def get_message(
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
return resp
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
uuid: str | None = None
class NewPMResp(BaseModel):
channel: ChatChannelResp
message: ChatMessageResp
new_channel_id: int
@router.post(
"/chat/new",
name="创建私聊频道",
description="创建一个新的私聊频道。",
tags=["聊天"],
responses={
200: api_doc(
"创建私聊频道响应",
{
"channel": ChatChannelModel,
"message": ChatMessageModel,
"new_channel_id": int,
},
["recent_messages.sender", "sender"],
name="NewPMResponse",
)
},
)
async def create_new_pm(
session: Database,
@@ -290,9 +296,9 @@ async def create_new_pm(
await session.refresh(target)
await session.refresh(current_user)
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]
await server.batch_join_channel([target, current_user], channel)
channel_resp = await ChatChannelModel.transform(
channel, user=current_user, server=server, includes=["recent_messages.sender"]
)
msg = ChatMessage(
channel_id=channel.channel_id,
@@ -306,10 +312,10 @@ async def create_new_pm(
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
await server.send_message_to_channel(message_resp)
return NewPMResp(
channel=channel_resp,
message=message_resp,
new_channel_id=channel_resp.channel_id,
)
return {
"channel": channel_resp,
"message": message_resp,
"new_channel_id": channel_resp["channel_id"],
}

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import Annotated, overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database import ChatMessageDict
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
from app.database.notification import UserNotification, insert_notification
from app.database.user import User
from app.dependencies.database import (
@@ -16,7 +17,7 @@ from app.log import log
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from app.utils import bg_tasks, safe_json_dumps
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -65,7 +66,7 @@ class ChatServer:
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
await self.leave_channel(user, db_channel)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@@ -80,7 +81,7 @@ class ChatServer:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
await client.send_text(safe_json_dumps(event))
async def broadcast(self, channel_id: int, event: ChatEvent):
users_in_channel = self.channels.get(channel_id, [])
@@ -107,38 +108,38 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
f"Sending message to channel {message['channel_id']}, message_id: "
f"{message['message_id']}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
bg_tasks.add_task(self.send_event, message.sender_id, event)
logger.info(f"Sending bot command to user {message['sender_id']}")
bg_tasks.add_task(self.send_event, message["sender_id"], event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
message["channel_id"],
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
if message["message_id"] and message["message_id"] > 0:
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
channel_id = channel.channel_id
not_joined = []
@@ -151,22 +152,18 @@ class ChatServer:
not_joined.append(user)
for user in not_joined:
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
channel_resp = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
user_id = user.id
channel_id = channel.channel_id
@@ -175,25 +172,21 @@ class ChatServer:
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
return channel_resp
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
user_id = user.id
channel_id = channel.channel_id
@@ -203,18 +196,14 @@ class ChatServer:
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
channel_resp = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
@@ -232,7 +221,7 @@ class ChatServer:
return
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
await self.join_channel(user, db_channel, session)
await self.join_channel(user, db_channel)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
@@ -248,7 +237,7 @@ class ChatServer:
return
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
await self.leave_channel(user, db_channel, session)
await self.leave_channel(user, db_channel)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
@@ -336,6 +325,6 @@ async def chat_websocket(
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)
await server.join_channel(user, db_channel)
await _listen_stop(websocket, user_id, factory)

View File

@@ -2,7 +2,7 @@ import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
from app.database.user import BASE_INCLUDES, User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
@@ -14,12 +14,11 @@ from app.models.notification import (
from app.models.score import GameMode
from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service
from app.utils import check_image, utcnow
from app.utils import api_doc, check_image, utcnow
from .router import router
from fastapi import File, Form, HTTPException, Path, Query, Request
from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -214,12 +213,22 @@ async def delete_team(
await cache_service.invalidate_team_cache()
class TeamQueryResp(BaseModel):
team: TeamResp
members: list[UserResp]
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
@router.get(
"/team/{team_id}",
name="查询战队",
tags=["战队", "g0v0 API"],
responses={
200: api_doc(
"战队信息",
{
"team": TeamResp,
"members": list[UserModel],
},
["statistics", "country"],
name="TeamQueryResp",
)
},
)
async def get_team(
session: Database,
team_id: Annotated[int, Path(..., description="战队 ID")],
@@ -233,10 +242,10 @@ async def get_team(
)
)
).all()
return TeamQueryResp(
team=await TeamResp.from_db(members[0].team, session, gamemode),
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
)
return {
"team": await TeamResp.from_db(members[0].team, session, gamemode),
"members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]),
}
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Annotated, Literal
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.statistics import UserStatistics, UserStatisticsModel
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.log import logger
@@ -46,7 +46,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}"
@classmethod
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User":
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -54,31 +54,33 @@ class V1User(AllStrModel):
current_statistics = i
break
if current_statistics:
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
statistics = await UserStatisticsModel.transform(
current_statistics, country_code=db_user.country_code, includes=["country_rank"]
)
else:
statistics = None
return cls(
user_id=db_user.id,
username=db_user.username,
join_date=db_user.join_date,
count300=statistics.count_300 if statistics else 0,
count100=statistics.count_100 if statistics else 0,
count50=statistics.count_50 if statistics else 0,
playcount=statistics.play_count if statistics else 0,
ranked_score=statistics.ranked_score if statistics else 0,
total_score=statistics.total_score if statistics else 0,
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
count300=current_statistics.count_300 if current_statistics else 0,
count100=current_statistics.count_100 if current_statistics else 0,
count50=current_statistics.count_50 if current_statistics else 0,
playcount=current_statistics.play_count if current_statistics else 0,
ranked_score=current_statistics.ranked_score if current_statistics else 0,
total_score=current_statistics.total_score if current_statistics else 0,
pp_rank=statistics.get("global_rank") or 0 if statistics else 0,
level=current_statistics.level_current if current_statistics else 0,
pp_raw=statistics.pp if statistics else 0.0,
accuracy=statistics.hit_accuracy if statistics else 0,
pp_raw=current_statistics.pp if current_statistics else 0.0,
accuracy=current_statistics.hit_accuracy if current_statistics else 0,
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
count_rank_s=current_statistics.grade_s if current_statistics else 0,
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code,
total_seconds_played=statistics.play_time if statistics else 0,
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
total_seconds_played=current_statistics.play_time if current_statistics else 0,
pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0,
events=[], # TODO
)
@@ -134,7 +136,7 @@ async def get_user(
try:
# 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset)
v1_user = await V1User.from_db(db_user, ruleset)
# 异步缓存结果如果有用户ID
if db_user.id is not None:

View File

@@ -5,7 +5,11 @@ from typing import Annotated
from app.calculator import get_calculator
from app.calculators.performance import ConvertError
from app.database import Beatmap, BeatmapResp, User
from app.database import (
Beatmap,
BeatmapModel,
User,
)
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
@@ -19,29 +23,20 @@ from app.models.performance import (
from app.models.score import (
GameMode,
)
from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from sqlmodel import col, select
class BatchGetResp(BaseModel):
"""批量获取谱面返回模型。
返回字段说明:
- beatmaps: 谱面详细信息列表。"""
beatmaps: list[BeatmapResp]
@router.get(
"/beatmaps/lookup",
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
@asset_proxy_response
@@ -67,14 +62,14 @@ async def lookup_beatmap(
raise HTTPException(status_code=404, detail="Beatmap not found")
await db.refresh(current_user)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
@router.get(
"/beatmaps/{beatmap_id}",
tags=["谱面"],
name="获取谱面详情",
response_model=BeatmapResp,
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description="获取单个谱面详情。",
)
@asset_proxy_response
@@ -86,7 +81,12 @@ async def get_beatmap(
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
await db.refresh(current_user)
return await BeatmapModel.transform(
beatmap,
user=current_user,
includes=BeatmapModel.TRANSFORMER_INCLUDES,
)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -95,7 +95,11 @@ async def get_beatmap(
"/beatmaps/",
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
responses={
200: api_doc(
"谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
)
},
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
@asset_proxy_response
@@ -124,7 +128,12 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
await db.refresh(current_user)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
return {
"beatmaps": [
await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
for bm in beatmaps
]
}
@router.post(

View File

@@ -2,17 +2,24 @@ import re
from typing import Annotated, Literal
from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.database import (
Beatmap,
Beatmapset,
BeatmapsetModel,
FavouriteBeatmapset,
SearchBeatmapsetsResp,
User,
)
from app.dependencies.beatmap_download import DownloadService
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
from app.dependencies.database import Database, Redis, with_db
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.beatmap import SearchQueryModel
from app.service.beatmapset_cache_service import generate_hash
from app.utils import api_doc
from .router import router
@@ -27,14 +34,7 @@ from fastapi import (
)
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await Beatmapset.from_resp(session, s)
from sqlmodel import select
@router.get(
@@ -105,7 +105,6 @@ async def search_beatmapset(
try:
sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets)
# 缓存搜索结果
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
@@ -117,8 +116,8 @@ async def search_beatmapset(
@router.get(
"/beatmapsets/lookup",
tags=["谱面集"],
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="查询谱面集 (通过谱面 ID)",
response_model=BeatmapsetResp,
description=("通过谱面 ID 查询所属谱面集。"),
)
@asset_proxy_response
@@ -137,7 +136,10 @@ async def lookup_beatmapset(
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
resp = await BeatmapsetModel.transform(
beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES
)
# 缓存结果
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
@@ -149,8 +151,8 @@ async def lookup_beatmapset(
@router.get(
"/beatmapsets/{beatmapset_id}",
tags=["谱面集"],
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="获取谱面集详情",
response_model=BeatmapsetResp,
description="获取单个谱面集详情。",
)
@asset_proxy_response
@@ -169,7 +171,8 @@ async def get_beatmapset(
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
await db.refresh(current_user)
resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user)
# 缓存结果
await cache_service.cache_beatmapset(resp)

View File

@@ -1,20 +1,29 @@
from typing import Annotated
from app.database import FavouriteBeatmapset, MeResp, User
from app.database import FavouriteBeatmapset, User
from app.database.user import UserModel
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.models.score import GameMode
from app.utils import api_doc
from .router import router
from fastapi import Path, Security
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from sqlmodel import select
ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"]
class BeatmapsetIds(BaseModel):
beatmapset_ids: list[int]
@router.get(
"/me/beatmapset-favourites",
response_model=list[int],
response_model=BeatmapsetIds,
name="获取当前用户收藏的谱面集 ID 列表",
description="获取当前登录用户收藏的谱面集 ID 列表。",
tags=["用户", "谱面集"],
@@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites(
beatmapset_ids = await session.exec(
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
)
return beatmapset_ids.all()
return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all()))
@router.get(
"/me/{ruleset}",
response_model=MeResp,
responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)},
name="获取当前用户信息 (指定 ruleset)",
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
tags=["用户"],
)
async def get_user_info_with_ruleset(
session: Database,
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
user_resp = await UserModel.transform(
user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES
)
return user_resp
@router.get(
"/me/",
response_model=MeResp,
responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)},
name="获取当前用户信息",
description="获取当前登录用户信息。",
tags=["用户"],
)
async def get_user_info_default(
session: Database,
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
user_resp = await UserModel.transform(
user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES
)
return user_resp

View File

@@ -1,11 +1,13 @@
from typing import Annotated, Literal
from app.config import settings
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
from app.database import Team, TeamMember, User, UserStatistics
from app.database.statistics import UserStatisticsModel
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
from app.utils import api_doc
from .router import router
@@ -308,14 +310,16 @@ async def get_country_ranking(
return response
class TopUsersResponse(BaseModel):
ranking: list[UserStatisticsResp]
total: int
@router.get(
"/rankings/{ruleset}/{sort}",
response_model=TopUsersResponse,
responses={
200: api_doc(
"用户排行榜",
{"ranking": list[UserStatisticsModel], "total": int},
["user.country", "user.cover"],
name="TopUsersResponse",
)
},
name="获取用户排行榜",
description="获取在指定模式下的用户排行榜",
tags=["排行榜"],
@@ -339,10 +343,10 @@ async def get_user_ranking(
if cached_data and cached_stats:
# 从缓存返回数据
return TopUsersResponse(
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data],
total=cached_stats.get("total", 0),
)
return {
"ranking": cached_data,
"total": cached_stats.get("total", 0),
}
# 缓存未命中,从数据库查询
wheres = [
@@ -350,7 +354,7 @@ async def get_user_ranking(
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked),
]
include = ["user"]
include = UserStatistics.RANKING_INCLUDES.copy()
if sort == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
@@ -358,6 +362,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
include.append("country_rank")
# 查询总数
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
@@ -378,12 +383,14 @@ async def get_user_ranking(
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
user_stats_resp = await UserStatisticsModel.transform(
statistics, includes=include, user_country=current_user.country_code
)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_data = ranking_data
stats_data = {"total": total_count}
# 创建后台任务来缓存数据
@@ -407,5 +414,7 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
resp = TopUsersResponse(ranking=ranking_data, total=total_count)
return resp
return {
"ranking": ranking_data,
"total": total_count,
}

View File

@@ -1,15 +1,16 @@
from typing import Annotated
from typing import Annotated, Any
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.database.user import UserResp
from app.database import Relationship, RelationshipType, User
from app.database.relationship import RelationshipModel
from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database
from app.dependencies.user import ClientUser, get_current_user
from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -17,38 +18,19 @@ from sqlmodel import col, exists, select
"/friends",
tags=["用户关系"],
responses={
200: {
"description": "好友列表",
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"type": "array",
"items": {"$ref": "#/components/schemas/RelationshipResp"},
"description": "好友列表",
},
{
"type": "array",
"items": {"$ref": "#/components/schemas/UserResp"},
"description": "好友列表 (`x-api-version < 20241022`)",
},
]
}
}
},
}
200: api_doc(
"好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表",
list[RelationshipModel] | list[UserModel],
[f"target.{inc}" for inc in User.LIST_INCLUDES],
)
},
name="获取好友列表",
description=(
"获取当前用户的好友列表。\n\n"
"如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。"
),
description="获取当前用户的好友列表。",
)
@router.get(
"/blocks",
tags=["用户关系"],
response_model=list[RelationshipResp],
response_model=list[dict[str, Any]],
name="获取屏蔽列表",
description="获取当前用户的屏蔽用户列表。",
)
@@ -67,35 +49,29 @@ async def get_relationship(
)
)
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
return [
await RelationshipModel.transform(
rel,
includes=[f"target.{inc}" for inc in User.LIST_INCLUDES],
ruleset=current_user.playmode,
)
for rel in relationships.unique()
]
else:
return [
await UserResp.from_db(
await UserModel.transform(
rel.target,
db,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
ruleset=current_user.playmode,
includes=User.LIST_INCLUDES,
)
for rel in relationships.unique()
]
class AddFriendResp(BaseModel):
"""添加好友/屏蔽 返回模型。
- user_relation: 新的或更新后的关系对象。"""
user_relation: RelationshipResp
@router.post(
"/friends",
tags=["用户关系"],
response_model=AddFriendResp,
responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")},
name="添加或更新好友关系",
description="\n添加或更新与目标用户的好友关系。",
)
@@ -163,7 +139,13 @@ async def add_relationship(
)
)
).one()
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
return {
"user_relation": await RelationshipModel.transform(
relationship,
includes=[],
ruleset=current_user.playmode,
)
}
@router.delete(

View File

@@ -1,25 +1,27 @@
from datetime import UTC
from typing import Annotated, Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp
from app.database.beatmap import (
Beatmap,
BeatmapModel,
)
from app.database.beatmapset import BeatmapsetModel
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from app.database.playlists import Playlist, PlaylistResp
from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.playlists import Playlist, PlaylistModel
from app.database.room import APIUploadedRoom, Room, RoomModel
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.database.user import User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import MatchType, RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get(
"/rooms",
tags=["房间"],
response_model=list[RoomResp],
responses={
200: api_doc(
"房间列表",
list[RoomModel],
[
"current_playlist_item.beatmap.beatmapset",
"difficulty_range",
"host.country",
"playlist_item_stats",
"recent_participants",
],
)
},
name="获取房间列表",
description="获取房间列表。支持按状态/模式筛选",
)
@@ -49,7 +63,7 @@ async def get_all_rooms(
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
):
resp_list: list[RoomResp] = []
resp_list = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
now = utcnow()
@@ -90,22 +104,24 @@ async def get_all_rooms(
.all()
)
for room in db_rooms:
resp = await RoomResp.from_db(room, db)
resp = await RoomModel.transform(
room,
includes=[
"current_playlist_item.beatmap.beatmapset",
"difficulty_range",
"host.country",
"playlist_item_stats",
"recent_participants",
],
)
if category == RoomCategory.REALTIME:
resp.category = RoomCategory.NORMAL
resp["category"] = RoomCategory.NORMAL
resp_list.append(resp)
return resp_list
class APICreatedRoom(RoomResp):
"""创建房间返回模型,继承 RoomResp。额外字段:
- error: 错误信息(为空表示成功)。"""
error: str = ""
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
@@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
@router.post(
"/rooms",
tags=["房间"],
response_model=APICreatedRoom,
name="创建房间",
description="\n创建一个新的房间。",
responses={
200: api_doc(
"创建的房间信息",
RoomModel,
Room.SHOW_RESPONSE_INCLUDES,
)
},
)
async def create_room(
db: Database,
@@ -145,23 +167,27 @@ async def create_room(
):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = ""
created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return created_room
@router.get(
"/rooms/{room_id}",
tags=["房间"],
response_model=RoomResp,
responses={
200: api_doc(
"房间详细信息",
RoomModel,
Room.SHOW_RESPONSE_INCLUDES,
)
},
name="获取房间详情",
description="获取单个房间详情。",
description="获取指定房间详情。",
)
async def get_room(
db: Database,
@@ -177,7 +203,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user)
return resp
@@ -225,10 +251,10 @@ async def add_user_to_room(
await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db)
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return resp
else:
raise HTTPException(404, "room not found0")
raise HTTPException(404, "room not found")
@router.delete(
@@ -268,21 +294,22 @@ async def remove_user_from_room(
raise HTTPException(404, "Room not found")
class APILeaderboard(BaseModel):
"""房间全局排行榜返回模型。
- leaderboard: 用户游玩统计(尝试次数/分数等)。
- user_score: 当前用户对应统计。"""
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
user_score: ItemAttemptsResp | None = None
@router.get(
"/rooms/{room_id}/leaderboard",
tags=["房间"],
response_model=APILeaderboard,
name="获取房间排行榜",
description="获取房间内累计得分排行榜。",
responses={
200: api_doc(
"房间排行榜",
{
"leaderboard": list[ItemAttemptsCountModel],
"user_score": ItemAttemptsCountModel | None,
},
["user.country", "position"],
name="RoomLeaderboardResponse",
)
},
)
async def get_room_leaderboard(
db: Database,
@@ -300,45 +327,43 @@ async def get_room_leaderboard(
aggs_resp = []
user_agg = None
for i, agg in enumerate(aggs):
resp = await ItemAttemptsResp.from_db(agg, db)
resp.position = i + 1
includes = ["user.country"]
if agg.user_id == current_user.id:
includes.append("position")
resp = await ItemAttemptsCountModel.transform(agg, includes=includes)
aggs_resp.append(resp)
if agg.user_id == current_user.id:
user_agg = resp
return APILeaderboard(
leaderboard=aggs_resp,
user_score=user_agg,
)
class RoomEvents(BaseModel):
"""房间事件流返回模型。
- beatmaps: 本次结果涉及的谱面列表。
- beatmapsets: 谱面集映射。
- current_playlist_item_id: 当前游玩列表(项目)项 ID。
- events: 事件列表。
- first_event_id / last_event_id: 事件范围。
- playlist_items: 房间游玩列表(项目)详情。
- room: 房间详情。
- user: 关联用户列表。"""
beatmaps: list[BeatmapResp] = Field(default_factory=list)
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
current_playlist_item_id: int = 0
events: list[MultiplayerEventResp] = Field(default_factory=list)
first_event_id: int = 0
last_event_id: int = 0
playlist_items: list[PlaylistResp] = Field(default_factory=list)
room: RoomResp
user: list[UserResp] = Field(default_factory=list)
return {
"leaderboard": aggs_resp,
"user_score": user_agg,
}
@router.get(
"/rooms/{room_id}/events",
response_model=RoomEvents,
tags=["房间"],
name="获取房间事件",
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
responses={
200: api_doc(
"房间事件",
{
"beatmaps": list[BeatmapModel],
"beatmapsets": list[BeatmapsetModel],
"current_playlist_item_id": int,
"events": list[MultiplayerEventResp],
"first_event_id": int,
"last_event_id": int,
"playlist_items": list[PlaylistModel],
"room": RoomModel,
"user": list[UserModel],
},
["country", "details", "scores"],
name="RoomEventsResponse",
)
},
)
async def get_room_events(
db: Database,
@@ -402,28 +427,44 @@ async def get_room_events(
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room, db)
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
current_playlist_item_id = room_resp.current_playlist_item.id
room_resp = await RoomModel.transform(room, includes=["current_playlist_item"])
if room.category == RoomCategory.REALTIME:
current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"]
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
user_resps = [await UserModel.transform(user, includes=["country"]) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
beatmap_resps = [
await BeatmapModel.transform(
beatmap,
)
for beatmap in beatmaps
]
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
beatmapsets = []
for beatmap in beatmaps:
if beatmap.beatmapset_id not in beatmapsets:
beatmapsets.append(beatmap.beatmapset)
beatmapset_resps = [
await BeatmapsetModel.transform(
beatmapset,
)
for beatmapset in beatmapsets
]
return RoomEvents(
beatmaps=beatmap_resps,
beatmapsets=beatmapset_resps,
current_playlist_item_id=current_playlist_item_id,
events=event_resps,
first_event_id=first_event_id,
last_event_id=last_event_id,
playlist_items=playlist_items_resps,
room=room_resp,
user=user_resps,
)
playlist_items_resps = [
await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values()
]
return {
"beatmaps": beatmap_resps,
"beatmapsets": beatmapset_resps,
"current_playlist_item_id": current_playlist_item_id,
"events": event_resps,
"first_event_id": first_event_id,
"last_event_id": last_event_id,
"playlist_items": playlist_items_resps,
"room": room_resp,
"user": user_resps,
}

View File

@@ -9,7 +9,6 @@ from app.database import (
Playlist,
Room,
Score,
ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
@@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType
from app.database.score import (
LegacyScoreResp,
MultiplayerScores,
ScoreAround,
MultiplayScoreDict,
ScoreModel,
get_leaderboard,
get_score_position_by_id,
process_score,
process_user,
)
@@ -49,7 +50,7 @@ from app.models.score import (
)
from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
@@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 10
DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"]
logger = log("Score")
@@ -180,13 +182,15 @@ async def submit_score(
await db.refresh(score)
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
resp: ScoreResp = await ScoreResp.from_db(db, score)
resp = await ScoreModel.transform(
score,
)
score_gamemode = score.gamemode
await db.commit()
if user_id is not None:
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
background_task.add_task(_process_user_achievement, resp.id)
background_task.add_task(_process_user_achievement, resp["id"])
return resp
@@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None:
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel):
LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp
class BeatmapUserScore(BaseModel):
position: int
score: T
score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm]
class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
scores: list[T]
user_score: BeatmapUserScore[T] | None = None
class BeatmapScores(BaseModel):
scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm]
user_score: BeatmapUserScore | None = None
score_count: int = 0
@router.get(
"/beatmaps/{beatmap_id}/scores",
tags=["成绩"],
response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp],
responses={
200: {
"model": BeatmapScores,
"description": (
"排行榜及当前用户成绩。\n\n"
f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`"
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}"
"否则为 `BeatmapScores[LegacyScoreResp]`。"
),
}
},
name="获取谱面排行榜",
description=(
"获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`"
"否则为 `BeatmapScores[LegacyScoreResp]`。"
),
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
)
async def get_beatmap_scores(
db: Database,
@@ -266,27 +279,46 @@ async def get_beatmap_scores(
mods=sorted(mods),
)
user_score_resp = await user_score.to_resp(db, api_version) if user_score else None
resp = BeatmapScores(
scores=[await score.to_resp(db, api_version) for score in all_scores],
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
if user_score_resp
else None,
score_count=count,
)
return resp
user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None
return {
"scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores],
"user_score": (
{
"score": user_score_resp,
"position": (
await get_score_position_by_id(
db,
user_score.beatmap_id,
user_score.id,
mode=user_score.gamemode,
user=user_score.user,
)
or 0
),
}
if user_score and user_score_resp
else None
),
"score_count": count,
}
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
tags=["成绩"],
response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp],
responses={
200: {
"model": BeatmapUserScore,
"description": (
"指定用户在指定谱面上的最高成绩\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`"
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}"
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
),
}
},
name="获取用户谱面最高成绩",
description=(
"获取指定用户在指定谱面上的最高成绩。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`"
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
),
description="获取指定用户在指定谱面上的最高成绩。",
)
async def get_user_beatmap_score(
db: Database,
@@ -318,23 +350,38 @@ async def get_user_beatmap_score(
detail=f"Cannot find user {user_id}'s score on this beatmap",
)
else:
resp = await user_score.to_resp(db, api_version=api_version)
return BeatmapUserScore(
position=resp.rank_global or 0,
score=resp,
)
resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES)
return {
"position": (
await get_score_position_by_id(
db,
user_score.beatmap_id,
user_score.id,
mode=user_score.gamemode,
user=user_score.user,
)
or 0
),
"score": resp,
}
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
tags=["成绩"],
response_model=list[ScoreResp] | list[LegacyScoreResp],
responses={
200: api_doc(
(
"用户谱面全部成绩\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `Score`列表,"
"否则为 `LegacyScoreResp`列表。"
),
list[ScoreModel] | list[LegacyScoreResp],
DEFAULT_SCORE_INCLUDES,
)
},
name="获取用户谱面全部成绩",
description=(
"获取指定用户在指定谱面上的全部成绩列表。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表,"
"否则为 `LegacyScoreResp`列表。"
),
description="获取指定用户在指定谱面上的全部成绩列表。",
)
async def get_user_all_beatmap_scores(
db: Database,
@@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores(
)
).all()
return [await score.to_resp(db, api_version) for score in all_user_scores]
return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores]
@router.post(
@@ -413,9 +460,9 @@ async def create_solo_score(
@router.put(
"/beatmaps/{beatmap_id}/solo/scores/{token}",
tags=["游玩"],
response_model=ScoreResp,
name="提交单曲成绩",
description="\n使用令牌提交单曲成绩。",
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_solo_score(
background_task: BackgroundTasks,
@@ -520,6 +567,7 @@ async def create_playlist_score(
tags=["游玩"],
name="提交房间项目成绩",
description="\n提交房间游玩项目成绩。",
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_playlist_score(
background_task: BackgroundTasks,
@@ -560,13 +608,13 @@ async def submit_playlist_score(
room_id,
playlist_id,
user_id,
score_resp.id,
score_resp.total_score,
score_resp["id"],
score_resp["total_score"],
session,
redis,
)
await session.commit()
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed:
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]:
await process_daily_challenge_score(session, user_id, room_id)
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
await session.commit()
@@ -575,15 +623,23 @@ async def submit_playlist_score(
class IndexedScoreResp(MultiplayerScores):
total: int
user_score: ScoreResp | None = None
user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm]
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores",
response_model=IndexedScoreResp,
# response_model=IndexedScoreResp,
name="获取房间项目排行榜",
description="获取房间游玩项目排行榜。",
tags=["成绩"],
responses={
200: {
"description": (
f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}"
),
"model": IndexedScoreResp,
}
},
)
async def index_playlist_scores(
session: Database,
@@ -620,16 +676,14 @@ async def index_playlist_scores(
scores = scores[:-1]
user_score = None
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores]
for score in score_resp:
score.position = await get_position(room_id, playlist_id, score.id, session)
if score.user_id == user_id:
if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[
"user_id"
] == user_id:
user_score = score
if room.category == RoomCategory.DAILY_CHALLENGE:
score_resp = [s for s in score_resp if s.passed]
if user_score and not user_score.passed:
user_score = None
user_score["position"] = await get_position(room_id, playlist_id, score["id"], session)
break
resp = IndexedScoreResp(
scores=score_resp,
@@ -648,10 +702,16 @@ async def index_playlist_scores(
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
response_model=ScoreResp,
name="获取房间项目单个成绩",
description="获取指定房间游玩项目中单个成绩详情。",
tags=["成绩"],
responses={
200: api_doc(
"房间项目单个成绩详情。",
ScoreModel,
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
)
},
)
async def show_playlist_score(
session: Database,
@@ -687,39 +747,25 @@ async def show_playlist_score(
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_id, session)
includes = [
*Score.MULTIPLAYER_BASE_INCLUDES,
"position",
]
if completed:
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
resp = await ScoreResp.from_db(session, score.score)
if is_playlist and not resp.passed:
continue
if score.total_score > resp.total_score:
higher_scores.append(resp)
elif score.total_score < resp.total_score:
lower_scores.append(resp)
resp.scores_around = ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
includes.append("scores_around")
resp = await ScoreModel.transform(score_record.score, includes=includes)
return resp
@router.get(
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
response_model=ScoreResp,
responses={
200: api_doc(
"房间项目单个成绩详情。",
ScoreModel,
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
)
},
name="获取房间项目用户成绩",
description="获取指定用户在房间游玩项目中的成绩。",
tags=["成绩"],
@@ -749,8 +795,14 @@ async def get_user_playlist_score(
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
resp = await ScoreModel.transform(
score_record.score,
includes=[
*Score.MULTIPLAYER_BASE_INCLUDES,
"position",
"scores_around",
],
)
return resp

View File

@@ -5,17 +5,16 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import (
Beatmap,
BeatmapModel,
BeatmapPlaycounts,
BeatmapPlaycountsResp,
BeatmapResp,
BeatmapsetResp,
BeatmapsetModel,
User,
UserResp,
)
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
from app.database.best_scores import BestScore
from app.database.events import Event
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
from app.database.score import Score, get_user_first_scores
from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.cache import UserCacheService
from app.dependencies.database import Database, get_redis
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.user_cache_service import get_user_cache_service
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
class BatchUserResponse(BaseModel):
users: list[UserResp]
class BeatmapsPassedResponse(BaseModel):
beatmaps_passed: list[BeatmapResp]
def _get_difficulty_reduction_mods() -> set[str]:
mods: set[str] = set()
for ruleset_mods in API_MODS.values():
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
@router.get(
"/users/",
response_model=BatchUserResponse,
responses={
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
},
name="批量获取用户信息",
description="通过用户 ID 列表批量获取用户信息。",
tags=["用户"],
)
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup", include_in_schema=False)
@router.get("/users/lookup/", include_in_schema=False)
@asset_proxy_response
async def get_users(
session: Database,
@@ -108,16 +100,15 @@ async def get_users(
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=cached_users)
response = {"users": cached_users}
return response
else:
searched_users = (
@@ -127,16 +118,15 @@ async def get_users(
for searched_user in searched_users:
if searched_user.id == BANCHOBOT_ID:
continue
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=users)
response = {"users": users}
return response
@@ -200,10 +190,12 @@ async def get_user_kudosu(
@router.get(
"/users/{user_id}/beatmaps-passed",
response_model=BeatmapsPassedResponse,
name="获取用户已通过谱面",
description="获取指定用户在给定谱面集中的已通过谱面列表。",
tags=["用户"],
responses={
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
},
)
@asset_proxy_response
async def get_user_beatmaps_passed(
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
):
if not beatmapset_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
if len(beatmapset_ids) > 50:
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
scores = (await session.exec(score_query)).all()
if not scores:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
passed_beatmap_ids: set[int] = set()
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
continue
passed_beatmap_ids.add(beatmap_id)
if not passed_beatmap_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
beatmaps = (
await session.exec(
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
)
).all()
return BeatmapsPassedResponse(
beatmaps_passed=[
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
return {
"beatmaps_passed": [
await BeatmapModel.transform(
beatmap,
)
for beatmap in beatmaps
]
)
}
@router.get(
"/users/{user_id}/{ruleset}",
response_model=UserResp,
name="获取用户信息(指定ruleset)",
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info_ruleset(
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
ruleset=ruleset,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
@router.get("/users/{user_id}/", include_in_schema=False)
@router.get(
"/users/{user_id}",
response_model=UserResp,
name="获取用户信息",
description="通过用户 ID 或用户名获取单个用户的详细信息。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info(
@@ -381,27 +375,31 @@ async def get_user_info(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
@router.get(
"/users/{user_id}/beatmapsets/{type}",
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
name="获取用户谱面集列表",
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
tags=["用户"],
responses={
200: api_doc(
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
beatmapset_includes,
)
},
)
@asset_proxy_response
async def get_user_beatmapsets(
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
return cached_result
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
await BeatmapsetModel.transform(
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
)
for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
resp = [
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
for most_played_beatmap in most_played
]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
@router.get(
"/users/{user_id}/scores/{type}",
response_model=list[ScoreResp] | list[LegacyScoreResp],
name="获取用户成绩列表",
description=(
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
@@ -523,6 +522,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
includes = Score.USER_PROFILE_INCLUDES.copy()
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -531,6 +531,7 @@ async def get_user_scores(
elif type == "best":
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
order_by = col(Score.pp).desc()
includes.append("weight")
elif type == "recent":
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
order_by = col(Score.ended_at).desc()
@@ -551,6 +552,7 @@ async def get_user_scores(
await score.to_resp(
session,
api_version,
includes=includes,
)
for score in scores
]