refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -447,7 +447,7 @@ async def create_multiplayer_room(
|
||||
# 让房主加入频道
|
||||
host_user = await db.get(User, host_user_id)
|
||||
if host_user:
|
||||
await server.batch_join_channel([host_user], channel, db)
|
||||
await server.batch_join_channel([host_user], channel)
|
||||
# Add playlist items
|
||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||
|
||||
|
||||
@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
|
||||
from math import ceil
|
||||
import random
|
||||
import shlex
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculator import calculate_weighted_pp
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
|
||||
from app.database.score import Score, get_best_id
|
||||
from app.database.statistics import UserStatistics, get_rank
|
||||
from app.database.user import User
|
||||
@@ -95,7 +98,7 @@ class Bot:
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(bot)
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||
@@ -119,7 +122,7 @@ class Bot:
|
||||
await session.refresh(channel)
|
||||
await session.refresh(user)
|
||||
await session.refresh(bot)
|
||||
await server.batch_join_channel([user, bot], channel, session)
|
||||
await server.batch_join_channel([user, bot], channel)
|
||||
return channel
|
||||
|
||||
async def _send_reply(
|
||||
|
||||
@@ -1,37 +1,40 @@
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatChannelModel,
|
||||
ChatMessage,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User, UserResp
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.utils import api_doc
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||
silences: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/updates",
|
||||
response_model=UpdateResponse,
|
||||
name="获取更新",
|
||||
description="获取当前用户所在频道的最新的禁言情况。",
|
||||
tags=["聊天"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"获取更新响应。",
|
||||
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
|
||||
ChatChannel.LISTING_INCLUDES,
|
||||
name="UpdateResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
@@ -44,45 +47,44 @@ async def get_update(
|
||||
Query(alias="includes[]", description="要包含的更新类型"),
|
||||
] = ["presence", "silences"],
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
resp = {
|
||||
"presence": [],
|
||||
"silences": [],
|
||||
}
|
||||
if "presence" in includes:
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
resp["presence"].append(
|
||||
await ChatChannelModel.transform(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
user=current_user,
|
||||
server=server,
|
||||
includes=ChatChannel.LISTING_INCLUDES,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||
).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
response_model=ChatChannelResp,
|
||||
name="加入频道",
|
||||
description="加入指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
@@ -101,7 +103,7 @@ async def join_channel(
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return await server.join_channel(current_user, db_channel, session)
|
||||
return await server.join_channel(current_user, db_channel)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -128,13 +130,13 @@ async def leave_channel(
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.leave_channel(current_user, db_channel, session)
|
||||
await server.leave_channel(current_user, db_channel)
|
||||
return
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels",
|
||||
response_model=list[ChatChannelResp],
|
||||
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
|
||||
name="获取频道列表",
|
||||
description="获取所有公开频道。",
|
||||
tags=["聊天"],
|
||||
@@ -142,35 +144,30 @@ async def leave_channel(
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
results = await ChatChannelModel.transform_many(
|
||||
channels,
|
||||
user=current_user,
|
||||
server=server,
|
||||
)
|
||||
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class GetChannelResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
users: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}",
|
||||
response_model=GetChannelResp,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"频道详细信息",
|
||||
{
|
||||
"channel": ChatChannelModel,
|
||||
"users": list[UserModel],
|
||||
},
|
||||
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
|
||||
name="GetChannelResponse",
|
||||
)
|
||||
},
|
||||
name="获取频道信息",
|
||||
description="获取指定频道的信息。",
|
||||
tags=["聊天"],
|
||||
@@ -191,7 +188,6 @@ async def get_channel(
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
@@ -209,15 +205,15 @@ async def get_channel(
|
||||
users.extend([target_user, current_user])
|
||||
break
|
||||
|
||||
return GetChannelResp(
|
||||
channel=await ChatChannelResp.from_db(
|
||||
return {
|
||||
"channel": await ChatChannelModel.transform(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
user=current_user,
|
||||
server=server,
|
||||
includes=ChatChannel.LISTING_INCLUDES,
|
||||
),
|
||||
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
|
||||
}
|
||||
|
||||
|
||||
class CreateChannelReq(BaseModel):
|
||||
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
|
||||
|
||||
@router.post(
|
||||
"/chat/channels",
|
||||
response_model=ChatChannelResp,
|
||||
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
|
||||
name="创建频道",
|
||||
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
||||
tags=["聊天"],
|
||||
@@ -289,21 +285,13 @@ async def create_channel(
|
||||
await session.refresh(current_user)
|
||||
if req.type == "PM":
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
await server.batch_join_channel([*target_users, current_user], channel)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
await server.join_channel(current_user, channel)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []),
|
||||
include_recent_messages=True,
|
||||
return await ChatChannelModel.transform(
|
||||
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database import ChatChannelModel
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
ChatMessageModel,
|
||||
MessageType,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
@@ -18,6 +18,7 @@ from app.log import log
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
from app.utils import api_doc
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
|
||||
|
||||
@router.post(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=ChatMessageResp,
|
||||
responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
|
||||
name="发送消息",
|
||||
description="发送消息到指定频道。",
|
||||
tags=["聊天"],
|
||||
@@ -130,7 +131,7 @@ async def send_message(
|
||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||
temp_msg = ChatMessage(
|
||||
message_id=resp.message_id, # 使用 Redis 系统生成的ID
|
||||
message_id=resp["message_id"], # 使用 Redis 系统生成的ID
|
||||
channel_id=channel_id,
|
||||
content=req.message,
|
||||
sender_id=user_id,
|
||||
@@ -151,7 +152,7 @@ async def send_message(
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=list[ChatMessageResp],
|
||||
responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
|
||||
name="获取消息",
|
||||
description="获取指定频道的消息列表(统一按时间正序返回)。",
|
||||
tags=["聊天"],
|
||||
@@ -177,7 +178,7 @@ async def get_message(
|
||||
|
||||
try:
|
||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
|
||||
if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
|
||||
messages.reverse()
|
||||
return messages
|
||||
except Exception as e:
|
||||
@@ -189,7 +190,7 @@ async def get_message(
|
||||
# 向前加载新消息 → 直接 ASC
|
||||
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
# 已经 ASC,无需反转
|
||||
return resp
|
||||
|
||||
@@ -202,15 +203,14 @@ async def get_message(
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
rows.reverse() # 反转为 ASC
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
return resp
|
||||
|
||||
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
rows.reverse() # 反转为 ASC
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
return resp
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
class NewPMResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
message: ChatMessageResp
|
||||
new_channel_id: int
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/new",
|
||||
name="创建私聊频道",
|
||||
description="创建一个新的私聊频道。",
|
||||
tags=["聊天"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"创建私聊频道响应",
|
||||
{
|
||||
"channel": ChatChannelModel,
|
||||
"message": ChatMessageModel,
|
||||
"new_channel_id": int,
|
||||
},
|
||||
["recent_messages.sender", "sender"],
|
||||
name="NewPMResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
@@ -290,9 +296,9 @@ async def create_new_pm(
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
await server.batch_join_channel([target, current_user], channel)
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||
)
|
||||
msg = ChatMessage(
|
||||
channel_id=channel.channel_id,
|
||||
@@ -306,10 +312,10 @@ async def create_new_pm(
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(channel)
|
||||
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
|
||||
await server.send_message_to_channel(message_resp)
|
||||
return NewPMResp(
|
||||
channel=channel_resp,
|
||||
message=message_resp,
|
||||
new_channel_id=channel_resp.channel_id,
|
||||
)
|
||||
return {
|
||||
"channel": channel_resp,
|
||||
"message": message_resp,
|
||||
"new_channel_id": channel_resp["channel_id"],
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Annotated, overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database import ChatMessageDict
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import (
|
||||
@@ -16,7 +17,7 @@ from app.log import log
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
from app.utils import bg_tasks, safe_json_dumps
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -65,7 +66,7 @@ class ChatServer:
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
@@ -80,7 +81,7 @@ class ChatServer:
|
||||
return
|
||||
client = client_
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
await client.send_text(safe_json_dumps(event))
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
@@ -107,38 +108,38 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message['channel_id']}, message_id: "
|
||||
f"{message['message_id']}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
logger.info(f"Sending bot command to user {message['sender_id']}")
|
||||
bg_tasks.add_task(self.send_event, message["sender_id"], event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
message["channel_id"],
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
if message["message_id"] and message["message_id"] > 0:
|
||||
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
|
||||
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
|
||||
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
|
||||
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
|
||||
channel_id = channel.channel_id
|
||||
|
||||
not_joined = []
|
||||
@@ -151,22 +152,18 @@ class ChatServer:
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -175,25 +172,21 @@ class ChatServer:
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -203,18 +196,14 @@ class ChatServer:
|
||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
@@ -232,7 +221,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.join_channel(user, db_channel, session)
|
||||
await self.join_channel(user, db_channel)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
@@ -248,7 +237,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
@@ -336,6 +325,6 @@ async def chat_websocket(
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
await server.join_channel(user, db_channel)
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
@@ -2,7 +2,7 @@ import hashlib
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
|
||||
from app.database.user import BASE_INCLUDES, User, UserResp
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser
|
||||
@@ -14,12 +14,11 @@ from app.models.notification import (
|
||||
from app.models.score import GameMode
|
||||
from app.router.notification import server
|
||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
from app.utils import check_image, utcnow
|
||||
from app.utils import api_doc, check_image, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import File, Form, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, exists, select
|
||||
|
||||
|
||||
@@ -214,12 +213,22 @@ async def delete_team(
|
||||
await cache_service.invalidate_team_cache()
|
||||
|
||||
|
||||
class TeamQueryResp(BaseModel):
|
||||
team: TeamResp
|
||||
members: list[UserResp]
|
||||
|
||||
|
||||
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
|
||||
@router.get(
|
||||
"/team/{team_id}",
|
||||
name="查询战队",
|
||||
tags=["战队", "g0v0 API"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"战队信息",
|
||||
{
|
||||
"team": TeamResp,
|
||||
"members": list[UserModel],
|
||||
},
|
||||
["statistics", "country"],
|
||||
name="TeamQueryResp",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def get_team(
|
||||
session: Database,
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
@@ -233,10 +242,10 @@ async def get_team(
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return TeamQueryResp(
|
||||
team=await TeamResp.from_db(members[0].team, session, gamemode),
|
||||
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
|
||||
)
|
||||
return {
|
||||
"team": await TeamResp.from_db(members[0].team, session, gamemode),
|
||||
"members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.database.statistics import UserStatistics, UserStatisticsModel
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.log import logger
|
||||
@@ -46,7 +46,7 @@ class V1User(AllStrModel):
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
||||
async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
||||
ruleset = ruleset or db_user.playmode
|
||||
current_statistics: UserStatistics | None = None
|
||||
for i in await db_user.awaitable_attrs.statistics:
|
||||
@@ -54,31 +54,33 @@ class V1User(AllStrModel):
|
||||
current_statistics = i
|
||||
break
|
||||
if current_statistics:
|
||||
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
|
||||
statistics = await UserStatisticsModel.transform(
|
||||
current_statistics, country_code=db_user.country_code, includes=["country_rank"]
|
||||
)
|
||||
else:
|
||||
statistics = None
|
||||
return cls(
|
||||
user_id=db_user.id,
|
||||
username=db_user.username,
|
||||
join_date=db_user.join_date,
|
||||
count300=statistics.count_300 if statistics else 0,
|
||||
count100=statistics.count_100 if statistics else 0,
|
||||
count50=statistics.count_50 if statistics else 0,
|
||||
playcount=statistics.play_count if statistics else 0,
|
||||
ranked_score=statistics.ranked_score if statistics else 0,
|
||||
total_score=statistics.total_score if statistics else 0,
|
||||
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
|
||||
count300=current_statistics.count_300 if current_statistics else 0,
|
||||
count100=current_statistics.count_100 if current_statistics else 0,
|
||||
count50=current_statistics.count_50 if current_statistics else 0,
|
||||
playcount=current_statistics.play_count if current_statistics else 0,
|
||||
ranked_score=current_statistics.ranked_score if current_statistics else 0,
|
||||
total_score=current_statistics.total_score if current_statistics else 0,
|
||||
pp_rank=statistics.get("global_rank") or 0 if statistics else 0,
|
||||
level=current_statistics.level_current if current_statistics else 0,
|
||||
pp_raw=statistics.pp if statistics else 0.0,
|
||||
accuracy=statistics.hit_accuracy if statistics else 0,
|
||||
pp_raw=current_statistics.pp if current_statistics else 0.0,
|
||||
accuracy=current_statistics.hit_accuracy if current_statistics else 0,
|
||||
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
|
||||
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
|
||||
count_rank_s=current_statistics.grade_s if current_statistics else 0,
|
||||
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
|
||||
count_rank_a=current_statistics.grade_a if current_statistics else 0,
|
||||
country=db_user.country_code,
|
||||
total_seconds_played=statistics.play_time if statistics else 0,
|
||||
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
|
||||
total_seconds_played=current_statistics.play_time if current_statistics else 0,
|
||||
pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0,
|
||||
events=[], # TODO
|
||||
)
|
||||
|
||||
@@ -134,7 +136,7 @@ async def get_user(
|
||||
|
||||
try:
|
||||
# 生成用户数据
|
||||
v1_user = await V1User.from_db(session, db_user, ruleset)
|
||||
v1_user = await V1User.from_db(db_user, ruleset)
|
||||
|
||||
# 异步缓存结果(如果有用户ID)
|
||||
if db_user.id is not None:
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import Annotated
|
||||
|
||||
from app.calculator import get_calculator
|
||||
from app.calculators.performance import ConvertError
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapModel,
|
||||
User,
|
||||
)
|
||||
from app.database.beatmap import calculate_beatmap_attributes
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
@@ -19,29 +23,20 @@ from app.models.performance import (
|
||||
from app.models.score import (
|
||||
GameMode,
|
||||
)
|
||||
from app.utils import api_doc
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class BatchGetResp(BaseModel):
|
||||
"""批量获取谱面返回模型。
|
||||
|
||||
返回字段说明:
|
||||
- beatmaps: 谱面详细信息列表。"""
|
||||
|
||||
beatmaps: list[BeatmapResp]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/lookup",
|
||||
tags=["谱面"],
|
||||
name="查询单个谱面",
|
||||
response_model=BeatmapResp,
|
||||
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
|
||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||
)
|
||||
@asset_proxy_response
|
||||
@@ -67,14 +62,14 @@ async def lookup_beatmap(
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
await db.refresh(current_user)
|
||||
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap_id}",
|
||||
tags=["谱面"],
|
||||
name="获取谱面详情",
|
||||
response_model=BeatmapResp,
|
||||
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
|
||||
description="获取单个谱面详情。",
|
||||
)
|
||||
@asset_proxy_response
|
||||
@@ -86,7 +81,12 @@ async def get_beatmap(
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
await db.refresh(current_user)
|
||||
return await BeatmapModel.transform(
|
||||
beatmap,
|
||||
user=current_user,
|
||||
includes=BeatmapModel.TRANSFORMER_INCLUDES,
|
||||
)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -95,7 +95,11 @@ async def get_beatmap(
|
||||
"/beatmaps/",
|
||||
tags=["谱面"],
|
||||
name="批量获取谱面",
|
||||
response_model=BatchGetResp,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
|
||||
)
|
||||
},
|
||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||
)
|
||||
@asset_proxy_response
|
||||
@@ -124,7 +128,12 @@ async def batch_get_beatmaps(
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
await db.refresh(current_user)
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
||||
return {
|
||||
"beatmaps": [
|
||||
await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
|
||||
for bm in beatmaps
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -2,17 +2,24 @@ import re
|
||||
from typing import Annotated, Literal
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
Beatmapset,
|
||||
BeatmapsetModel,
|
||||
FavouriteBeatmapset,
|
||||
SearchBeatmapsetsResp,
|
||||
User,
|
||||
)
|
||||
from app.dependencies.beatmap_download import DownloadService
|
||||
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
|
||||
from app.dependencies.database import Database, Redis, with_db
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.geoip import IPAddress, get_geoip_helper
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.helpers.asset_proxy_helper import asset_proxy_response
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.service.beatmapset_cache_service import generate_hash
|
||||
from app.utils import api_doc
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -27,14 +34,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPError
|
||||
from sqlmodel import exists, select
|
||||
|
||||
|
||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
async with with_db() as session:
|
||||
for s in sets.beatmapsets:
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
|
||||
await Beatmapset.from_resp(session, s)
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -105,7 +105,6 @@ async def search_beatmapset(
|
||||
|
||||
try:
|
||||
sets = await fetcher.search_beatmapset(query, cursor, redis)
|
||||
background_tasks.add_task(_save_to_db, sets)
|
||||
|
||||
# 缓存搜索结果
|
||||
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
||||
@@ -117,8 +116,8 @@ async def search_beatmapset(
|
||||
@router.get(
|
||||
"/beatmapsets/lookup",
|
||||
tags=["谱面集"],
|
||||
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
|
||||
name="查询谱面集 (通过谱面 ID)",
|
||||
response_model=BeatmapsetResp,
|
||||
description=("通过谱面 ID 查询所属谱面集。"),
|
||||
)
|
||||
@asset_proxy_response
|
||||
@@ -137,7 +136,10 @@ async def lookup_beatmapset(
|
||||
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
||||
|
||||
resp = await BeatmapsetModel.transform(
|
||||
beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
||||
@@ -149,8 +151,8 @@ async def lookup_beatmapset(
|
||||
@router.get(
|
||||
"/beatmapsets/{beatmapset_id}",
|
||||
tags=["谱面集"],
|
||||
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
|
||||
name="获取谱面集详情",
|
||||
response_model=BeatmapsetResp,
|
||||
description="获取单个谱面集详情。",
|
||||
)
|
||||
@asset_proxy_response
|
||||
@@ -169,7 +171,8 @@ async def get_beatmapset(
|
||||
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
||||
resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
|
||||
await db.refresh(current_user)
|
||||
resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user)
|
||||
|
||||
# 缓存结果
|
||||
await cache_service.cache_beatmapset(resp)
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import FavouriteBeatmapset, MeResp, User
|
||||
from app.database import FavouriteBeatmapset, User
|
||||
from app.database.user import UserModel
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
|
||||
from app.models.score import GameMode
|
||||
from app.utils import api_doc
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Path, Security
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
|
||||
ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"]
|
||||
|
||||
|
||||
class BeatmapsetIds(BaseModel):
|
||||
beatmapset_ids: list[int]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me/beatmapset-favourites",
|
||||
response_model=list[int],
|
||||
response_model=BeatmapsetIds,
|
||||
name="获取当前用户收藏的谱面集 ID 列表",
|
||||
description="获取当前登录用户收藏的谱面集 ID 列表。",
|
||||
tags=["用户", "谱面集"],
|
||||
@@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites(
|
||||
beatmapset_ids = await session.exec(
|
||||
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
|
||||
)
|
||||
return beatmapset_ids.all()
|
||||
return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all()))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me/{ruleset}",
|
||||
response_model=MeResp,
|
||||
responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)},
|
||||
name="获取当前用户信息 (指定 ruleset)",
|
||||
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info_with_ruleset(
|
||||
session: Database,
|
||||
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
|
||||
user_resp = await UserModel.transform(
|
||||
user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES
|
||||
)
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me/",
|
||||
response_model=MeResp,
|
||||
responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)},
|
||||
name="获取当前用户信息",
|
||||
description="获取当前登录用户信息。",
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info_default(
|
||||
session: Database,
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
|
||||
user_resp = await UserModel.transform(
|
||||
user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES
|
||||
)
|
||||
return user_resp
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
|
||||
from app.database import Team, TeamMember, User, UserStatistics
|
||||
from app.database.statistics import UserStatisticsModel
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import GameMode
|
||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
from app.utils import api_doc
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -308,14 +310,16 @@ async def get_country_ranking(
|
||||
return response
|
||||
|
||||
|
||||
class TopUsersResponse(BaseModel):
|
||||
ranking: list[UserStatisticsResp]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rankings/{ruleset}/{sort}",
|
||||
response_model=TopUsersResponse,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"用户排行榜",
|
||||
{"ranking": list[UserStatisticsModel], "total": int},
|
||||
["user.country", "user.cover"],
|
||||
name="TopUsersResponse",
|
||||
)
|
||||
},
|
||||
name="获取用户排行榜",
|
||||
description="获取在指定模式下的用户排行榜",
|
||||
tags=["排行榜"],
|
||||
@@ -339,10 +343,10 @@ async def get_user_ranking(
|
||||
|
||||
if cached_data and cached_stats:
|
||||
# 从缓存返回数据
|
||||
return TopUsersResponse(
|
||||
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data],
|
||||
total=cached_stats.get("total", 0),
|
||||
)
|
||||
return {
|
||||
"ranking": cached_data,
|
||||
"total": cached_stats.get("total", 0),
|
||||
}
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
wheres = [
|
||||
@@ -350,7 +354,7 @@ async def get_user_ranking(
|
||||
col(UserStatistics.pp) > 0,
|
||||
col(UserStatistics.is_ranked),
|
||||
]
|
||||
include = ["user"]
|
||||
include = UserStatistics.RANKING_INCLUDES.copy()
|
||||
if sort == "performance":
|
||||
order_by = col(UserStatistics.pp).desc()
|
||||
include.append("rank_change_since_30_days")
|
||||
@@ -358,6 +362,7 @@ async def get_user_ranking(
|
||||
order_by = col(UserStatistics.ranked_score).desc()
|
||||
if country:
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
include.append("country_rank")
|
||||
|
||||
# 查询总数
|
||||
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
|
||||
@@ -378,12 +383,14 @@ async def get_user_ranking(
|
||||
# 转换为响应格式
|
||||
ranking_data = []
|
||||
for statistics in statistics_list:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
user_stats_resp = await UserStatisticsModel.transform(
|
||||
statistics, includes=include, user_country=current_user.country_code
|
||||
)
|
||||
ranking_data.append(user_stats_resp)
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
# 使用配置文件中的TTL设置
|
||||
cache_data = [item.model_dump() for item in ranking_data]
|
||||
cache_data = ranking_data
|
||||
stats_data = {"total": total_count}
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
@@ -407,5 +414,7 @@ async def get_user_ranking(
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
resp = TopUsersResponse(ranking=ranking_data, total=total_count)
|
||||
return resp
|
||||
return {
|
||||
"ranking": ranking_data,
|
||||
"total": total_count,
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
||||
from app.database.user import UserResp
|
||||
from app.database import Relationship, RelationshipType, User
|
||||
from app.database.relationship import RelationshipModel
|
||||
from app.database.user import UserModel
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.utils import api_doc
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Request, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, exists, select
|
||||
|
||||
|
||||
@@ -17,38 +18,19 @@ from sqlmodel import col, exists, select
|
||||
"/friends",
|
||||
tags=["用户关系"],
|
||||
responses={
|
||||
200: {
|
||||
"description": "好友列表",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/components/schemas/RelationshipResp"},
|
||||
"description": "好友列表",
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/components/schemas/UserResp"},
|
||||
"description": "好友列表 (`x-api-version < 20241022`)",
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
200: api_doc(
|
||||
"好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表。",
|
||||
list[RelationshipModel] | list[UserModel],
|
||||
[f"target.{inc}" for inc in User.LIST_INCLUDES],
|
||||
)
|
||||
},
|
||||
name="获取好友列表",
|
||||
description=(
|
||||
"获取当前用户的好友列表。\n\n"
|
||||
"如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。"
|
||||
),
|
||||
description="获取当前用户的好友列表。",
|
||||
)
|
||||
@router.get(
|
||||
"/blocks",
|
||||
tags=["用户关系"],
|
||||
response_model=list[RelationshipResp],
|
||||
response_model=list[dict[str, Any]],
|
||||
name="获取屏蔽列表",
|
||||
description="获取当前用户的屏蔽用户列表。",
|
||||
)
|
||||
@@ -67,35 +49,29 @@ async def get_relationship(
|
||||
)
|
||||
)
|
||||
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
||||
return [
|
||||
await RelationshipModel.transform(
|
||||
rel,
|
||||
includes=[f"target.{inc}" for inc in User.LIST_INCLUDES],
|
||||
ruleset=current_user.playmode,
|
||||
)
|
||||
for rel in relationships.unique()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
await UserResp.from_db(
|
||||
await UserModel.transform(
|
||||
rel.target,
|
||||
db,
|
||||
include=[
|
||||
"team",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
],
|
||||
ruleset=current_user.playmode,
|
||||
includes=User.LIST_INCLUDES,
|
||||
)
|
||||
for rel in relationships.unique()
|
||||
]
|
||||
|
||||
|
||||
class AddFriendResp(BaseModel):
|
||||
"""添加好友/屏蔽 返回模型。
|
||||
|
||||
- user_relation: 新的或更新后的关系对象。"""
|
||||
|
||||
user_relation: RelationshipResp
|
||||
|
||||
|
||||
@router.post(
|
||||
"/friends",
|
||||
tags=["用户关系"],
|
||||
response_model=AddFriendResp,
|
||||
responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")},
|
||||
name="添加或更新好友关系",
|
||||
description="\n添加或更新与目标用户的好友关系。",
|
||||
)
|
||||
@@ -163,7 +139,13 @@ async def add_relationship(
|
||||
)
|
||||
)
|
||||
).one()
|
||||
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
|
||||
return {
|
||||
"user_relation": await RelationshipModel.transform(
|
||||
relationship,
|
||||
includes=[],
|
||||
ruleset=current_user.playmode,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
from datetime import UTC
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp
|
||||
from app.database.beatmap import (
|
||||
Beatmap,
|
||||
BeatmapModel,
|
||||
)
|
||||
from app.database.beatmapset import BeatmapsetModel
|
||||
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel
|
||||
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from app.database.playlists import Playlist, PlaylistResp
|
||||
from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||
from app.database.playlists import Playlist, PlaylistModel
|
||||
from app.database.room import APIUploadedRoom, Room, RoomModel
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.score import Score
|
||||
from app.database.user import User, UserResp
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.room import MatchType, RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.utils import utcnow
|
||||
from app.utils import api_doc, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@router.get(
|
||||
"/rooms",
|
||||
tags=["房间"],
|
||||
response_model=list[RoomResp],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间列表",
|
||||
list[RoomModel],
|
||||
[
|
||||
"current_playlist_item.beatmap.beatmapset",
|
||||
"difficulty_range",
|
||||
"host.country",
|
||||
"playlist_item_stats",
|
||||
"recent_participants",
|
||||
],
|
||||
)
|
||||
},
|
||||
name="获取房间列表",
|
||||
description="获取房间列表。支持按状态/模式筛选",
|
||||
)
|
||||
@@ -49,7 +63,7 @@ async def get_all_rooms(
|
||||
] = RoomCategory.NORMAL,
|
||||
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
|
||||
):
|
||||
resp_list: list[RoomResp] = []
|
||||
resp_list = []
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
|
||||
now = utcnow()
|
||||
|
||||
@@ -90,22 +104,24 @@ async def get_all_rooms(
|
||||
.all()
|
||||
)
|
||||
for room in db_rooms:
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
resp = await RoomModel.transform(
|
||||
room,
|
||||
includes=[
|
||||
"current_playlist_item.beatmap.beatmapset",
|
||||
"difficulty_range",
|
||||
"host.country",
|
||||
"playlist_item_stats",
|
||||
"recent_participants",
|
||||
],
|
||||
)
|
||||
if category == RoomCategory.REALTIME:
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp["category"] = RoomCategory.NORMAL
|
||||
|
||||
resp_list.append(resp)
|
||||
|
||||
return resp_list
|
||||
|
||||
|
||||
class APICreatedRoom(RoomResp):
|
||||
"""创建房间返回模型,继承 RoomResp。额外字段:
|
||||
- error: 错误信息(为空表示成功)。"""
|
||||
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
@@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
|
||||
@router.post(
|
||||
"/rooms",
|
||||
tags=["房间"],
|
||||
response_model=APICreatedRoom,
|
||||
name="创建房间",
|
||||
description="\n创建一个新的房间。",
|
||||
responses={
|
||||
200: api_doc(
|
||||
"创建的房间信息",
|
||||
RoomModel,
|
||||
Room.SHOW_RESPONSE_INCLUDES,
|
||||
)
|
||||
},
|
||||
)
|
||||
async def create_room(
|
||||
db: Database,
|
||||
@@ -145,23 +167,27 @@ async def create_room(
|
||||
):
|
||||
if await current_user.is_restricted(db):
|
||||
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
|
||||
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
|
||||
return created_room
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}",
|
||||
tags=["房间"],
|
||||
response_model=RoomResp,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间详细信息",
|
||||
RoomModel,
|
||||
Room.SHOW_RESPONSE_INCLUDES,
|
||||
)
|
||||
},
|
||||
name="获取房间详情",
|
||||
description="获取单个房间详情。",
|
||||
description="获取指定房间详情。",
|
||||
)
|
||||
async def get_room(
|
||||
db: Database,
|
||||
@@ -177,7 +203,7 @@ async def get_room(
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
|
||||
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -225,10 +251,10 @@ async def add_user_to_room(
|
||||
await _participate_room(room_id, user_id, db_room, db, redis)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(404, "room not found0")
|
||||
raise HTTPException(404, "room not found")
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -268,21 +294,22 @@ async def remove_user_from_room(
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
|
||||
class APILeaderboard(BaseModel):
|
||||
"""房间全局排行榜返回模型。
|
||||
- leaderboard: 用户游玩统计(尝试次数/分数等)。
|
||||
- user_score: 当前用户对应统计。"""
|
||||
|
||||
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
|
||||
user_score: ItemAttemptsResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/leaderboard",
|
||||
tags=["房间"],
|
||||
response_model=APILeaderboard,
|
||||
name="获取房间排行榜",
|
||||
description="获取房间内累计得分排行榜。",
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间排行榜",
|
||||
{
|
||||
"leaderboard": list[ItemAttemptsCountModel],
|
||||
"user_score": ItemAttemptsCountModel | None,
|
||||
},
|
||||
["user.country", "position"],
|
||||
name="RoomLeaderboardResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def get_room_leaderboard(
|
||||
db: Database,
|
||||
@@ -300,45 +327,43 @@ async def get_room_leaderboard(
|
||||
aggs_resp = []
|
||||
user_agg = None
|
||||
for i, agg in enumerate(aggs):
|
||||
resp = await ItemAttemptsResp.from_db(agg, db)
|
||||
resp.position = i + 1
|
||||
includes = ["user.country"]
|
||||
if agg.user_id == current_user.id:
|
||||
includes.append("position")
|
||||
resp = await ItemAttemptsCountModel.transform(agg, includes=includes)
|
||||
aggs_resp.append(resp)
|
||||
if agg.user_id == current_user.id:
|
||||
user_agg = resp
|
||||
return APILeaderboard(
|
||||
leaderboard=aggs_resp,
|
||||
user_score=user_agg,
|
||||
)
|
||||
|
||||
|
||||
class RoomEvents(BaseModel):
|
||||
"""房间事件流返回模型。
|
||||
- beatmaps: 本次结果涉及的谱面列表。
|
||||
- beatmapsets: 谱面集映射。
|
||||
- current_playlist_item_id: 当前游玩列表(项目)项 ID。
|
||||
- events: 事件列表。
|
||||
- first_event_id / last_event_id: 事件范围。
|
||||
- playlist_items: 房间游玩列表(项目)详情。
|
||||
- room: 房间详情。
|
||||
- user: 关联用户列表。"""
|
||||
|
||||
beatmaps: list[BeatmapResp] = Field(default_factory=list)
|
||||
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
|
||||
current_playlist_item_id: int = 0
|
||||
events: list[MultiplayerEventResp] = Field(default_factory=list)
|
||||
first_event_id: int = 0
|
||||
last_event_id: int = 0
|
||||
playlist_items: list[PlaylistResp] = Field(default_factory=list)
|
||||
room: RoomResp
|
||||
user: list[UserResp] = Field(default_factory=list)
|
||||
return {
|
||||
"leaderboard": aggs_resp,
|
||||
"user_score": user_agg,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/events",
|
||||
response_model=RoomEvents,
|
||||
tags=["房间"],
|
||||
name="获取房间事件",
|
||||
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间事件",
|
||||
{
|
||||
"beatmaps": list[BeatmapModel],
|
||||
"beatmapsets": list[BeatmapsetModel],
|
||||
"current_playlist_item_id": int,
|
||||
"events": list[MultiplayerEventResp],
|
||||
"first_event_id": int,
|
||||
"last_event_id": int,
|
||||
"playlist_items": list[PlaylistModel],
|
||||
"room": RoomModel,
|
||||
"user": list[UserModel],
|
||||
},
|
||||
["country", "details", "scores"],
|
||||
name="RoomEventsResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def get_room_events(
|
||||
db: Database,
|
||||
@@ -402,28 +427,44 @@ async def get_room_events(
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room, db)
|
||||
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
|
||||
current_playlist_item_id = room_resp.current_playlist_item.id
|
||||
room_resp = await RoomModel.transform(room, includes=["current_playlist_item"])
|
||||
if room.category == RoomCategory.REALTIME:
|
||||
current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"]
|
||||
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
user_resps = [await UserModel.transform(user, includes=["country"]) for user in users]
|
||||
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
beatmap_resps = [
|
||||
await BeatmapModel.transform(
|
||||
beatmap,
|
||||
)
|
||||
for beatmap in beatmaps
|
||||
]
|
||||
|
||||
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
|
||||
beatmapsets = []
|
||||
for beatmap in beatmaps:
|
||||
if beatmap.beatmapset_id not in beatmapsets:
|
||||
beatmapsets.append(beatmap.beatmapset)
|
||||
beatmapset_resps = [
|
||||
await BeatmapsetModel.transform(
|
||||
beatmapset,
|
||||
)
|
||||
for beatmapset in beatmapsets
|
||||
]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
beatmapsets=beatmapset_resps,
|
||||
current_playlist_item_id=current_playlist_item_id,
|
||||
events=event_resps,
|
||||
first_event_id=first_event_id,
|
||||
last_event_id=last_event_id,
|
||||
playlist_items=playlist_items_resps,
|
||||
room=room_resp,
|
||||
user=user_resps,
|
||||
)
|
||||
playlist_items_resps = [
|
||||
await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values()
|
||||
]
|
||||
|
||||
return {
|
||||
"beatmaps": beatmap_resps,
|
||||
"beatmapsets": beatmapset_resps,
|
||||
"current_playlist_item_id": current_playlist_item_id,
|
||||
"events": event_resps,
|
||||
"first_event_id": first_event_id,
|
||||
"last_event_id": last_event_id,
|
||||
"playlist_items": playlist_items_resps,
|
||||
"room": room_resp,
|
||||
"user": user_resps,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ from app.database import (
|
||||
Playlist,
|
||||
Room,
|
||||
Score,
|
||||
ScoreResp,
|
||||
ScoreToken,
|
||||
ScoreTokenResp,
|
||||
User,
|
||||
@@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType
|
||||
from app.database.score import (
|
||||
LegacyScoreResp,
|
||||
MultiplayerScores,
|
||||
ScoreAround,
|
||||
MultiplayScoreDict,
|
||||
ScoreModel,
|
||||
get_leaderboard,
|
||||
get_score_position_by_id,
|
||||
process_score,
|
||||
process_user,
|
||||
)
|
||||
@@ -49,7 +50,7 @@ from app.models.score import (
|
||||
)
|
||||
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||
from app.service.user_cache_service import refresh_user_cache_background
|
||||
from app.utils import utcnow
|
||||
from app.utils import api_doc, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 10
|
||||
DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"]
|
||||
logger = log("Score")
|
||||
|
||||
|
||||
@@ -180,13 +182,15 @@ async def submit_score(
|
||||
await db.refresh(score)
|
||||
|
||||
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
|
||||
resp: ScoreResp = await ScoreResp.from_db(db, score)
|
||||
resp = await ScoreModel.transform(
|
||||
score,
|
||||
)
|
||||
score_gamemode = score.gamemode
|
||||
|
||||
await db.commit()
|
||||
if user_id is not None:
|
||||
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
|
||||
background_task.add_task(_process_user_achievement, resp.id)
|
||||
background_task.add_task(_process_user_achievement, resp["id"])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None:
|
||||
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||
|
||||
|
||||
class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel):
|
||||
LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: T
|
||||
score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm]
|
||||
|
||||
|
||||
class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
|
||||
scores: list[T]
|
||||
user_score: BeatmapUserScore[T] | None = None
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm]
|
||||
user_score: BeatmapUserScore | None = None
|
||||
score_count: int = 0
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap_id}/scores",
|
||||
tags=["成绩"],
|
||||
response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp],
|
||||
responses={
|
||||
200: {
|
||||
"model": BeatmapScores,
|
||||
"description": (
|
||||
"排行榜及当前用户成绩。\n\n"
|
||||
f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`"
|
||||
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
|
||||
"否则为 `BeatmapScores[LegacyScoreResp]`。"
|
||||
),
|
||||
}
|
||||
},
|
||||
name="获取谱面排行榜",
|
||||
description=(
|
||||
"获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n"
|
||||
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`,"
|
||||
"否则为 `BeatmapScores[LegacyScoreResp]`。"
|
||||
),
|
||||
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
db: Database,
|
||||
@@ -266,27 +279,46 @@ async def get_beatmap_scores(
|
||||
mods=sorted(mods),
|
||||
)
|
||||
|
||||
user_score_resp = await user_score.to_resp(db, api_version) if user_score else None
|
||||
resp = BeatmapScores(
|
||||
scores=[await score.to_resp(db, api_version) for score in all_scores],
|
||||
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
|
||||
if user_score_resp
|
||||
else None,
|
||||
score_count=count,
|
||||
)
|
||||
return resp
|
||||
user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None
|
||||
return {
|
||||
"scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores],
|
||||
"user_score": (
|
||||
{
|
||||
"score": user_score_resp,
|
||||
"position": (
|
||||
await get_score_position_by_id(
|
||||
db,
|
||||
user_score.beatmap_id,
|
||||
user_score.id,
|
||||
mode=user_score.gamemode,
|
||||
user=user_score.user,
|
||||
)
|
||||
or 0
|
||||
),
|
||||
}
|
||||
if user_score and user_score_resp
|
||||
else None
|
||||
),
|
||||
"score_count": count,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
|
||||
tags=["成绩"],
|
||||
response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp],
|
||||
responses={
|
||||
200: {
|
||||
"model": BeatmapUserScore,
|
||||
"description": (
|
||||
"指定用户在指定谱面上的最高成绩\n\n"
|
||||
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`,"
|
||||
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
|
||||
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
|
||||
),
|
||||
}
|
||||
},
|
||||
name="获取用户谱面最高成绩",
|
||||
description=(
|
||||
"获取指定用户在指定谱面上的最高成绩。\n\n"
|
||||
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`,"
|
||||
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
|
||||
),
|
||||
description="获取指定用户在指定谱面上的最高成绩。",
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
db: Database,
|
||||
@@ -318,23 +350,38 @@ async def get_user_beatmap_score(
|
||||
detail=f"Cannot find user {user_id}'s score on this beatmap",
|
||||
)
|
||||
else:
|
||||
resp = await user_score.to_resp(db, api_version=api_version)
|
||||
return BeatmapUserScore(
|
||||
position=resp.rank_global or 0,
|
||||
score=resp,
|
||||
)
|
||||
resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES)
|
||||
return {
|
||||
"position": (
|
||||
await get_score_position_by_id(
|
||||
db,
|
||||
user_score.beatmap_id,
|
||||
user_score.id,
|
||||
mode=user_score.gamemode,
|
||||
user=user_score.user,
|
||||
)
|
||||
or 0
|
||||
),
|
||||
"score": resp,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
|
||||
tags=["成绩"],
|
||||
response_model=list[ScoreResp] | list[LegacyScoreResp],
|
||||
responses={
|
||||
200: api_doc(
|
||||
(
|
||||
"用户谱面全部成绩\n\n"
|
||||
"如果 `x-api-version >= 20220705`,返回值为 `Score`列表,"
|
||||
"否则为 `LegacyScoreResp`列表。"
|
||||
),
|
||||
list[ScoreModel] | list[LegacyScoreResp],
|
||||
DEFAULT_SCORE_INCLUDES,
|
||||
)
|
||||
},
|
||||
name="获取用户谱面全部成绩",
|
||||
description=(
|
||||
"获取指定用户在指定谱面上的全部成绩列表。\n\n"
|
||||
"如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表,"
|
||||
"否则为 `LegacyScoreResp`列表。"
|
||||
),
|
||||
description="获取指定用户在指定谱面上的全部成绩列表。",
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
db: Database,
|
||||
@@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
).all()
|
||||
|
||||
return [await score.to_resp(db, api_version) for score in all_user_scores]
|
||||
return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -413,9 +460,9 @@ async def create_solo_score(
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap_id}/solo/scores/{token}",
|
||||
tags=["游玩"],
|
||||
response_model=ScoreResp,
|
||||
name="提交单曲成绩",
|
||||
description="\n使用令牌提交单曲成绩。",
|
||||
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
|
||||
)
|
||||
async def submit_solo_score(
|
||||
background_task: BackgroundTasks,
|
||||
@@ -520,6 +567,7 @@ async def create_playlist_score(
|
||||
tags=["游玩"],
|
||||
name="提交房间项目成绩",
|
||||
description="\n提交房间游玩项目成绩。",
|
||||
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
|
||||
)
|
||||
async def submit_playlist_score(
|
||||
background_task: BackgroundTasks,
|
||||
@@ -560,13 +608,13 @@ async def submit_playlist_score(
|
||||
room_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
score_resp.id,
|
||||
score_resp.total_score,
|
||||
score_resp["id"],
|
||||
score_resp["total_score"],
|
||||
session,
|
||||
redis,
|
||||
)
|
||||
await session.commit()
|
||||
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed:
|
||||
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]:
|
||||
await process_daily_challenge_score(session, user_id, room_id)
|
||||
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
||||
await session.commit()
|
||||
@@ -575,15 +623,23 @@ async def submit_playlist_score(
|
||||
|
||||
class IndexedScoreResp(MultiplayerScores):
|
||||
total: int
|
||||
user_score: ScoreResp | None = None
|
||||
user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores",
|
||||
response_model=IndexedScoreResp,
|
||||
# response_model=IndexedScoreResp,
|
||||
name="获取房间项目排行榜",
|
||||
description="获取房间游玩项目排行榜。",
|
||||
tags=["成绩"],
|
||||
responses={
|
||||
200: {
|
||||
"description": (
|
||||
f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}"
|
||||
),
|
||||
"model": IndexedScoreResp,
|
||||
}
|
||||
},
|
||||
)
|
||||
async def index_playlist_scores(
|
||||
session: Database,
|
||||
@@ -620,16 +676,14 @@ async def index_playlist_scores(
|
||||
scores = scores[:-1]
|
||||
|
||||
user_score = None
|
||||
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
|
||||
score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores]
|
||||
for score in score_resp:
|
||||
score.position = await get_position(room_id, playlist_id, score.id, session)
|
||||
if score.user_id == user_id:
|
||||
if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[
|
||||
"user_id"
|
||||
] == user_id:
|
||||
user_score = score
|
||||
|
||||
if room.category == RoomCategory.DAILY_CHALLENGE:
|
||||
score_resp = [s for s in score_resp if s.passed]
|
||||
if user_score and not user_score.passed:
|
||||
user_score = None
|
||||
user_score["position"] = await get_position(room_id, playlist_id, score["id"], session)
|
||||
break
|
||||
|
||||
resp = IndexedScoreResp(
|
||||
scores=score_resp,
|
||||
@@ -648,10 +702,16 @@ async def index_playlist_scores(
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
||||
response_model=ScoreResp,
|
||||
name="获取房间项目单个成绩",
|
||||
description="获取指定房间游玩项目中单个成绩详情。",
|
||||
tags=["成绩"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间项目单个成绩详情。",
|
||||
ScoreModel,
|
||||
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
|
||||
)
|
||||
},
|
||||
)
|
||||
async def show_playlist_score(
|
||||
session: Database,
|
||||
@@ -687,39 +747,25 @@ async def show_playlist_score(
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(room_id, playlist_id, score_id, session)
|
||||
includes = [
|
||||
*Score.MULTIPLAYER_BASE_INCLUDES,
|
||||
"position",
|
||||
]
|
||||
if completed:
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
higher_scores = []
|
||||
lower_scores = []
|
||||
for score in scores:
|
||||
resp = await ScoreResp.from_db(session, score.score)
|
||||
if is_playlist and not resp.passed:
|
||||
continue
|
||||
if score.total_score > resp.total_score:
|
||||
higher_scores.append(resp)
|
||||
elif score.total_score < resp.total_score:
|
||||
lower_scores.append(resp)
|
||||
resp.scores_around = ScoreAround(
|
||||
higher=MultiplayerScores(scores=higher_scores),
|
||||
lower=MultiplayerScores(scores=lower_scores),
|
||||
)
|
||||
|
||||
includes.append("scores_around")
|
||||
resp = await ScoreModel.transform(score_record.score, includes=includes)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
||||
response_model=ScoreResp,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"房间项目单个成绩详情。",
|
||||
ScoreModel,
|
||||
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
|
||||
)
|
||||
},
|
||||
name="获取房间项目用户成绩",
|
||||
description="获取指定用户在房间游玩项目中的成绩。",
|
||||
tags=["成绩"],
|
||||
@@ -749,8 +795,14 @@ async def get_user_playlist_score(
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
|
||||
resp = await ScoreModel.transform(
|
||||
score_record.score,
|
||||
includes=[
|
||||
*Score.MULTIPLAYER_BASE_INCLUDES,
|
||||
"position",
|
||||
"scores_around",
|
||||
],
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,16 @@ from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapModel,
|
||||
BeatmapPlaycounts,
|
||||
BeatmapPlaycountsResp,
|
||||
BeatmapResp,
|
||||
BeatmapsetResp,
|
||||
BeatmapsetModel,
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
|
||||
from app.database.best_scores import BestScore
|
||||
from app.database.events import Event
|
||||
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
|
||||
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
|
||||
from app.database.score import Score, get_user_first_scores
|
||||
from app.database.user import UserModel
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.cache import UserCacheService
|
||||
from app.dependencies.database import Database, get_redis
|
||||
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import BeatmapsetType
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
from app.utils import utcnow
|
||||
from app.utils import api_doc, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[UserResp]
|
||||
|
||||
|
||||
class BeatmapsPassedResponse(BaseModel):
|
||||
beatmaps_passed: list[BeatmapResp]
|
||||
|
||||
|
||||
def _get_difficulty_reduction_mods() -> set[str]:
|
||||
mods: set[str] = set()
|
||||
for ruleset_mods in API_MODS.values():
|
||||
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
|
||||
|
||||
@router.get(
|
||||
"/users/",
|
||||
response_model=BatchUserResponse,
|
||||
responses={
|
||||
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
|
||||
},
|
||||
name="批量获取用户信息",
|
||||
description="通过用户 ID 列表批量获取用户信息。",
|
||||
tags=["用户"],
|
||||
)
|
||||
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
@router.get("/users/lookup", include_in_schema=False)
|
||||
@router.get("/users/lookup/", include_in_schema=False)
|
||||
@asset_proxy_response
|
||||
async def get_users(
|
||||
session: Database,
|
||||
@@ -108,16 +100,15 @@ async def get_users(
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id != BANCHOBOT_ID:
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
includes=User.CARD_INCLUDES,
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
response = BatchUserResponse(users=cached_users)
|
||||
response = {"users": cached_users}
|
||||
return response
|
||||
else:
|
||||
searched_users = (
|
||||
@@ -127,16 +118,15 @@ async def get_users(
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id == BANCHOBOT_ID:
|
||||
continue
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
includes=User.CARD_INCLUDES,
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
response = BatchUserResponse(users=users)
|
||||
response = {"users": users}
|
||||
return response
|
||||
|
||||
|
||||
@@ -200,10 +190,12 @@ async def get_user_kudosu(
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/beatmaps-passed",
|
||||
response_model=BeatmapsPassedResponse,
|
||||
name="获取用户已通过谱面",
|
||||
description="获取指定用户在给定谱面集中的已通过谱面列表。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_beatmaps_passed(
|
||||
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
|
||||
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
|
||||
):
|
||||
if not beatmapset_ids:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
if len(beatmapset_ids) > 50:
|
||||
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
|
||||
|
||||
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
|
||||
|
||||
scores = (await session.exec(score_query)).all()
|
||||
if not scores:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
|
||||
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
|
||||
passed_beatmap_ids: set[int] = set()
|
||||
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
|
||||
continue
|
||||
passed_beatmap_ids.add(beatmap_id)
|
||||
if not passed_beatmap_ids:
|
||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
||||
return {"beatmaps_passed": []}
|
||||
|
||||
beatmaps = (
|
||||
await session.exec(
|
||||
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
|
||||
)
|
||||
).all()
|
||||
|
||||
return BeatmapsPassedResponse(
|
||||
beatmaps_passed=[
|
||||
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
|
||||
return {
|
||||
"beatmaps_passed": [
|
||||
await BeatmapModel.transform(
|
||||
beatmap,
|
||||
)
|
||||
for beatmap in beatmaps
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/{ruleset}",
|
||||
response_model=UserResp,
|
||||
name="获取用户信息(指定ruleset)",
|
||||
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_info_ruleset(
|
||||
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
|
||||
if should_not_show:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
include = SEARCH_INCLUDED
|
||||
if searched_is_self:
|
||||
include = ALL_INCLUDED
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=include,
|
||||
includes=User.USER_INCLUDES,
|
||||
ruleset=ruleset,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
|
||||
@router.get("/users/{user_id}/", include_in_schema=False)
|
||||
@router.get(
|
||||
"/users/{user_id}",
|
||||
response_model=UserResp,
|
||||
name="获取用户信息",
|
||||
description="通过用户 ID 或用户名获取单个用户的详细信息。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_info(
|
||||
@@ -381,27 +375,31 @@ async def get_user_info(
|
||||
if should_not_show:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
include = SEARCH_INCLUDED
|
||||
if searched_is_self:
|
||||
include = ALL_INCLUDED
|
||||
user_resp = await UserResp.from_db(
|
||||
user_resp = await UserModel.transform(
|
||||
searched_user,
|
||||
session,
|
||||
include=include,
|
||||
includes=User.USER_INCLUDES,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/beatmapsets/{type}",
|
||||
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
|
||||
name="获取用户谱面集列表",
|
||||
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
|
||||
tags=["用户"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
|
||||
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
|
||||
beatmapset_includes,
|
||||
)
|
||||
},
|
||||
)
|
||||
@asset_proxy_response
|
||||
async def get_user_beatmapsets(
|
||||
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
return [BeatmapPlaycountsResp(**item) for item in cached_result]
|
||||
else:
|
||||
return [BeatmapsetResp(**item) for item in cached_result]
|
||||
return cached_result
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if not user or user.id == BANCHOBOT_ID:
|
||||
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
await BeatmapsetModel.transform(
|
||||
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
|
||||
)
|
||||
for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
resp = [
|
||||
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/scores/{type}",
|
||||
response_model=list[ScoreResp] | list[LegacyScoreResp],
|
||||
name="获取用户成绩列表",
|
||||
description=(
|
||||
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
|
||||
@@ -523,6 +522,7 @@ async def get_user_scores(
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
includes = Score.USER_PROFILE_INCLUDES.copy()
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -531,6 +531,7 @@ async def get_user_scores(
|
||||
elif type == "best":
|
||||
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
|
||||
order_by = col(Score.pp).desc()
|
||||
includes.append("weight")
|
||||
elif type == "recent":
|
||||
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
|
||||
order_by = col(Score.ended_at).desc()
|
||||
@@ -551,6 +552,7 @@ async def get_user_scores(
|
||||
await score.to_resp(
|
||||
session,
|
||||
api_version,
|
||||
includes=includes,
|
||||
)
|
||||
for score in scores
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user