diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 6818b66..d6ed022 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -9,23 +9,26 @@ git clone https://github.com/GooGuTeam/g0v0-server.git
此外,您还需要:
- clone 旁观服务器到 g0v0-server 的文件夹。
+
```bash
git clone https://github.com/GooGuTeam/osu-server-spectator.git spectator-server
```
+
- clone 表现分计算器到 g0v0-server 的文件夹。
```bash
git clone https://github.com/GooGuTeam/osu-performance-server.git performance-server
```
+
- 下载并放置自定义规则集 DLL 到 `rulesets/` 目录(如果需要)。
## 开发环境
为了确保一致的开发环境,我们强烈建议使用提供的 Dev Container。这将设置一个容器化的环境,预先安装所有必要的工具和依赖项。
-1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。
-2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。
-3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。
+1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。
+2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。
+3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。
## 配置项目
@@ -67,54 +70,109 @@ uv sync
以下是项目主要目录和文件的结构说明:
-- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
-- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
-- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
-- `app/`: 存放所有核心应用代码。
- - `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
- - `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
- - `database/`: 定义数据库模型 (SQLModel) 和会话管理。
- - `models/`: 定义非数据库模型和其他模型。
- - `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
- - `dependencies/`: 管理 FastAPI 的依赖项注入。
- - `achievements/`: 存放与成就相关的逻辑。
- - `storage/`: 存储服务代码。
- - `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
- - `middleware/`: 定义中间件,例如会话验证。
- - `helpers/`: 存放辅助函数和工具类。
- - `config.py`: 应用配置,使用 pydantic-settings 管理。
- - `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
- - `log.py`: 日志记录模块,提供统一的日志接口。
- - `const.py`: 定义常量。
- - `path.py`: 定义跨文件使用的常量。
-- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
-- `static/`: 存放静态文件,如 `mods.json`。
+- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
+- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
+- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
+- `app/`: 存放所有核心应用代码。
+ - `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
+ - `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
+ - `database/`: 定义数据库模型 (SQLModel) 和会话管理。
+ - `models/`: 定义非数据库模型和其他模型。
+ - `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
+ - `dependencies/`: 管理 FastAPI 的依赖项注入。
+ - `achievements/`: 存放与成就相关的逻辑。
+ - `storage/`: 存储服务代码。
+ - `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
+ - `middleware/`: 定义中间件,例如会话验证。
+ - `helpers/`: 存放辅助函数和工具类。
+ - `config.py`: 应用配置,使用 pydantic-settings 管理。
+ - `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
+ - `log.py`: 日志记录模块,提供统一的日志接口。
+ - `const.py`: 定义常量。
+ - `path.py`: 定义跨文件使用的常量。
+- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
+- `static/`: 存放静态文件,如 `mods.json`。
### 数据库模型定义
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
-如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构:
+本项目使用一种“按需返回”的设计模式,遵循 `Dict` - `Model` - `Table` 结构。详细设计思路请参考[这篇文章](https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/)。
+
+#### 基本结构
+
+1. **Dict**: 定义模型转换后的字典结构,用于类型检查。必须在 Model 之前定义。
+2. **Model**: 继承自 `DatabaseModel[Dict]`,定义字段和计算属性。
+3. **Table**: 继承自 `Model`,定义数据库表结构。
```python
-class ModelBase(SQLModel):
- # 定义共有内容
- ...
+from typing import TypedDict, NotRequired
+from app.database._base import DatabaseModel, OnDemand, included, ondemand
+from sqlmodel import Field
+# 1. 定义 Dict
+class UserDict(TypedDict):
+ id: int
+ username: str
+ email: NotRequired[str] # 可选字段
+ followers_count: int # 计算属性
-class Model(ModelBase, table=True):
- # 定义数据库表内容
- ...
+# 2. 定义 Model
+class UserModel(DatabaseModel[UserDict]):
+ id: int = Field(primary_key=True)
+ username: str
+ email: OnDemand[str] # 使用 OnDemand 标记可选字段
-
-class ModelResp(ModelBase):
- # 定义响应内容
- ...
-
- @classmethod
- def from_db(cls, db: Model) -> "ModelResp":
- # 从数据库模型转换
+ # 普通计算属性 (总是返回)
+ @included
+ @staticmethod
+ async def followers_count(session: AsyncSession, instance: "User") -> int:
+ return await session.scalar(select(func.count()).where(Follower.followed_id == instance.id))
+
+ # 可选计算属性 (仅在 includes 中指定时返回)
+ @ondemand
+ @staticmethod
+ async def some_optional_property(session: AsyncSession, instance: "User") -> str:
...
+
+# 3. 定义 Table
+class User(UserModel, table=True):
+ password: str # 仅在数据库中存在的字段
+ ...
+```
+
+#### 字段类型
+
+- **普通属性**: 直接定义在 Model 中,总是返回。
+- **可选属性**: 使用 `OnDemand[T]` 标记,仅在 `includes` 中指定时返回。
+- **普通计算属性**: 使用 `@included` 装饰的静态方法,总是返回。
+- **可选计算属性**: 使用 `@ondemand` 装饰的静态方法,仅在 `includes` 中指定时返回。
+
+#### 使用方法
+
+**转换模型**:
+
+使用 `Model.transform` 方法将数据库实例转换为字典:
+
+```python
+user = await session.get(User, 1)
+user_dict = await UserModel.transform(
+ user,
+ includes=["email"], # 指定需要返回的可选字段
+ some_context="foo-bar", # 如果计算属性需要上下文,可以传入额外参数
+ session=session # 可选传入自己的 session
+)
+```
+
+**API 文档**:
+
+在 FastAPI 路由中,使用 `Model.generate_typeddict` 生成准确的响应文档:
+
+```python
+@router.get("/users/{id}", response_model=UserModel.generate_typeddict(includes=("email",)))
+async def get_user(id: int) -> dict:
+ ...
+ return await UserModel.transform(user, includes=["email"])
```
数据库模块名应与表名相同,定义了多个模型的除外。
@@ -227,16 +285,16 @@ pre-commit 不提供 pyright 的 hook,您需要手动运行 `pyright` 检查
**类型** 必须是以下之一:
-* **feat**:新功能
-* **fix**:错误修复
-* **docs**:仅文档更改
-* **style**:不影响代码含义的更改(空格、格式、缺少分号等)
-* **refactor**:代码重构
-* **perf**:改善性能的代码更改
-* **test**:添加缺失的测试或修正现有测试
-* **chore**:对构建过程或辅助工具和库(如文档生成)的更改
-* **ci**:持续集成相关的更改
-* **deploy**: 部署相关的更改
+- **feat**:新功能
+- **fix**:错误修复
+- **docs**:仅文档更改
+- **style**:不影响代码含义的更改(空格、格式、缺少分号等)
+- **refactor**:代码重构
+- **perf**:改善性能的代码更改
+- **test**:添加缺失的测试或修正现有测试
+- **chore**:对构建过程或辅助工具和库(如文档生成)的更改
+- **ci**:持续集成相关的更改
+- **deploy**: 部署相关的更改
**范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。对整个项目的更改使用 `project`。
diff --git a/app/database/__init__.py b/app/database/__init__.py
index 6922184..601b235 100644
--- a/app/database/__init__.py
+++ b/app/database/__init__.py
@@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
from .beatmap import (
Beatmap,
- BeatmapResp,
+ BeatmapDict,
+ BeatmapModel,
+)
+from .beatmap_playcounts import (
+ BeatmapPlaycounts,
+ BeatmapPlaycountsDict,
+ BeatmapPlaycountsModel,
)
-from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
from .beatmap_sync import BeatmapSync
from .beatmap_tags import BeatmapTagVote
from .beatmapset import (
Beatmapset,
- BeatmapsetResp,
+ BeatmapsetDict,
+ BeatmapsetModel,
)
from .beatmapset_ratings import BeatmapRating
from .best_scores import BestScore
from .chat import (
ChannelType,
ChatChannel,
- ChatChannelResp,
+ ChatChannelDict,
+ ChatChannelModel,
ChatMessage,
- ChatMessageResp,
+ ChatMessageDict,
+ ChatMessageModel,
)
from .counts import (
CountResp,
@@ -30,8 +38,8 @@ from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset
from .item_attempts_count import (
ItemAttemptsCount,
- ItemAttemptsResp,
- PlaylistAggregateScore,
+ ItemAttemptsCountDict,
+ ItemAttemptsCountModel,
)
from .matchmaking import (
MatchmakingPool,
@@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .notification import Notification, UserNotification
from .password_reset import PasswordReset
from .playlist_best_score import PlaylistBestScore
-from .playlists import Playlist, PlaylistResp
+from .playlists import Playlist, PlaylistDict, PlaylistModel
from .rank_history import RankHistory, RankHistoryResp, RankTop
-from .relationship import Relationship, RelationshipResp, RelationshipType
-from .room import APIUploadedRoom, Room, RoomResp
+from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType
+from .room import APIUploadedRoom, Room, RoomDict, RoomModel
from .room_participated_user import RoomParticipatedUser
from .score import (
MultiplayerScores,
Score,
ScoreAround,
- ScoreBase,
- ScoreResp,
+ ScoreDict,
+ ScoreModel,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
+from .search_beatmapset import SearchBeatmapsetsResp
from .statistics import (
UserStatistics,
- UserStatisticsResp,
+ UserStatisticsDict,
+ UserStatisticsModel,
)
from .team import Team, TeamMember, TeamRequest, TeamResp
from .total_score_best_scores import TotalScoreBestScore
from .user import (
- MeResp,
User,
- UserResp,
+ UserDict,
+ UserModel,
)
from .user_account_history import (
UserAccountHistory,
@@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru
__all__ = [
"APIUploadedRoom",
"Beatmap",
+ "BeatmapDict",
+ "BeatmapModel",
"BeatmapPlaycounts",
- "BeatmapPlaycountsResp",
+ "BeatmapPlaycountsDict",
+ "BeatmapPlaycountsModel",
"BeatmapRating",
- "BeatmapResp",
"BeatmapSync",
"BeatmapTagVote",
"Beatmapset",
- "BeatmapsetResp",
+ "BeatmapsetDict",
+ "BeatmapsetModel",
"BestScore",
"ChannelType",
"ChatChannel",
- "ChatChannelResp",
+ "ChatChannelDict",
+ "ChatChannelModel",
"ChatMessage",
- "ChatMessageResp",
+ "ChatMessageDict",
+ "ChatMessageModel",
"CountResp",
"DailyChallengeStats",
"DailyChallengeStatsResp",
@@ -100,13 +115,13 @@ __all__ = [
"Event",
"FavouriteBeatmapset",
"ItemAttemptsCount",
- "ItemAttemptsResp",
+ "ItemAttemptsCountDict",
+ "ItemAttemptsCountModel",
"LoginSession",
"LoginSessionResp",
"MatchmakingPool",
"MatchmakingPoolBeatmap",
"MatchmakingUserStats",
- "MeResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
"MultiplayerEventResp",
@@ -116,26 +131,29 @@ __all__ = [
"OAuthToken",
"PasswordReset",
"Playlist",
- "PlaylistAggregateScore",
"PlaylistBestScore",
- "PlaylistResp",
+ "PlaylistDict",
+ "PlaylistModel",
"RankHistory",
"RankHistoryResp",
"RankTop",
"Relationship",
- "RelationshipResp",
+ "RelationshipDict",
+ "RelationshipModel",
"RelationshipType",
"ReplayWatchedCount",
"Room",
+ "RoomDict",
+ "RoomModel",
"RoomParticipatedUser",
- "RoomResp",
"Score",
"ScoreAround",
- "ScoreBase",
- "ScoreResp",
+ "ScoreDict",
+ "ScoreModel",
"ScoreStatistics",
"ScoreToken",
"ScoreTokenResp",
+ "SearchBeatmapsetsResp",
"Team",
"TeamMember",
"TeamRequest",
@@ -149,17 +167,18 @@ __all__ = [
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement",
- "UserAchievement",
"UserAchievementResp",
+ "UserDict",
"UserLoginLog",
+ "UserModel",
"UserNotification",
"UserPreference",
- "UserResp",
"UserStatistics",
- "UserStatisticsResp",
+ "UserStatisticsDict",
+ "UserStatisticsModel",
"V1APIKeys",
]
for i in __all__:
- if i.endswith("Resp"):
- globals()[i].model_rebuild() # type: ignore[call-arg]
+ if i.endswith("Model") or i.endswith("Resp"):
+ globals()[i].model_rebuild()
diff --git a/app/database/_base.py b/app/database/_base.py
new file mode 100644
index 0000000..7a05010
--- /dev/null
+++ b/app/database/_base.py
@@ -0,0 +1,499 @@
+from collections.abc import Awaitable, Callable, Sequence
+from functools import lru_cache, wraps
+import inspect
+import sys
+from types import NoneType, get_original_bases
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Concatenate,
+ ForwardRef,
+ ParamSpec,
+ TypedDict,
+ cast,
+ get_args,
+ get_origin,
+ overload,
+)
+
+from app.models.model import UTCBaseModel
+from app.utils import type_is_optional
+
+from sqlalchemy.ext.asyncio import async_object_session
+from sqlmodel import SQLModel
+from sqlmodel.ext.asyncio.session import AsyncSession
+from sqlmodel.main import SQLModelMetaclass
+
+_dict_to_model: dict[type, type["DatabaseModel"]] = {}
+
+
+def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
+ """Safely evaluate a ForwardRef, with fallback to app.database module"""
+ if isinstance(type_, str):
+ type_ = ForwardRef(type_)
+
+ try:
+ return evaluate_forwardref(
+ type_,
+ globalns=vars(sys.modules[module_name]),
+ localns={},
+ )
+ except (NameError, AttributeError, KeyError):
+ # Fallback to app.database module
+ try:
+ import app.database
+
+ return evaluate_forwardref(
+ type_,
+ globalns=vars(app.database),
+ localns={},
+ )
+ except (NameError, AttributeError, KeyError):
+ return None
+
+
+class OnDemand[T]:
+ if TYPE_CHECKING:
+
+ def __get__(self, instance: object | None, owner: Any) -> T: ...
+
+ def __set__(self, instance: Any, value: T) -> None: ...
+
+ def __delete__(self, instance: Any) -> None: ...
+
+
+class Exclude[T]:
+ if TYPE_CHECKING:
+
+ def __get__(self, instance: object | None, owner: Any) -> T: ...
+
+ def __set__(self, instance: Any, value: T) -> None: ...
+
+ def __delete__(self, instance: Any) -> None: ...
+
+
+# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
+def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
+ raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
+ if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
+ # See https://github.com/pydantic/pydantic/pull/11991
+ from annotationlib import (
+ Format,
+ call_annotate_function,
+ get_annotate_from_class_namespace,
+ )
+
+ if annotate := get_annotate_from_class_namespace(class_dict):
+ raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
+ return raw_annotations
+
+
+# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
+if sys.version_info < (3, 12, 4):
+
+ def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
+ # Even though it is the right signature for python 3.9, mypy complains with
+ # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
+ # Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
+ # TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
+ return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
+
+else:
+
+ def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
+ # Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
+ # warnings:
+ return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
+
+
+class DatabaseModelMetaclass(SQLModelMetaclass):
+ def __new__(
+ cls,
+ name: str,
+ bases: tuple[type, ...],
+ namespace: dict[str, Any],
+ **kwargs: Any,
+ ) -> "DatabaseModelMetaclass":
+ original_annotations = _get_annotations(namespace)
+ new_annotations = {}
+ ondemands = []
+ excludes = []
+
+ for k, v in original_annotations.items():
+ if get_origin(v) is OnDemand:
+ inner_type = v.__args__[0]
+ new_annotations[k] = inner_type
+ ondemands.append(k)
+ elif get_origin(v) is Exclude:
+ inner_type = v.__args__[0]
+ new_annotations[k] = inner_type
+ excludes.append(k)
+ else:
+ new_annotations[k] = v
+
+ new_class = super().__new__(
+ cls,
+ name,
+ bases,
+ {
+ **namespace,
+ "__annotations__": new_annotations,
+ },
+ **kwargs,
+ )
+
+ new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
+ new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
+ ondemands
+ )
+ new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
+ new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
+
+ for attr_name, attr_value in namespace.items():
+ target = _get_callable_target(attr_value)
+ if target is None:
+ continue
+
+ if getattr(target, "__included__", False):
+ new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
+ _pre_calculate_context_params(target, attr_value)
+
+ if getattr(target, "__calculated_ondemand__", False):
+ new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
+ _pre_calculate_context_params(target, attr_value)
+
+ # Register TDict to DatabaseModel mapping
+ for base in get_original_bases(new_class):
+ cls_name = base.__name__
+ if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
+ generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
+ generic_type = evaluate_forwardref(
+ ForwardRef(generic_type_name),
+ globalns=vars(sys.modules[new_class.__module__]),
+ localns={},
+ )
+ _dict_to_model[generic_type[0]] = new_class
+
+ return new_class
+
+
+def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
+ if hasattr(target, "__context_params__"):
+ return
+
+ sig = inspect.signature(target)
+ params = list(sig.parameters.keys())
+
+ start_index = 2
+ if isinstance(attr_value, classmethod):
+ start_index = 3
+
+ context_params = [] if len(params) < start_index else params[start_index:]
+
+ setattr(target, "__context_params__", context_params)
+
+
+def _get_callable_target(value: Any) -> Callable | None:
+ if isinstance(value, (staticmethod, classmethod)):
+ return value.__func__
+ if inspect.isfunction(value):
+ return value
+ if inspect.ismethod(value):
+ return value.__func__
+ return None
+
+
+def _mark_callable(value: Any, flag: str) -> Callable | None:
+ target = _get_callable_target(value)
+ if target is None:
+ return None
+ setattr(target, flag, True)
+ return target
+
+
+def _get_return_type(func: Callable) -> type:
+ sig = inspect.get_annotations(func)
+ return sig.get("return", Any)
+
+
+P = ParamSpec("P")
+CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
+DecoratorTarget = CalculatedField | staticmethod | classmethod
+
+
+def included(func: DecoratorTarget) -> DecoratorTarget:
+ marker = _mark_callable(func, "__included__")
+ if marker is None:
+ raise RuntimeError("@included is only usable on callables.")
+
+ @wraps(marker)
+ async def wrapper(*args, **kwargs):
+ return await marker(*args, **kwargs)
+
+ if isinstance(func, staticmethod):
+ return staticmethod(wrapper)
+ if isinstance(func, classmethod):
+ return classmethod(wrapper)
+ return wrapper
+
+
+def ondemand(func: DecoratorTarget) -> DecoratorTarget:
+ marker = _mark_callable(func, "__calculated_ondemand__")
+ if marker is None:
+ raise RuntimeError("@ondemand is only usable on callables.")
+
+ @wraps(marker)
+ async def wrapper(*args, **kwargs):
+ return await marker(*args, **kwargs)
+
+ if isinstance(func, staticmethod):
+ return staticmethod(wrapper)
+ if isinstance(func, classmethod):
+ return classmethod(wrapper)
+ return wrapper
+
+
+async def call_awaitable_with_context(
+ func: CalculatedField,
+ session: AsyncSession,
+ instance: Any,
+ context: dict[str, Any],
+) -> Any:
+ context_params: list[str] | None = getattr(func, "__context_params__", None)
+
+ if context_params is None:
+ # Fallback if not pre-calculated
+ sig = inspect.signature(func)
+ if len(sig.parameters) == 2:
+ return await func(session, instance)
+ else:
+ call_params = {}
+ for param in sig.parameters.values():
+ if param.name in context:
+ call_params[param.name] = context[param.name]
+ return await func(session, instance, **call_params)
+
+ if not context_params:
+ return await func(session, instance)
+
+ call_params = {}
+ for name in context_params:
+ if name in context:
+ call_params[name] = context[name]
+ return await func(session, instance, **call_params)
+
+
+class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
+ _CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
+
+ _ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
+ _ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
+
+ _EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
+
+ @overload
+ @classmethod
+ async def transform(
+ cls,
+ db_instance: "DatabaseModel",
+ *,
+ session: AsyncSession,
+ includes: list[str] | None = None,
+ **context: Any,
+ ) -> TDict: ...
+
+ @overload
+ @classmethod
+ async def transform(
+ cls,
+ db_instance: "DatabaseModel",
+ *,
+ includes: list[str] | None = None,
+ **context: Any,
+ ) -> TDict: ...
+
+ @classmethod
+ async def transform(
+ cls,
+ db_instance: "DatabaseModel",
+ *,
+ session: AsyncSession | None = None,
+ includes: list[str] | None = None,
+ **context: Any,
+ ) -> TDict:
+ includes = includes.copy() if includes is not None else []
+ session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
+ if session is None:
+ raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
+ resp_obj = cls.model_validate(db_instance.model_dump())
+ data = resp_obj.model_dump()
+
+ for field in cls._CALCULATED_FIELDS:
+ func = getattr(cls, field)
+ value = await call_awaitable_with_context(func, session, db_instance, context)
+ data[field] = value
+
+ sub_include_map: dict[str, list[str]] = {}
+ for include in [i for i in includes if "." in i]:
+ parent, sub_include = include.split(".", 1)
+ if parent not in sub_include_map:
+ sub_include_map[parent] = []
+ sub_include_map[parent].append(sub_include)
+ includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
+
+ for field, sub_includes in sub_include_map.items():
+ if field in cls._ONDEMAND_CALCULATED_FIELDS:
+ func = getattr(cls, field)
+ value = await call_awaitable_with_context(
+ func, session, db_instance, {**context, "includes": sub_includes}
+ )
+ data[field] = value
+
+ for include in includes:
+ if include in data:
+ continue
+
+ if include in cls._ONDEMAND_CALCULATED_FIELDS:
+ func = getattr(cls, include)
+ value = await call_awaitable_with_context(func, session, db_instance, context)
+ data[include] = value
+
+ for field in cls._ONDEMAND_DATABASE_FIELDS:
+ if field not in includes:
+ del data[field]
+
+ for field in cls._EXCLUDED_DATABASE_FIELDS:
+ if field in data:
+ del data[field]
+
+ return cast(TDict, data)
+
+ @classmethod
+ async def transform_many(
+ cls,
+ db_instances: Sequence["DatabaseModel"],
+ *,
+ session: AsyncSession | None = None,
+ includes: list[str] | None = None,
+ **context: Any,
+ ) -> list[TDict]:
+ if not db_instances:
+ return []
+
+ # SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
+ # if the transform method performs any database operations using the shared session.
+ # Since we don't know if the transform method (or its calculated fields) will use the DB,
+ # we must execute them serially to be safe.
+ results = []
+ for instance in db_instances:
+ results.append(await cls.transform(instance, session=session, includes=includes, **context))
+ return results
+
+ @classmethod
+ @lru_cache
+ def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
+ def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
+ # Evaluate ForwardRef if present
+ if isinstance(field_type, (str, ForwardRef)):
+ resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
+ if resolved is not None:
+ field_type = resolved
+
+ origin_type = get_origin(field_type)
+ inner_type = field_type
+ args = get_args(field_type)
+
+ is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
+ if is_optional:
+ inner_type = next((arg for arg in args if arg is not NoneType), field_type)
+
+ is_list = False
+ if origin_type is list:
+ is_list = True
+ inner_type = args[0]
+
+ # Evaluate ForwardRef in inner_type if present
+ if isinstance(inner_type, (str, ForwardRef)):
+ resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
+ if resolved is not None:
+ inner_type = resolved
+
+ if not resolve_database_model:
+ if is_optional:
+ return inner_type | None # pyright: ignore[reportOperatorIssue]
+ elif is_list:
+ return list[inner_type]
+ return inner_type
+
+ model_class = None
+
+ # First check if inner_type is directly a DatabaseModel subclass
+ try:
+ if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
+ model_class = inner_type
+ except TypeError:
+ pass
+
+ # If not found, look up in _dict_to_model
+ if model_class is None:
+ model_class = _dict_to_model.get(inner_type) # type: ignore
+
+ if model_class is not None:
+ nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
+ resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
+
+ if is_optional:
+ resolved_type = resolved_type | None # type: ignore
+
+ return resolved_type
+
+ # Fallback: use the resolved inner_type
+ resolved_type = list[inner_type] if is_list else inner_type # type: ignore
+ if is_optional:
+ resolved_type = resolved_type | None # type: ignore
+ return resolved_type
+
+ if includes is None:
+ includes = ()
+
+ # Parse nested includes
+ direct_includes = []
+ sub_include_map: dict[str, list[str]] = {}
+ for include in includes:
+ if "." in include:
+ parent, sub_include = include.split(".", 1)
+ if parent not in sub_include_map:
+ sub_include_map[parent] = []
+ sub_include_map[parent].append(sub_include)
+ if parent not in direct_includes:
+ direct_includes.append(parent)
+ else:
+ direct_includes.append(include)
+
+ fields = {}
+
+ # Process model fields
+ for field_name, field_info in cls.model_fields.items():
+ field_type = field_info.annotation or Any
+ field_type = _evaluate_type(field_type, field_name=field_name)
+
+ if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
+ continue
+ else:
+ fields[field_name] = field_type
+
+ # Process calculated fields
+ for field_name, field_type in cls._CALCULATED_FIELDS.items():
+ field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
+ fields[field_name] = field_type
+
+ # Process ondemand calculated fields
+ for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
+ if field_name not in direct_includes:
+ continue
+
+ field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
+ fields[field_name] = field_type
+
+ return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]
diff --git a/app/database/beatmap.py b/app/database/beatmap.py
index 695e8a7..8072d46 100644
--- a/app/database/beatmap.py
+++ b/app/database/beatmap.py
@@ -1,117 +1,339 @@
from datetime import datetime
import hashlib
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.calculator import get_calculator
from app.config import settings
-from app.database.beatmap_tags import BeatmapTagVote
-from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DifficultyAttributesUnion
from app.models.score import GameMode
+from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
-from .beatmapset import Beatmapset, BeatmapsetResp
+from .beatmap_tags import BeatmapTagVote
+from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
+from .failtime import FailTime, FailTimeResp
+from .user import User, UserDict, UserModel
from pydantic import BaseModel, TypeAdapter
from redis.asyncio import Redis
from sqlalchemy import Column, DateTime
+from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
- from .user import User
-
class BeatmapOwner(SQLModel):
id: int
username: str
-class BeatmapBase(SQLModel):
- # Beatmap
- url: str
+class BeatmapDict(TypedDict):
+ beatmapset_id: int
+ difficulty_rating: float
+ id: int
mode: GameMode
+ total_length: int
+ user_id: int
+ version: str
+ url: str
+
+ checksum: NotRequired[str]
+ max_combo: NotRequired[int | None]
+ ar: NotRequired[float]
+ cs: NotRequired[float]
+ drain: NotRequired[float]
+ accuracy: NotRequired[float]
+ bpm: NotRequired[float]
+ count_circles: NotRequired[int]
+ count_sliders: NotRequired[int]
+ count_spinners: NotRequired[int]
+ deleted_at: NotRequired[datetime | None]
+ hit_length: NotRequired[int]
+ last_updated: NotRequired[datetime]
+
+ status: NotRequired[str]
+ beatmapset: NotRequired[BeatmapsetDict]
+ current_user_playcount: NotRequired[int]
+ current_user_tag_ids: NotRequired[list[int]]
+ failtimes: NotRequired[FailTimeResp]
+ top_tag_ids: NotRequired[list[dict[str, int]]]
+ user: NotRequired[UserDict]
+ convert: NotRequired[bool]
+ is_scoreable: NotRequired[bool]
+ mode_int: NotRequired[int]
+ ranked: NotRequired[int]
+ playcount: NotRequired[int]
+ passcount: NotRequired[int]
+
+
+class BeatmapModel(DatabaseModel[BeatmapDict]):
+ BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
+ "checksum",
+ "accuracy",
+ "ar",
+ "bpm",
+ "convert",
+ "count_circles",
+ "count_sliders",
+ "count_spinners",
+ "cs",
+ "deleted_at",
+ "drain",
+ "hit_length",
+ "is_scoreable",
+ "last_updated",
+ "mode_int",
+ "passcount",
+ "playcount",
+ "ranked",
+ "url",
+ ]
+ DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
+ "beatmapset.ratings",
+ "current_user_playcount",
+ "failtimes",
+ "max_combo",
+ "owners",
+ ]
+ TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
+
+ # Beatmap
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(default=0.0, index=True)
+ id: int = Field(primary_key=True, index=True)
+ mode: GameMode
total_length: int
user_id: int = Field(index=True)
version: str = Field(index=True)
+ url: OnDemand[str]
# optional
- checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
- current_user_playcount: int = Field(default=0)
- max_combo: int | None = Field(default=0)
- # TODO: failtimes, owners
+ checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
+ max_combo: OnDemand[int | None] = Field(default=0)
+ # TODO: owners
# BeatmapExtended
- ar: float = Field(default=0.0)
- cs: float = Field(default=0.0)
- drain: float = Field(default=0.0) # hp
- accuracy: float = Field(default=0.0) # od
- bpm: float = Field(default=0.0)
- count_circles: int = Field(default=0)
- count_sliders: int = Field(default=0)
- count_spinners: int = Field(default=0)
- deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
- hit_length: int = Field(default=0)
- last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
+ ar: OnDemand[float] = Field(default=0.0)
+ cs: OnDemand[float] = Field(default=0.0)
+ drain: OnDemand[float] = Field(default=0.0) # hp
+ accuracy: OnDemand[float] = Field(default=0.0) # od
+ bpm: OnDemand[float] = Field(default=0.0)
+ count_circles: OnDemand[int] = Field(default=0)
+ count_sliders: OnDemand[int] = Field(default=0)
+ count_spinners: OnDemand[int] = Field(default=0)
+ deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
+ hit_length: OnDemand[int] = Field(default=0)
+ last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
+
+ @included
+ @staticmethod
+ async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
+ if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
+ return BeatmapRankStatus.APPROVED.name.lower()
+ return beatmap.beatmap_status.name.lower()
+
+ @ondemand
+ @staticmethod
+ async def beatmapset(
+ _session: AsyncSession,
+ beatmap: "Beatmap",
+ includes: list[str] | None = None,
+ ) -> BeatmapsetDict | None:
+ if beatmap.beatmapset is not None:
+ return await BeatmapsetModel.transform(
+ beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
+ )
+
+ @ondemand
+ @staticmethod
+ async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
+ playcount = (
+ await _session.exec(
+ select(BeatmapPlaycounts.playcount).where(
+ BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
+ )
+ )
+ ).first()
+ return int(playcount or 0)
+
+ @ondemand
+ @staticmethod
+ async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
+ if user is None:
+ return []
+ tag_ids = (
+ await _session.exec(
+ select(BeatmapTagVote.tag_id).where(
+ BeatmapTagVote.beatmap_id == beatmap.id,
+ BeatmapTagVote.user_id == user.id,
+ )
+ )
+ ).all()
+ return list(tag_ids)
+
+ @ondemand
+ @staticmethod
+ async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
+ if beatmap.failtimes is not None:
+ return FailTimeResp.from_db(beatmap.failtimes)
+ return FailTimeResp()
+
+ @ondemand
+ @staticmethod
+ async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
+ all_votes = (
+ await _session.exec(
+ select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
+ .where(BeatmapTagVote.beatmap_id == beatmap.id)
+ .group_by(col(BeatmapTagVote.tag_id))
+ .having(func.count() > settings.beatmap_tag_top_count)
+ )
+ ).all()
+ top_tag_ids: list[dict[str, int]] = []
+ for id, votes in all_votes:
+ top_tag_ids.append({"tag_id": id, "count": votes})
+ top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
+ return top_tag_ids
+
+ @ondemand
+ @staticmethod
+ async def user(
+ _session: AsyncSession,
+ beatmap: "Beatmap",
+ includes: list[str] | None = None,
+ ) -> UserDict | None:
+ from .user import User
+
+ user = await _session.get(User, beatmap.user_id)
+ if user is None:
+ return None
+ return await UserModel.transform(user, includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
+ return False
+
+ @ondemand
+ @staticmethod
+ async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
+ beatmap_status = beatmap.beatmap_status
+ if settings.enable_all_beatmap_leaderboard:
+ return True
+ return beatmap_status.has_leaderboard()
+
+ @ondemand
+ @staticmethod
+ async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
+ return int(beatmap.mode)
+
+ @ondemand
+ @staticmethod
+ async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
+ beatmap_status = beatmap.beatmap_status
+ if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
+ return BeatmapRankStatus.APPROVED.value
+ return beatmap_status.value
+
+ @ondemand
+ @staticmethod
+ async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
+ result = (
+ await _session.exec(
+ select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
+ )
+ ).first()
+ return int(result or 0)
+
+ @ondemand
+ @staticmethod
+ async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
+ from .score import Score
+
+ return (
+ await _session.exec(
+ select(func.count())
+ .select_from(Score)
+ .where(
+ Score.beatmap_id == beatmap.id,
+ col(Score.passed).is_(True),
+ )
+ )
+ ).one()
-class Beatmap(BeatmapBase, table=True):
+class Beatmap(AsyncAttrs, BeatmapModel, table=True):
__tablename__: str = "beatmaps"
- id: int = Field(primary_key=True, index=True)
- beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
+
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
- beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
+ beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
- async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
- d = resp.model_dump()
- del d["beatmapset"]
+ async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
+ d = {k: v for k, v in resp.items() if k != "beatmapset"}
+ beatmapset_id = resp.get("beatmapset_id")
+ bid = resp.get("id")
+ ranked = resp.get("ranked")
+ if beatmapset_id is None or bid is None or ranked is None:
+ raise ValueError("beatmapset_id, id and ranked are required")
beatmap = cls.model_validate(
{
**d,
- "beatmapset_id": resp.beatmapset_id,
- "id": resp.id,
- "beatmap_status": BeatmapRankStatus(resp.ranked),
+ "beatmapset_id": beatmapset_id,
+ "id": bid,
+ "beatmap_status": BeatmapRankStatus(ranked),
}
)
return beatmap
@classmethod
- async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
+ async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
beatmap = await cls.from_resp_no_save(session, resp)
- if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
+ resp_id = resp.get("id")
+ if resp_id is None:
+ raise ValueError("id is required")
+ if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
session.add(beatmap)
await session.commit()
- return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
+ return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
@classmethod
- async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
+ async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
- for resp in inp:
- if resp.id == from_:
+ for resp_dict in inp:
+ bid = resp_dict.get("id")
+ if bid == from_ or bid is None:
continue
- d = resp.model_dump()
- del d["beatmapset"]
+
+ beatmapset_id = resp_dict.get("beatmapset_id")
+ ranked = resp_dict.get("ranked")
+ if beatmapset_id is None or ranked is None:
+ continue
+
+ # 创建 beatmap 字典,移除 beatmapset
+ d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
+
beatmap = Beatmap.model_validate(
{
**d,
- "beatmapset_id": resp.beatmapset_id,
- "id": resp.id,
- "beatmap_status": BeatmapRankStatus(resp.ranked),
+ "beatmapset_id": beatmapset_id,
+ "id": bid,
+ "beatmap_status": BeatmapRankStatus(ranked),
}
)
- if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
+ if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
+ for beatmap in beatmaps:
+ await session.refresh(beatmap)
return beatmaps
@classmethod
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
beatmap = (await session.exec(stmt)).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
- r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
+ beatmapset_id = resp.get("beatmapset_id")
+ if beatmapset_id is None:
+ raise ValueError("beatmapset_id is required")
+ r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
if not r.first():
- set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
- await Beatmapset.from_resp(session, set_resp, from_=resp.id)
+ set_resp = await fetcher.get_beatmapset(beatmapset_id)
+ resp_id = resp.get("id")
+ await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
return await Beatmap.from_resp(session, resp)
return beatmap
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
count: int
-class BeatmapResp(BeatmapBase):
- id: int
- beatmapset_id: int
- beatmapset: BeatmapsetResp | None = None
- convert: bool = False
- is_scoreable: bool
- status: str
- mode_int: int
- ranked: int
- url: str = ""
- playcount: int = 0
- passcount: int = 0
- failtimes: FailTimeResp | None = None
- top_tag_ids: list[APIBeatmapTag] | None = None
- current_user_tag_ids: list[int] | None = None
- is_deleted: bool = False
-
- @classmethod
- async def from_db(
- cls,
- beatmap: Beatmap,
- query_mode: GameMode | None = None,
- from_set: bool = False,
- session: AsyncSession | None = None,
- user: "User | None" = None,
- ) -> "BeatmapResp":
- from .score import Score
-
- beatmap_ = beatmap.model_dump()
- beatmap_status = beatmap.beatmap_status
- if query_mode is not None and beatmap.mode != query_mode:
- beatmap_["convert"] = True
- beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
- if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
- beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
- beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
- else:
- beatmap_["status"] = beatmap_status.name.lower()
- beatmap_["ranked"] = beatmap_status.value
- beatmap_["mode_int"] = int(beatmap.mode)
- beatmap_["is_deleted"] = beatmap.deleted_at is not None
- if not from_set:
- beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
- if beatmap.failtimes is not None:
- beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
- else:
- beatmap_["failtimes"] = FailTimeResp()
- if session:
- beatmap_["playcount"] = (
- await session.exec(
- select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
- )
- ).first() or 0
- beatmap_["passcount"] = (
- await session.exec(
- select(func.count())
- .select_from(Score)
- .where(
- Score.beatmap_id == beatmap.id,
- col(Score.passed).is_(True),
- )
- )
- ).one()
-
- all_votes = (
- await session.exec(
- select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
- .where(BeatmapTagVote.beatmap_id == beatmap.id)
- .group_by(col(BeatmapTagVote.tag_id))
- .having(func.count() > settings.beatmap_tag_top_count)
- )
- ).all()
- top_tag_ids: list[dict[str, int]] = []
- for id, votes in all_votes:
- top_tag_ids.append({"tag_id": id, "count": votes})
- top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
- beatmap_["top_tag_ids"] = top_tag_ids
-
- if user is not None:
- beatmap_["current_user_tag_ids"] = (
- await session.exec(
- select(BeatmapTagVote.tag_id)
- .where(BeatmapTagVote.beatmap_id == beatmap.id)
- .where(BeatmapTagVote.user_id == user.id)
- )
- ).all()
- else:
- beatmap_["current_user_tag_ids"] = []
- return cls.model_validate(beatmap_)
-
-
class BannedBeatmaps(SQLModel, table=True):
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)
diff --git a/app/database/beatmap_playcounts.py b/app/database/beatmap_playcounts.py
index 2635d6a..cf2fb6c 100644
--- a/app/database/beatmap_playcounts.py
+++ b/app/database/beatmap_playcounts.py
@@ -1,10 +1,11 @@
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.config import settings
-from app.database.events import Event, EventType
from app.utils import utcnow
-from pydantic import BaseModel
+from ._base import DatabaseModel, included
+from .events import Event, EventType
+
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -12,52 +13,65 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
- SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
- from .beatmap import Beatmap, BeatmapResp
- from .beatmapset import BeatmapsetResp
+ from .beatmap import Beatmap, BeatmapDict
+ from .beatmapset import BeatmapsetDict
from .user import User
-class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
+class BeatmapPlaycountsDict(TypedDict):
+ user_id: int
+ beatmap_id: int
+ count: NotRequired[int]
+ beatmap: NotRequired["BeatmapDict"]
+ beatmapset: NotRequired["BeatmapsetDict"]
+
+
+class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]):
__tablename__: str = "beatmap_playcounts"
id: int | None = Field(
- default=None,
- sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
+ default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
- playcount: int = Field(default=0)
+ playcount: int = Field(default=0, exclude=True)
+ @included
+ @staticmethod
+ async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int:
+ return obj.playcount
+
+ @included
+ @staticmethod
+ async def beatmap(
+ _session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
+ ) -> "BeatmapDict":
+ from .beatmap import BeatmapModel
+
+ await obj.awaitable_attrs.beatmap
+ return await BeatmapModel.transform(obj.beatmap, includes=includes)
+
+ @included
+ @staticmethod
+ async def beatmapset(
+ _session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
+ ) -> "BeatmapsetDict":
+ from .beatmap import BeatmapsetModel
+
+ await obj.awaitable_attrs.beatmap
+ return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes)
+
+
+class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True):
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
-class BeatmapPlaycountsResp(BaseModel):
- beatmap_id: int
- beatmap: "BeatmapResp | None" = None
- beatmapset: "BeatmapsetResp | None" = None
- count: int
-
- @classmethod
- async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp":
- from .beatmap import BeatmapResp
- from .beatmapset import BeatmapsetResp
-
- await db_model.awaitable_attrs.beatmap
- return cls(
- beatmap_id=db_model.beatmap_id,
- count=db_model.playcount,
- beatmap=await BeatmapResp.from_db(db_model.beatmap),
- beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset),
- )
-
-
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
existing_playcount = (
await session.exec(
diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py
index fd58c8b..ff8346d 100644
--- a/app/database/beatmapset.py
+++ b/app/database/beatmapset.py
@@ -1,14 +1,15 @@
from datetime import datetime
-from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
+from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.config import settings
-from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
-from .user import BASE_INCLUDES, User, UserResp
+from ._base import DatabaseModel, OnDemand, included, ondemand
+from .beatmap_playcounts import BeatmapPlaycounts
+from .user import User, UserDict
-from pydantic import BaseModel, field_validator, model_validator
+from pydantic import BaseModel
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
- from .beatmap import Beatmap, BeatmapResp
+ from .beatmap import Beatmap, BeatmapDict
from .favourite_beatmapset import FavouriteBeatmapset
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
id: int | None = None
-class BeatmapsetBase(SQLModel):
+class BeatmapsetDict(TypedDict):
+ id: int
+ artist: str
+ artist_unicode: str
+ covers: BeatmapCovers | None
+ creator: str
+ nsfw: bool
+ preview_url: str
+ source: str
+ spotlight: bool
+ title: str
+ title_unicode: str
+ track_id: int | None
+ user_id: int
+ video: bool
+ current_nominations: list[BeatmapNomination] | None
+ description: BeatmapDescription | None
+ pack_tags: list[str]
+
+ bpm: NotRequired[float]
+ can_be_hyped: NotRequired[bool]
+ discussion_locked: NotRequired[bool]
+ last_updated: NotRequired[datetime]
+ ranked_date: NotRequired[datetime | None]
+ storyboard: NotRequired[bool]
+ submitted_date: NotRequired[datetime]
+ tags: NotRequired[str]
+ discussion_enabled: NotRequired[bool]
+ legacy_thread_url: NotRequired[str | None]
+ status: NotRequired[str]
+ ranked: NotRequired[int]
+ is_scoreable: NotRequired[bool]
+ favourite_count: NotRequired[int]
+ genre_id: NotRequired[int]
+ hype: NotRequired[BeatmapHype]
+ language_id: NotRequired[int]
+ play_count: NotRequired[int]
+ availability: NotRequired[BeatmapAvailability]
+ beatmaps: NotRequired[list["BeatmapDict"]]
+ has_favourited: NotRequired[bool]
+ recent_favourites: NotRequired[list[UserDict]]
+ genre: NotRequired[BeatmapTranslationText]
+ language: NotRequired[BeatmapTranslationText]
+ nominations: NotRequired["BeatmapNominations"]
+ ratings: NotRequired[list[int]]
+
+
+class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
+ BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
+ "availability",
+ "has_favourited",
+ "bpm",
+ "deleted_atcan_be_hyped",
+ "discussion_locked",
+ "is_scoreable",
+ "last_updated",
+ "legacy_thread_url",
+ "ranked",
+ "ranked_date",
+ "submitted_date",
+ "tags",
+ "rating",
+ "storyboard",
+ ]
+ API_INCLUDES: ClassVar[list[str]] = [
+ *BEATMAPSET_TRANSFORMER_INCLUDES,
+ "beatmaps.current_user_playcount",
+ "beatmaps.current_user_tag_ids",
+ "beatmaps.max_combo",
+ "current_nominations",
+ "current_user_attributes",
+ "description",
+ "genre",
+ "language",
+ "pack_tags",
+ "ratings",
+ "recent_favourites",
+ "related_tags",
+ "related_users",
+ "user",
+ "version_count",
+ *[
+ f"beatmaps.{inc}"
+ for inc in {
+ "failtimes",
+ "owners",
+ "top_tag_ids",
+ }
+ ],
+ ]
+
# Beatmapset
+ id: int = Field(default=None, primary_key=True, index=True)
artist: str = Field(index=True)
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
preview_url: str
source: str = Field(default="")
-
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
title: str = Field(index=True)
title_unicode: str = Field(index=True)
+ track_id: int | None = Field(default=None, index=True) # feature artist?
user_id: int = Field(index=True)
video: bool = Field(sa_column=Column(Boolean, index=True))
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
- current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
- description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
+ current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
+ description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
# TODO: events: Optional[list[BeatmapsetEvent]] = None
- pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
+ pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
- track_id: int | None = Field(default=None, index=True) # feature artist?
# BeatmapsetExtended
- bpm: float = Field(default=0.0)
- can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
- discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
- last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
- ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
- storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
- submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
- tags: str = Field(default="", sa_column=Column(Text))
+ bpm: OnDemand[float] = Field(default=0.0)
+ can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
+ discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
+ last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
+ ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
+ storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
+ submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
+ tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
+
+ @ondemand
+ @staticmethod
+ async def legacy_thread_url(
+ _session: AsyncSession,
+ _beatmapset: "Beatmapset",
+ ) -> str | None:
+ return None
+
+ @included
+ @staticmethod
+ async def discussion_enabled(
+ _session: AsyncSession,
+ _beatmapset: "Beatmapset",
+ ) -> bool:
+ return True
+
+ @included
+ @staticmethod
+ async def status(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> str:
+ beatmap_status = beatmapset.beatmap_status
+ if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
+ return BeatmapRankStatus.APPROVED.name.lower()
+ return beatmap_status.name.lower()
+
+ @included
+ @staticmethod
+ async def ranked(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> int:
+ beatmap_status = beatmapset.beatmap_status
+ if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
+ return BeatmapRankStatus.APPROVED.value
+ return beatmap_status.value
+
+ @ondemand
+ @staticmethod
+ async def is_scoreable(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> bool:
+ beatmap_status = beatmapset.beatmap_status
+ if settings.enable_all_beatmap_leaderboard:
+ return True
+ return beatmap_status.has_leaderboard()
+
+ @included
+ @staticmethod
+ async def favourite_count(
+ session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> int:
+ from .favourite_beatmapset import FavouriteBeatmapset
+
+ count = await session.exec(
+ select(func.count())
+ .select_from(FavouriteBeatmapset)
+ .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
+ )
+ return count.one()
+
+ @included
+ @staticmethod
+ async def genre_id(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> int:
+ return beatmapset.beatmap_genre.value
+
+ @ondemand
+ @staticmethod
+ async def hype(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> BeatmapHype:
+ return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
+
+ @included
+ @staticmethod
+ async def language_id(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> int:
+ return beatmapset.beatmap_language.value
+
+ @included
+ @staticmethod
+ async def play_count(
+ session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> int:
+ from .beatmap import Beatmap
+
+ playcount = await session.exec(
+ select(func.sum(BeatmapPlaycounts.playcount)).where(
+ col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
+ )
+ )
+ return int(playcount.first() or 0)
+
+ @ondemand
+ @staticmethod
+ async def availability(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> BeatmapAvailability:
+ return BeatmapAvailability(
+ more_information=beatmapset.availability_info,
+ download_disabled=beatmapset.download_disabled,
+ )
+
+ @ondemand
+ @staticmethod
+ async def beatmaps(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ includes: list[str] | None = None,
+ user: "User | None" = None,
+ ) -> list["BeatmapDict"]:
+ from .beatmap import BeatmapModel
+
+ return [
+ await BeatmapModel.transform(
+ beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
+ )
+ for beatmap in await beatmapset.awaitable_attrs.beatmaps
+ ]
+
+ # @ondemand
+ # @staticmethod
+ # async def current_nominations(
+ # _session: AsyncSession,
+ # beatmapset: "Beatmapset",
+ # ) -> list[BeatmapNomination] | None:
+ # return beatmapset.current_nominations or []
+
+ @ondemand
+ @staticmethod
+ async def has_favourited(
+ session: AsyncSession,
+ beatmapset: "Beatmapset",
+ user: User | None = None,
+ ) -> bool:
+ from .favourite_beatmapset import FavouriteBeatmapset
+
+ if session is None:
+ return False
+ query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
+ if user is not None:
+ query = query.where(FavouriteBeatmapset.user_id == user.id)
+ existing = (await session.exec(query)).first()
+ return existing is not None
+
+ @ondemand
+ @staticmethod
+ async def recent_favourites(
+ session: AsyncSession,
+ beatmapset: "Beatmapset",
+ includes: list[str] | None = None,
+ ) -> list[UserDict]:
+ from .favourite_beatmapset import FavouriteBeatmapset
+
+ recent_favourites = (
+ await session.exec(
+ select(FavouriteBeatmapset)
+ .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
+ .order_by(col(FavouriteBeatmapset.date).desc())
+ .limit(50)
+ )
+ ).all()
+ return [
+ await User.transform(
+ (await favourite.awaitable_attrs.user),
+ includes=includes,
+ )
+ for favourite in recent_favourites
+ ]
+
+ @ondemand
+ @staticmethod
+ async def genre(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> BeatmapTranslationText:
+ return BeatmapTranslationText(
+ name=beatmapset.beatmap_genre.name,
+ id=beatmapset.beatmap_genre.value,
+ )
+
+ @ondemand
+ @staticmethod
+ async def language(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> BeatmapTranslationText:
+ return BeatmapTranslationText(
+ name=beatmapset.beatmap_language.name,
+ id=beatmapset.beatmap_language.value,
+ )
+
+ @ondemand
+ @staticmethod
+ async def nominations(
+ _session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> BeatmapNominations:
+ return BeatmapNominations(
+ required=beatmapset.nominations_required,
+ current=beatmapset.nominations_current,
+ )
+
+ # @ondemand
+ # @staticmethod
+ # async def user(
+ # session: AsyncSession,
+ # beatmapset: Beatmapset,
+ # includes: list[str] | None = None,
+ # ) -> dict[str, Any] | None:
+ # db_user = await session.get(User, beatmapset.user_id)
+ # if not db_user:
+ # return None
+ # return await UserResp.transform(db_user, includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def ratings(
+ session: AsyncSession,
+ beatmapset: "Beatmapset",
+ ) -> list[int]:
+ # Provide a stable default shape if no session is available
+ if session is None:
+ return []
+
+ from .beatmapset_ratings import BeatmapRating
+
+ beatmapset_all_ratings = (
+ await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
+ ).all()
+ ratings_list = [0] * 11
+ for rating in beatmapset_all_ratings:
+ ratings_list[rating.rating] += 1
+ return ratings_list
-class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
+class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
__tablename__: str = "beatmapsets"
- id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
- async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
- d = resp.model_dump()
- if resp.nominations:
- d["nominations_required"] = resp.nominations.required
- d["nominations_current"] = resp.nominations.current
- if resp.hype:
- d["hype_current"] = resp.hype.current
- d["hype_required"] = resp.hype.required
- if resp.genre_id:
- d["beatmap_genre"] = Genre(resp.genre_id)
- elif resp.genre:
- d["beatmap_genre"] = Genre(resp.genre.id)
- if resp.language_id:
- d["beatmap_language"] = Language(resp.language_id)
- elif resp.language:
- d["beatmap_language"] = Language(resp.language.id)
+ async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
+ # make a shallow copy so we can mutate safely
+ d: dict[str, Any] = dict(resp)
+
+ # nominations = resp.get("nominations")
+ # if nominations is not None:
+ # d["nominations_required"] = nominations.required
+ # d["nominations_current"] = nominations.current
+
+ hype = resp.get("hype")
+ if hype is not None:
+ d["hype_current"] = hype.current
+ d["hype_required"] = hype.required
+
+ genre_id = resp.get("genre_id")
+ genre = resp.get("genre")
+ if genre_id is not None:
+ d["beatmap_genre"] = Genre(genre_id)
+ elif genre is not None:
+ d["beatmap_genre"] = Genre(genre.id)
+
+ language_id = resp.get("language_id")
+ language = resp.get("language")
+ if language_id is not None:
+ d["beatmap_language"] = Language(language_id)
+ elif language is not None:
+ d["beatmap_language"] = Language(language.id)
+
+ availability = resp.get("availability")
+ ranked = resp.get("ranked")
+ if ranked is None:
+ raise ValueError("ranked field is required")
+
beatmapset = Beatmapset.model_validate(
{
**d,
- "id": resp.id,
- "beatmap_status": BeatmapRankStatus(resp.ranked),
- "availability_info": resp.availability.more_information,
- "download_disabled": resp.availability.download_disabled or False,
+ "beatmap_status": BeatmapRankStatus(ranked),
+ "availability_info": availability.more_information if availability is not None else None,
+ "download_disabled": bool(availability.download_disabled) if availability is not None else False,
}
)
return beatmapset
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
async def from_resp(
cls,
session: AsyncSession,
- resp: "BeatmapsetResp",
+ resp: BeatmapsetDict,
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
+ beatmapset_id = resp["id"]
beatmapset = await cls.from_resp_no_save(resp)
- if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
+ if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
session.add(beatmapset)
await session.commit()
- await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
- beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
+ beatmaps = resp.get("beatmaps", [])
+ await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
+ beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
return beatmapset
@classmethod
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
await get_beatmapset_update_service().add(resp)
+ await session.refresh(beatmapset)
return beatmapset
-
-
-class BeatmapsetResp(BeatmapsetBase):
- id: int
- beatmaps: list["BeatmapResp"] = Field(default_factory=list)
- discussion_enabled: bool = True
- status: str
- ranked: int
- legacy_thread_url: str | None = ""
- is_scoreable: bool
- hype: BeatmapHype | None = None
- availability: BeatmapAvailability
- genre: BeatmapTranslationText | None = None
- genre_id: int
- language: BeatmapTranslationText | None = None
- language_id: int
- nominations: BeatmapNominations | None = None
- has_favourited: bool = False
- favourite_count: int = 0
- recent_favourites: list[UserResp] = Field(default_factory=list)
- play_count: int = 0
-
- @field_validator(
- "nsfw",
- "spotlight",
- "video",
- "can_be_hyped",
- "discussion_locked",
- "storyboard",
- "discussion_enabled",
- "is_scoreable",
- "has_favourited",
- mode="before",
- )
- @classmethod
- def validate_bool_fields(cls, v):
- """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
- if isinstance(v, int):
- return bool(v)
- return v
-
- @model_validator(mode="after")
- def fix_genre_language(self) -> Self:
- if self.genre is None:
- self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
- if self.language is None:
- self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
- return self
-
- @classmethod
- async def from_db(
- cls,
- beatmapset: Beatmapset,
- include: list[str] = [],
- session: AsyncSession | None = None,
- user: User | None = None,
- ) -> "BeatmapsetResp":
- from .beatmap import Beatmap, BeatmapResp
- from .favourite_beatmapset import FavouriteBeatmapset
-
- update = {
- "beatmaps": [
- await BeatmapResp.from_db(beatmap, from_set=True, session=session)
- for beatmap in await beatmapset.awaitable_attrs.beatmaps
- ],
- "hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
- "availability": BeatmapAvailability(
- more_information=beatmapset.availability_info,
- download_disabled=beatmapset.download_disabled,
- ),
- "genre": BeatmapTranslationText(
- name=beatmapset.beatmap_genre.name,
- id=beatmapset.beatmap_genre.value,
- ),
- "language": BeatmapTranslationText(
- name=beatmapset.beatmap_language.name,
- id=beatmapset.beatmap_language.value,
- ),
- "genre_id": beatmapset.beatmap_genre.value,
- "language_id": beatmapset.beatmap_language.value,
- "nominations": BeatmapNominations(
- required=beatmapset.nominations_required,
- current=beatmapset.nominations_current,
- ),
- "is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
- **beatmapset.model_dump(),
- }
-
- if session is not None:
- # 从数据库读取对应谱面集的评分
- from .beatmapset_ratings import BeatmapRating
-
- beatmapset_all_ratings = (
- await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
- ).all()
- ratings_list = [0] * 11
- for rating in beatmapset_all_ratings:
- ratings_list[rating.rating] += 1
- update["ratings"] = ratings_list
- else:
- # 返回非空值避免客户端崩溃
- if update.get("ratings") is None:
- update["ratings"] = []
-
- beatmap_status = beatmapset.beatmap_status
- if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
- update["status"] = BeatmapRankStatus.APPROVED.name.lower()
- update["ranked"] = BeatmapRankStatus.APPROVED.value
- else:
- update["status"] = beatmap_status.name.lower()
- update["ranked"] = beatmap_status.value
-
- if session and user:
- existing_favourite = (
- await session.exec(
- select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
- )
- ).first()
- update["has_favourited"] = existing_favourite is not None
-
- if session and "recent_favourites" in include:
- recent_favourites = (
- await session.exec(
- select(FavouriteBeatmapset)
- .where(
- FavouriteBeatmapset.beatmapset_id == beatmapset.id,
- )
- .order_by(col(FavouriteBeatmapset.date).desc())
- .limit(50)
- )
- ).all()
- update["recent_favourites"] = [
- await UserResp.from_db(
- await favourite.awaitable_attrs.user,
- session=session,
- include=BASE_INCLUDES,
- )
- for favourite in recent_favourites
- ]
-
- if session:
- update["favourite_count"] = (
- await session.exec(
- select(func.count())
- .select_from(FavouriteBeatmapset)
- .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
- )
- ).one()
-
- update["play_count"] = (
- await session.exec(
- select(func.sum(BeatmapPlaycounts.playcount)).where(
- col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
- )
- )
- ).first() or 0
- return cls.model_validate(
- update,
- )
-
-
-class SearchBeatmapsetsResp(SQLModel):
- beatmapsets: list[BeatmapsetResp]
- total: int
- cursor: dict[str, int | float | str] | None = None
- cursor_string: str | None = None
diff --git a/app/database/beatmapset_ratings.py b/app/database/beatmapset_ratings.py
index 8b63d88..c6e1d75 100644
--- a/app/database/beatmapset_ratings.py
+++ b/app/database/beatmapset_ratings.py
@@ -1,5 +1,5 @@
-from app.database.beatmapset import Beatmapset
-from app.database.user import User
+from .beatmapset import Beatmapset
+from .user import User
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel
diff --git a/app/database/best_scores.py b/app/database/best_scores.py
index b1f9059..13fead7 100644
--- a/app/database/best_scores.py
+++ b/app/database/best_scores.py
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING
-from app.database.statistics import UserStatistics
from app.models.score import GameMode
+from .statistics import UserStatistics
from .user import User
from sqlmodel import (
diff --git a/app/database/chat.py b/app/database/chat.py
index b05c790..3dfb0cb 100644
--- a/app/database/chat.py
+++ b/app/database/chat.py
@@ -1,13 +1,14 @@
from datetime import datetime
from enum import Enum
-from typing import Self
+from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
-from app.database.user import RANKING_INCLUDES, User, UserResp
from app.models.model import UTCBaseModel
from app.utils import utcnow
+from ._base import DatabaseModel, included, ondemand
+from .user import User, UserDict, UserModel
+
from pydantic import BaseModel
-from redis.asyncio import Redis
from sqlmodel import (
VARCHAR,
BigInteger,
@@ -22,6 +23,8 @@ from sqlmodel import (
)
from sqlmodel.ext.asyncio.session import AsyncSession
+if TYPE_CHECKING:
+ from app.router.notification.server import ChatServer
# ChatChannel
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
TEAM = "TEAM"
-class ChatChannelBase(SQLModel):
- name: str = Field(sa_column=Column(VARCHAR(50), index=True))
+class MessageType(str, Enum):
+ ACTION = "action"
+ MARKDOWN = "markdown"
+ PLAIN = "plain"
+
+
+class ChatChannelDict(TypedDict):
+ channel_id: int
+ description: str
+ name: str
+ icon: str | None
+ type: ChannelType
+ uuid: NotRequired[str | None]
+ message_length_limit: NotRequired[int]
+ moderated: NotRequired[bool]
+ current_user_attributes: NotRequired[ChatUserAttributes]
+ last_read_id: NotRequired[int | None]
+ last_message_id: NotRequired[int | None]
+ recent_messages: NotRequired[list["ChatMessageDict"]]
+ users: NotRequired[list[int]]
+
+
+class ChatChannelModel(DatabaseModel[ChatChannelDict]):
+ CONVERSATION_INCLUDES: ClassVar[list[str]] = [
+ "last_message_id",
+ "users",
+ ]
+ LISTING_INCLUDES: ClassVar[list[str]] = [
+ *CONVERSATION_INCLUDES,
+ "current_user_attributes",
+ "last_read_id",
+ ]
+
+ channel_id: int = Field(primary_key=True, index=True, default=None)
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
icon: str | None = Field(default=None)
type: ChannelType = Field(index=True)
+ @included
+ @staticmethod
+ async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
+ users = server.channels.get(channel.channel_id, [])
+ if channel.type == ChannelType.PM and users and len(users) == 2:
+ target_user_id = next(u for u in users if u != user.id)
+ target_name = await session.exec(select(User.username).where(User.id == target_user_id))
+ return target_name.one()
+ return channel.name
-class ChatChannel(ChatChannelBase, table=True):
+ @included
+ @staticmethod
+ async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
+ silence = (
+ await session.exec(
+ select(SilenceUser).where(
+ SilenceUser.channel_id == channel.channel_id,
+ SilenceUser.user_id == user.id,
+ )
+ )
+ ).first()
+
+ return silence is not None
+
+ @ondemand
+ @staticmethod
+ async def current_user_attributes(
+ session: AsyncSession,
+ channel: "ChatChannel",
+ user: User,
+ ) -> ChatUserAttributes:
+ from app.dependencies.database import get_redis
+
+ silence = (
+ await session.exec(
+ select(SilenceUser).where(
+ SilenceUser.channel_id == channel.channel_id,
+ SilenceUser.user_id == user.id,
+ )
+ )
+ ).first()
+ can_message = silence is None
+ can_message_error = "You are silenced in this channel" if not can_message else None
+
+ redis = get_redis()
+ last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
+ last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
+ last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
+ last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
+
+ return ChatUserAttributes(
+ can_message=can_message,
+ can_message_error=can_message_error,
+ last_read_id=last_read_id,
+ )
+
+ @ondemand
+ @staticmethod
+ async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
+ from app.dependencies.database import get_redis
+
+ redis = get_redis()
+ last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
+ last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
+ last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
+ return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
+
+ @ondemand
+ @staticmethod
+ async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
+ from app.dependencies.database import get_redis
+
+ redis = get_redis()
+ last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
+ return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
+
+ @ondemand
+ @staticmethod
+ async def recent_messages(
+ session: AsyncSession,
+ channel: "ChatChannel",
+ ) -> list["ChatMessageDict"]:
+ messages = (
+ await session.exec(
+ select(ChatMessage)
+ .where(ChatMessage.channel_id == channel.channel_id)
+ .order_by(col(ChatMessage.message_id).desc())
+ .limit(50)
+ )
+ ).all()
+ result = [
+ await ChatMessageModel.transform(
+ msg,
+ )
+ for msg in reversed(messages)
+ ]
+ return result
+
+ @ondemand
+ @staticmethod
+ async def users(
+ _session: AsyncSession,
+ channel: "ChatChannel",
+ server: "ChatServer",
+ user: User,
+ ) -> list[int]:
+ if channel.type == ChannelType.PUBLIC:
+ return []
+ users = server.channels.get(channel.channel_id, []).copy()
+ if channel.type == ChannelType.PM and users and len(users) == 2:
+ target_user_id = next(u for u in users if u != user.id)
+ users = [target_user_id, user.id]
+ return users
+
+ @included
+ @staticmethod
+ async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
+ return 1000
+
+
+class ChatChannel(ChatChannelModel, table=True):
__tablename__: str = "chat_channels"
- channel_id: int = Field(primary_key=True, index=True, default=None)
+
+ name: str = Field(sa_column=Column(VARCHAR(50), index=True))
@classmethod
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
return channel
-class ChatChannelResp(ChatChannelBase):
- channel_id: int
- moderated: bool = False
- uuid: str | None = None
- current_user_attributes: ChatUserAttributes | None = None
- last_read_id: int | None = None
- last_message_id: int | None = None
- recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
- users: list[int] = Field(default_factory=list)
- message_length_limit: int = 1000
-
- @classmethod
- async def from_db(
- cls,
- channel: ChatChannel,
- session: AsyncSession,
- user: User,
- redis: Redis,
- users: list[int] | None = None,
- include_recent_messages: bool = False,
- ) -> Self:
- c = cls.model_validate(channel)
- silence = (
- await session.exec(
- select(SilenceUser).where(
- SilenceUser.channel_id == channel.channel_id,
- SilenceUser.user_id == user.id,
- )
- )
- ).first()
-
- last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
- last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
-
- last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
- last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
-
- if silence is not None:
- attribute = ChatUserAttributes(
- can_message=False,
- can_message_error=silence.reason or "You are muted in this channel.",
- last_read_id=last_read_id or 0,
- )
- c.moderated = True
- else:
- attribute = ChatUserAttributes(
- can_message=True,
- last_read_id=last_read_id or 0,
- )
- c.moderated = False
-
- c.current_user_attributes = attribute
- if c.type != ChannelType.PUBLIC and users is not None:
- c.users = users
- c.last_message_id = last_msg
- c.last_read_id = last_read_id
-
- if include_recent_messages:
- messages = (
- await session.exec(
- select(ChatMessage)
- .where(ChatMessage.channel_id == channel.channel_id)
- .order_by(col(ChatMessage.timestamp).desc())
- .limit(10)
- )
- ).all()
- c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
- c.recent_messages.reverse()
-
- if c.type == ChannelType.PM and users and len(users) == 2:
- target_user_id = next(u for u in users if u != user.id)
- target_name = await session.exec(select(User.username).where(User.id == target_user_id))
- c.name = target_name.one()
- c.users = [target_user_id, user.id]
- return c
-
-
# ChatMessage
+class ChatMessageDict(TypedDict):
+ channel_id: int
+ content: str
+ message_id: int
+ sender_id: int
+ timestamp: datetime
+ type: MessageType
+ uuid: str | None
+ is_action: NotRequired[bool]
+ sender: NotRequired[UserDict]
-class MessageType(str, Enum):
- ACTION = "action"
- MARKDOWN = "markdown"
- PLAIN = "plain"
-
-
-class ChatMessageBase(UTCBaseModel, SQLModel):
+class ChatMessageModel(DatabaseModel[ChatMessageDict]):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int = Field(index=True, primary_key=True, default=None)
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
+ @included
+ @staticmethod
+ async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
+ return db_message.type == MessageType.ACTION
-class ChatMessage(ChatMessageBase, table=True):
+ @ondemand
+ @staticmethod
+ async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
+ return await UserModel.transform(db_message.user)
+
+
+class ChatMessage(ChatMessageModel, table=True):
__tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
- channel: ChatChannel = Relationship()
-
-
-class ChatMessageResp(ChatMessageBase):
- sender: UserResp | None = None
- is_action: bool = False
-
- @classmethod
- async def from_db(
- cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
- ) -> "ChatMessageResp":
- m = cls.model_validate(db_message.model_dump())
- m.is_action = db_message.type == MessageType.ACTION
- if user:
- m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
- else:
- m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
- return m
-
-
-# SilenceUser
+ channel: "ChatChannel" = Relationship()
class SilenceUser(UTCBaseModel, SQLModel, table=True):
diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py
index 308bc30..2925207 100644
--- a/app/database/favourite_beatmapset.py
+++ b/app/database/favourite_beatmapset.py
@@ -1,7 +1,7 @@
import datetime
-from app.database.beatmapset import Beatmapset
-from app.database.user import User
+from .beatmapset import Beatmapset
+from .user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
diff --git a/app/database/item_attempts_count.py b/app/database/item_attempts_count.py
index a4487c2..4522787 100644
--- a/app/database/item_attempts_count.py
+++ b/app/database/item_attempts_count.py
@@ -1,7 +1,9 @@
-from .playlist_best_score import PlaylistBestScore
-from .user import User, UserResp
+from typing import Any, NotRequired, TypedDict
+
+from ._base import DatabaseModel, ondemand
+from .playlist_best_score import PlaylistBestScore
+from .user import User, UserDict, UserModel
-from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -9,7 +11,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
- SQLModel,
col,
func,
select,
@@ -17,17 +18,66 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
-class ItemAttemptsCountBase(SQLModel):
- room_id: int = Field(foreign_key="rooms.id", index=True)
+class ItemAttemptsCountDict(TypedDict):
+ accuracy: float
+ attempts: int
+ completed: int
+ pp: float
+ room_id: int
+ total_score: int
+ user_id: int
+ user: NotRequired[UserDict]
+ position: NotRequired[int]
+ playlist_item_attempts: NotRequired[list[dict[str, Any]]]
+
+
+class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]):
+ accuracy: float = 0.0
attempts: int = Field(default=0)
completed: int = Field(default=0)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
- accuracy: float = 0.0
pp: float = 0
+ room_id: int = Field(foreign_key="rooms.id", index=True)
total_score: int = 0
+ user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
+
+ @ondemand
+ @staticmethod
+ async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict:
+ user_instance = await item_attempts.awaitable_attrs.user
+ return await UserModel.transform(user_instance)
+
+ @ondemand
+ @staticmethod
+ async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int:
+ return await item_attempts.get_position(session)
+
+ @ondemand
+ @staticmethod
+ async def playlist_item_attempts(
+ session: AsyncSession,
+ item_attempts: "ItemAttemptsCount",
+ ) -> list[dict[str, Any]]:
+ playlist_scores = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.room_id == item_attempts.room_id,
+ PlaylistBestScore.user_id == item_attempts.user_id,
+ )
+ )
+ ).all()
+ result: list[dict[str, Any]] = []
+ for score in playlist_scores:
+ result.append(
+ {
+ "id": score.playlist_id,
+ "attempts": score.attempts,
+ "passed": score.score.passed,
+ }
+ )
+ return result
-class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
+class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True):
__tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True)
@@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
rownum = (
func.row_number()
.over(
- partition_by=col(ItemAttemptsCountBase.room_id),
- order_by=col(ItemAttemptsCountBase.total_score).desc(),
+ partition_by=col(ItemAttemptsCount.room_id),
+ order_by=col(ItemAttemptsCount.total_score).desc(),
)
.label("rn")
)
- subq = select(ItemAttemptsCountBase, rownum).subquery()
+ subq = select(ItemAttemptsCount, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
result = await session.exec(stmt)
- return result.one()
+ return result.first() or 0
async def update(self, session: AsyncSession):
playlist_scores = (
@@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
await session.refresh(item_attempts)
await item_attempts.update(session)
return item_attempts
-
-
-class ItemAttemptsResp(ItemAttemptsCountBase):
- user: UserResp | None = None
- position: int | None = None
-
- @classmethod
- async def from_db(
- cls,
- item_attempts: ItemAttemptsCount,
- session: AsyncSession,
- include: list[str] = [],
- ) -> "ItemAttemptsResp":
- resp = cls.model_validate(item_attempts.model_dump())
- resp.user = await UserResp.from_db(
- await item_attempts.awaitable_attrs.user,
- session=session,
- include=["statistics", "team", "daily_challenge_user_stats"],
- )
- if "position" in include:
- resp.position = await item_attempts.get_position(session)
- # resp.accuracy *= 100
- return resp
-
-
-class ItemAttemptsCountForItem(BaseModel):
- id: int
- attempts: int
- passed: bool
-
-
-class PlaylistAggregateScore(BaseModel):
- playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
-
- @classmethod
- async def from_db(
- cls,
- room_id: int,
- user_id: int,
- session: AsyncSession,
- ) -> "PlaylistAggregateScore":
- playlist_scores = (
- await session.exec(
- select(PlaylistBestScore).where(
- PlaylistBestScore.room_id == room_id,
- PlaylistBestScore.user_id == user_id,
- )
- )
- ).all()
- playlist_item_attempts = []
- for score in playlist_scores:
- playlist_item_attempts.append(
- ItemAttemptsCountForItem(
- id=score.playlist_id,
- attempts=score.attempts,
- passed=score.score.passed,
- )
- )
- return cls(playlist_item_attempts=playlist_item_attempts)
diff --git a/app/database/playlists.py b/app/database/playlists.py
index 6edf16f..6d5b6f7 100644
--- a/app/database/playlists.py
+++ b/app/database/playlists.py
@@ -1,11 +1,11 @@
from datetime import datetime
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
-from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.playlist import PlaylistItem
-from .beatmap import Beatmap, BeatmapResp
+from ._base import DatabaseModel, ondemand
+from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from sqlmodel import (
JSON,
@@ -15,7 +15,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
- SQLModel,
func,
select,
)
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
+ from .score import ScoreDict
-class PlaylistBase(SQLModel, UTCBaseModel):
- id: int = Field(index=True)
- owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
+class PlaylistDict(TypedDict):
+ id: int
+ room_id: int
+ beatmap_id: int
+ created_at: datetime | None
ruleset_id: int
- expired: bool = Field(default=False)
- playlist_order: int = Field(default=0)
- played_at: datetime | None = Field(
- sa_column=Column(DateTime(timezone=True)),
- default=None,
+ allowed_mods: list[APIMod]
+ required_mods: list[APIMod]
+ freestyle: bool
+ expired: bool
+ owner_id: int
+ playlist_order: int
+ played_at: datetime | None
+ beatmap: NotRequired["BeatmapDict"]
+ scores: NotRequired[list[dict[str, Any]]]
+
+
+class PlaylistModel(DatabaseModel[PlaylistDict]):
+ id: int = Field(index=True)
+ room_id: int = Field(foreign_key="rooms.id")
+ beatmap_id: int = Field(
+ foreign_key="beatmaps.id",
)
+ created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
+ ruleset_id: int
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
default_factory=list,
sa_column=Column(JSON),
)
- beatmap_id: int = Field(
- foreign_key="beatmaps.id",
- )
freestyle: bool = Field(default=False)
+ expired: bool = Field(default=False)
+ owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
+ playlist_order: int = Field(default=0)
+ played_at: datetime | None = Field(
+ sa_column=Column(DateTime(timezone=True)),
+ default=None,
+ )
+
+ @ondemand
+ @staticmethod
+ async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
+ return await BeatmapModel.transform(playlist.beatmap, includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
+ from .score import Score, ScoreModel
+
+ scores = (
+ await session.exec(
+ select(Score).where(
+ Score.playlist_item_id == playlist.id,
+ Score.room_id == playlist.room_id,
+ )
+ )
+ ).all()
+ result: list[ScoreDict] = []
+ for score in scores:
+ result.append(
+ await ScoreModel.transform(
+ score,
+ )
+ )
+ return result
-class Playlist(PlaylistBase, table=True):
+class Playlist(PlaylistModel, table=True):
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
- room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
}
)
room: "Room" = Relationship()
- created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
updated_at: datetime | None = Field(
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
-
-
-class PlaylistResp(PlaylistBase):
- beatmap: BeatmapResp | None = None
-
- @classmethod
- async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
- data = playlist.model_dump()
- if "beatmap" in include:
- data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
- resp = cls.model_validate(data)
- return resp
diff --git a/app/database/relationship.py b/app/database/relationship.py
index f792f31..55ecf20 100644
--- a/app/database/relationship.py
+++ b/app/database/relationship.py
@@ -1,26 +1,39 @@
from enum import Enum
+from typing import TYPE_CHECKING, NotRequired, TypedDict
-from .user import User, UserResp
+from app.models.score import GameMode
+
+from ._base import DatabaseModel, included, ondemand
-from pydantic import BaseModel
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship as SQLRelationship,
- SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
+if TYPE_CHECKING:
+ from .user import User, UserDict
+
class RelationshipType(str, Enum):
- FOLLOW = "Friend"
- BLOCK = "Block"
+ FOLLOW = "friend"
+ BLOCK = "block"
-class Relationship(SQLModel, table=True):
+class RelationshipDict(TypedDict):
+ target_id: int | None
+ type: RelationshipType
+ id: NotRequired[int | None]
+ user_id: NotRequired[int | None]
+ mutual: NotRequired[bool]
+ target: NotRequired["UserDict"]
+
+
+class RelationshipModel(DatabaseModel[RelationshipDict]):
__tablename__: str = "relationship"
id: int | None = Field(
default=None,
@@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True):
ForeignKey("lazer_users.id"),
index=True,
),
+ exclude=True,
)
target_id: int = Field(
default=None,
@@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True):
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
- target: User = SQLRelationship(
- sa_relationship_kwargs={
- "foreign_keys": "[Relationship.target_id]",
- "lazy": "selectin",
- }
- )
-
-class RelationshipResp(BaseModel):
- target_id: int
- target: UserResp
- mutual: bool = False
- type: RelationshipType
-
- @classmethod
- async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
+ @included
+ @staticmethod
+ async def mutual(session: AsyncSession, relationship: "Relationship") -> bool:
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -68,23 +70,29 @@ class RelationshipResp(BaseModel):
)
)
).first()
- mutual = bool(
+ return bool(
target_relationship is not None
and relationship.type == RelationshipType.FOLLOW
and target_relationship.type == RelationshipType.FOLLOW
)
- return cls(
- target_id=relationship.target_id,
- target=await UserResp.from_db(
- relationship.target,
- session,
- include=[
- "team",
- "daily_challenge_user_stats",
- "statistics",
- "statistics_rulesets",
- ],
- ),
- mutual=mutual,
- type=relationship.type,
- )
+
+ @ondemand
+ @staticmethod
+ async def target(
+ _session: AsyncSession,
+ relationship: "Relationship",
+ ruleset: GameMode | None = None,
+ includes: list[str] | None = None,
+ ) -> "UserDict":
+ from .user import UserModel
+
+ return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes)
+
+
+class Relationship(RelationshipModel, table=True):
+ target: "User" = SQLRelationship(
+ sa_relationship_kwargs={
+ "foreign_keys": "[Relationship.target_id]",
+ "lazy": "selectin",
+ }
+ )
diff --git a/app/database/room.py b/app/database/room.py
index add0dbc..afa609b 100644
--- a/app/database/room.py
+++ b/app/database/room.py
@@ -1,8 +1,6 @@
from datetime import datetime
+from typing import ClassVar, NotRequired, TypedDict
-from app.database.item_attempts_count import PlaylistAggregateScore
-from app.database.room_participated_user import RoomParticipatedUser
-from app.models.model import UTCBaseModel
from app.models.room import (
MatchType,
QueueMode,
@@ -13,28 +11,58 @@ from app.models.room import (
)
from app.utils import utcnow
-from .playlists import Playlist, PlaylistResp
-from .user import User, UserResp
+from ._base import DatabaseModel, included, ondemand
+from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel
+from .playlists import Playlist, PlaylistDict, PlaylistModel
+from .room_participated_user import RoomParticipatedUser
+from .user import User, UserDict, UserModel
+from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs
-from sqlmodel import (
- BigInteger,
- Column,
- DateTime,
- Field,
- ForeignKey,
- Relationship,
- SQLModel,
- col,
- func,
- select,
-)
+from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
-class RoomBase(SQLModel, UTCBaseModel):
+class RoomDict(TypedDict):
+ id: int
+ name: str
+ category: RoomCategory
+ status: RoomStatus
+ type: MatchType
+ duration: int | None
+ starts_at: datetime | None
+ ends_at: datetime | None
+ max_attempts: int | None
+ participant_count: int
+ channel_id: int
+ queue_mode: QueueMode
+ auto_skip: bool
+ auto_start_duration: int
+ has_password: NotRequired[bool]
+ current_playlist_item: NotRequired["PlaylistDict | None"]
+ playlist: NotRequired[list["PlaylistDict"]]
+ playlist_item_stats: NotRequired[RoomPlaylistItemStats]
+ difficulty_range: NotRequired[RoomDifficultyRange]
+ host: NotRequired[UserDict]
+ recent_participants: NotRequired[list[UserDict]]
+ current_user_score: NotRequired["ItemAttemptsCountDict | None"]
+
+
+class RoomModel(DatabaseModel[RoomDict]):
+ SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [
+ "current_user_score.playlist_item_attempts",
+ "host.country",
+ "playlist.beatmap.beatmapset",
+ "playlist.beatmap.checksum",
+ "playlist.beatmap.max_combo",
+ "recent_participants",
+ ]
+
+ id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
+ status: RoomStatus
+ type: MatchType
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
@@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel):
),
default=None,
)
- participant_count: int = Field(default=0)
max_attempts: int | None = Field(default=None) # playlists
- type: MatchType
+ participant_count: int = Field(default=0)
+ channel_id: int = 0
queue_mode: QueueMode
auto_skip: bool
+
auto_start_duration: int
- status: RoomStatus
- channel_id: int | None = None
-
-
-class Room(AsyncAttrs, RoomBase, table=True):
- __tablename__: str = "rooms"
- id: int = Field(default=None, primary_key=True, index=True)
- host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
- password: str | None = Field(default=None)
-
- host: User = Relationship()
- playlist: list[Playlist] = Relationship(
- sa_relationship_kwargs={
- "lazy": "selectin",
- "cascade": "all, delete-orphan",
- "overlaps": "room",
- }
- )
-
-
-class RoomResp(RoomBase):
- id: int
- has_password: bool = False
- host: UserResp | None = None
- playlist: list[PlaylistResp] = []
- playlist_item_stats: RoomPlaylistItemStats | None = None
- difficulty_range: RoomDifficultyRange | None = None
- current_playlist_item: PlaylistResp | None = None
- current_user_score: PlaylistAggregateScore | None = None
- recent_participants: list[UserResp] = Field(default_factory=list)
- channel_id: int = 0
+ @field_validator("channel_id", mode="before")
@classmethod
- async def from_db(
- cls,
- room: Room,
- session: AsyncSession,
- include: list[str] = [],
- user: User | None = None,
- ) -> "RoomResp":
- d = room.model_dump()
- d["channel_id"] = d.get("channel_id", 0) or 0
- d["has_password"] = bool(room.password)
- resp = cls.model_validate(d)
+ def validate_channel_id(cls, v):
+ """将 None 转换为 0"""
+ if v is None:
+ return 0
+ return v
- stats = RoomPlaylistItemStats(count_active=0, count_total=0)
- difficulty_range = RoomDifficultyRange(
- min=0,
- max=0,
- )
- rulesets = set()
- for playlist in room.playlist:
+ @included
+ @staticmethod
+ async def has_password(_session: AsyncSession, room: "Room") -> bool:
+ return bool(room.password)
+
+ @ondemand
+ @staticmethod
+ async def current_playlist_item(
+ _session: AsyncSession, room: "Room", includes: list[str] | None = None
+ ) -> "PlaylistDict | None":
+ playlists = await room.awaitable_attrs.playlist
+ if not playlists:
+ return None
+ return await PlaylistModel.transform(playlists[-1], includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]:
+ playlists = await room.awaitable_attrs.playlist
+ result: list[PlaylistDict] = []
+ for playlist_item in playlists:
+ result.append(await PlaylistModel.transform(playlist_item, includes=includes))
+ return result
+
+ @ondemand
+ @staticmethod
+ async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats:
+ playlists = await room.awaitable_attrs.playlist
+ stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[])
+ rulesets: set[int] = set()
+ for playlist in playlists:
stats.count_total += 1
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
- difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
- difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
- resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
- resp.playlist_item_stats = stats
- resp.difficulty_range = difficulty_range
- resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
- resp.recent_participants = []
+ return stats
+
+ @ondemand
+ @staticmethod
+ async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange:
+ playlists = await room.awaitable_attrs.playlist
+ if not playlists:
+ return RoomDifficultyRange(min=0.0, max=0.0)
+ min_diff = float("inf")
+ max_diff = float("-inf")
+ for playlist in playlists:
+ rating = playlist.beatmap.difficulty_rating
+ min_diff = min(min_diff, rating)
+ max_diff = max(max_diff, rating)
+ if min_diff == float("inf"):
+ min_diff = 0.0
+ if max_diff == float("-inf"):
+ max_diff = 0.0
+ return RoomDifficultyRange(min=min_diff, max=max_diff)
+
+ @ondemand
+ @staticmethod
+ async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict:
+ host_user = await room.awaitable_attrs.host
+ return await UserModel.transform(host_user, includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]:
+ participants: list[UserDict] = []
if room.category == RoomCategory.REALTIME:
query = (
select(RoomParticipatedUser)
@@ -137,39 +177,67 @@ class RoomResp(RoomBase):
.limit(8)
.order_by(col(RoomParticipatedUser.joined_at).desc())
)
- resp.participant_count = (
- await session.exec(
- select(func.count())
- .select_from(RoomParticipatedUser)
- .where(
- RoomParticipatedUser.room_id == room.id,
- )
- )
- ).first() or 0
for recent_participant in await session.exec(query):
- resp.recent_participants.append(
- await UserResp.from_db(
- await recent_participant.awaitable_attrs.user,
- session,
- include=["statistics"],
+ user_instance = await recent_participant.awaitable_attrs.user
+ participants.append(await UserModel.transform(user_instance))
+ return participants
+
+ @ondemand
+ @staticmethod
+ async def current_user_score(
+ session: AsyncSession, room: "Room", includes: list[str] | None = None
+ ) -> "ItemAttemptsCountDict | None":
+ item_attempt = (
+ await session.exec(
+ select(ItemAttemptsCount).where(
+ ItemAttemptsCount.room_id == room.id,
)
)
- resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
- if "current_user_score" in include and user:
- resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
- return resp
+ ).first()
+ if item_attempt is None:
+ return None
+
+ return await ItemAttemptsCountModel.transform(item_attempt, includes=includes)
-class APIUploadedRoom(RoomBase):
- def to_room(self) -> Room:
- """
- 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
- """
- room_dict = self.model_dump()
- room_dict.pop("playlist", None)
- # host_id 已在字段中
- return Room(**room_dict)
+class Room(AsyncAttrs, RoomModel, table=True):
+ __tablename__: str = "rooms"
- id: int | None
- host_id: int | None = None
+ host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
+ password: str | None = Field(default=None)
+
+ host: User = Relationship()
+ playlist: list[Playlist] = Relationship(
+ sa_relationship_kwargs={
+ "lazy": "selectin",
+ "cascade": "all, delete-orphan",
+ "overlaps": "room",
+ }
+ )
+
+
+class APIUploadedRoom(SQLModel):
+ name: str = Field(index=True)
+ category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
+ status: RoomStatus
+ type: MatchType
+ duration: int | None = Field(default=None) # minutes
+ starts_at: datetime | None = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default_factory=utcnow,
+ )
+ ends_at: datetime | None = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default=None,
+ )
+ max_attempts: int | None = Field(default=None) # playlists
+ participant_count: int = Field(default=0)
+ channel_id: int = 0
+ queue_mode: QueueMode
+ auto_skip: bool
+ auto_start_duration: int
playlist: list[Playlist] = Field(default_factory=list)
diff --git a/app/database/score.py b/app/database/score.py
index 49bdecd..6bb4a28 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -3,7 +3,7 @@ from datetime import date, datetime
import json
import math
import sys
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.calculator import (
calculate_pp_weight,
@@ -15,8 +15,6 @@ from app.calculator import (
pre_fetch_and_calculate_pp,
)
from app.config import settings
-from app.database.beatmap_playcounts import BeatmapPlaycounts
-from app.database.team import TeamMember
from app.dependencies.database import get_redis
from app.log import log
from app.models.beatmap import BeatmapRankStatus
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
from app.storage import StorageService
from app.utils import utcnow
-from .beatmap import Beatmap, BeatmapResp
-from .beatmapset import BeatmapsetResp
+from ._base import DatabaseModel, OnDemand, included, ondemand
+from .beatmap import Beatmap, BeatmapDict, BeatmapModel
+from .beatmap_playcounts import BeatmapPlaycounts
+from .beatmapset import BeatmapsetDict, BeatmapsetModel
from .best_scores import BestScore
from .counts import MonthlyPlaycounts
from .events import Event, EventType
@@ -50,8 +50,9 @@ from .relationship import (
RelationshipType,
)
from .score_token import ScoreToken
+from .team import TeamMember
from .total_score_best_scores import TotalScoreBestScore
-from .user import User, UserResp
+from .user import User, UserDict, UserModel
from pydantic import BaseModel, field_serializer, field_validator
from redis.asyncio import Redis
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
logger = log("Score")
-class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
- # 基本字段
+class ScoreDict(TypedDict):
+ beatmap_id: int
+ id: int
+ rank: Rank
+ type: str
+ user_id: int
+ accuracy: float
+ build_id: int | None
+ ended_at: datetime
+ has_replay: bool
+ max_combo: int
+ passed: bool
+ pp: float
+ started_at: datetime
+ total_score: int
+ maximum_statistics: ScoreStatistics
+ mods: list[APIMod]
+ classic_total_score: int | None
+ preserve: bool
+ processed: bool
+ ranked: bool
+ playlist_item_id: NotRequired[int | None]
+ room_id: NotRequired[int | None]
+ best_id: NotRequired[int | None]
+ legacy_perfect: NotRequired[bool]
+ is_perfect_combo: NotRequired[bool]
+ ruleset_id: NotRequired[int]
+ statistics: NotRequired[ScoreStatistics]
+ beatmapset: NotRequired[BeatmapsetDict]
+ beatmap: NotRequired[BeatmapDict]
+ current_user_attributes: NotRequired[CurrentUserAttributes]
+ position: NotRequired[int | None]
+ scores_around: NotRequired["ScoreAround | None"]
+ rank_country: NotRequired[int | None]
+ rank_global: NotRequired[int | None]
+ user: NotRequired[UserDict]
+ weight: NotRequired[float | None]
+
+ # ScoreResp 字段
+ legacy_total_score: NotRequired[int]
+
+
+class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
+ # https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
+ MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
+ MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
+ "user.country",
+ "user.cover",
+ "user.team",
+ *MULTIPLAYER_SCORE_INCLUDE,
+ ]
+ USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
+
+ # 基本字段
+ beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
+ id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
+ rank: Rank
+ type: str
+ user_id: int = Field(
+ default=None,
+ sa_column=Column(
+ BigInteger,
+ ForeignKey("lazer_users.id"),
+ index=True,
+ ),
+ )
accuracy: float
- map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
- classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
- mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool = Field(sa_column=Column(Boolean))
- playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float = Field(default=0.0)
- preserve: bool = Field(default=True, sa_column=Column(Boolean))
- rank: Rank
- room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
- total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
- type: str
- beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
- processed: bool = False # solo_score
- ranked: bool = False
+ mods: list[APIMod] = Field(sa_column=Column(JSON))
+ total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
+
+ # solo
+ classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
+ preserve: bool = Field(default=True, sa_column=Column(Boolean))
+ processed: bool = Field(default=False)
+ ranked: bool = Field(default=False)
+
+ # multiplayer
+ playlist_item_id: OnDemand[int | None] = Field(default=None)
+ room_id: OnDemand[int | None] = Field(default=None)
+
+ @included
+ @staticmethod
+ async def best_id(
+ session: AsyncSession,
+ score: "Score",
+ ) -> int | None:
+ return await get_best_id(session, score.id)
+
+ @included
+ @staticmethod
+ async def legacy_perfect(
+ _session: AsyncSession,
+ score: "Score",
+ ) -> bool:
+ await score.awaitable_attrs.beatmap
+ return score.max_combo == score.beatmap.max_combo
+
+ @included
+ @staticmethod
+ async def is_perfect_combo(
+ _session: AsyncSession,
+ score: "Score",
+ ) -> bool:
+ await score.awaitable_attrs.beatmap
+ return score.max_combo == score.beatmap.max_combo
+
+ @included
+ @staticmethod
+ async def ruleset_id(
+ _session: AsyncSession,
+ score: "Score",
+ ) -> int:
+ return int(score.gamemode)
+
+ @included
+ @staticmethod
+ async def statistics(
+ _session: AsyncSession,
+ score: "Score",
+ ) -> ScoreStatistics:
+ stats = {
+ HitResult.MISS: score.nmiss,
+ HitResult.MEH: score.n50,
+ HitResult.OK: score.n100,
+ HitResult.GREAT: score.n300,
+ HitResult.PERFECT: score.ngeki,
+ HitResult.GOOD: score.nkatu,
+ }
+ if score.nlarge_tick_miss is not None:
+ stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
+ if score.nslider_tail_hit is not None:
+ stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
+ if score.nsmall_tick_hit is not None:
+ stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
+ if score.nlarge_tick_hit is not None:
+ stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
+ return stats
+
+ @ondemand
+ @staticmethod
+ async def beatmapset(
+ _session: AsyncSession,
+ score: "Score",
+ includes: list[str] | None = None,
+ ) -> BeatmapsetDict:
+ await score.awaitable_attrs.beatmap
+ return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
+
+ # reorder beatmapset and beatmap
+ # https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
+ @ondemand
+ @staticmethod
+ async def beatmap(
+ _session: AsyncSession,
+ score: "Score",
+ includes: list[str] | None = None,
+ ) -> BeatmapDict:
+ await score.awaitable_attrs.beatmap
+ return await BeatmapModel.transform(score.beatmap, includes=includes)
+
+ @ondemand
+ @staticmethod
+ async def current_user_attributes(
+ _session: AsyncSession,
+ score: "Score",
+ ) -> CurrentUserAttributes:
+ return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
+
+ @ondemand
+ @staticmethod
+ async def position(
+ session: AsyncSession,
+ score: "Score",
+ ) -> int | None:
+ return await get_score_position_by_id(
+ session,
+ score.beatmap_id,
+ score.id,
+ mode=score.gamemode,
+ user=score.user,
+ )
+
+ @ondemand
+ @staticmethod
+ async def scores_around(
+ session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
+ ) -> "ScoreAround | None":
+ scores = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
+ ~User.is_restricted_query(col(PlaylistBestScore.user_id)),
+ col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
+ )
+ )
+ ).all()
+
+ higher_scores = []
+ lower_scores = []
+ for score in scores:
+ total_score = score.score.total_score
+ resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
+ if score.total_score > total_score:
+ higher_scores.append(resp)
+ elif score.total_score < total_score:
+ lower_scores.append(resp)
+
+ return ScoreAround(
+ higher=MultiplayerScores(scores=higher_scores),
+ lower=MultiplayerScores(scores=lower_scores),
+ )
+
+ @ondemand
+ @staticmethod
+ async def rank_country(
+ session: AsyncSession,
+ score: "Score",
+ ) -> int | None:
+ return (
+ await get_score_position_by_id(
+ session,
+ score.beatmap_id,
+ score.id,
+ score.gamemode,
+ score.user,
+ type=LeaderboardType.COUNTRY,
+ )
+ or None
+ )
+
+ @ondemand
+ @staticmethod
+ async def rank_global(
+ session: AsyncSession,
+ score: "Score",
+ ) -> int | None:
+ return (
+ await get_score_position_by_id(
+ session,
+ score.beatmap_id,
+ score.id,
+ mode=score.gamemode,
+ user=score.user,
+ )
+ or None
+ )
+
+ @ondemand
+ @staticmethod
+ async def user(
+ _session: AsyncSession,
+ score: "Score",
+ includes: list[str] | None = None,
+ ) -> UserDict:
+ return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
+
+ @ondemand
+ @staticmethod
+ async def weight(
+ session: AsyncSession,
+ score: "Score",
+ ) -> float | None:
+ best_id = await get_best_id(session, score.id)
+ if best_id:
+ return calculate_pp_weight(best_id - 1)
+ return None
+
+ @ondemand
+ @staticmethod
+ async def legacy_total_score(
+ _session: AsyncSession,
+ _score: "Score",
+ ) -> int:
+ return 0
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# TODO: current_user_attributes
-class Score(ScoreBase, table=True):
+class Score(ScoreModel, table=True):
__tablename__: str = "scores"
- id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("lazer_users.id"),
- index=True,
- ),
- )
+
# ScoreStatistics
n300: int = Field(exclude=True)
n100: int = Field(exclude=True)
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
+ map_md5: str = Field(max_length=32, index=True, exclude=True)
@field_validator("gamemode", mode="before")
@classmethod
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
maximum_statistics=self.maximum_statistics,
)
- async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
+ async def to_resp(
+ self, session: AsyncSession, api_version: int, includes: list[str] = []
+ ) -> "ScoreDict | LegacyScoreResp":
if api_version >= 20220705:
- return await ScoreResp.from_db(session, self)
+ return await ScoreModel.transform(self, includes=includes)
return await LegacyScoreResp.from_db(session, self)
async def delete(
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
await session.delete(self)
-class ScoreResp(ScoreBase):
- id: int
- user_id: int
- is_perfect_combo: bool = False
- legacy_perfect: bool = False
- legacy_total_score: int = 0 # FIXME
- weight: float = 0.0
- best_id: int | None = None
- ruleset_id: int | None = None
- beatmap: BeatmapResp | None = None
- beatmapset: BeatmapsetResp | None = None
- user: UserResp | None = None
- statistics: ScoreStatistics | None = None
- maximum_statistics: ScoreStatistics | None = None
- rank_global: int | None = None
- rank_country: int | None = None
- position: int | None = None
- scores_around: "ScoreAround | None" = None
- current_user_attributes: CurrentUserAttributes | None = None
-
- @field_validator(
- "has_replay",
- "passed",
- "preserve",
- "is_perfect_combo",
- "legacy_perfect",
- "processed",
- "ranked",
- mode="before",
- )
- @classmethod
- def validate_bool_fields(cls, v):
- """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
- if isinstance(v, int):
- return bool(v)
- return v
-
- @field_validator("statistics", "maximum_statistics", mode="before")
- @classmethod
- def validate_statistics_fields(cls, v):
- """处理统计字段中的字符串键,转换为 HitResult 枚举"""
- if isinstance(v, dict):
- converted = {}
- for key, value in v.items():
- if isinstance(key, str):
- try:
- # 尝试将字符串转换为 HitResult 枚举
- enum_key = HitResult(key)
- converted[enum_key] = value
- except ValueError:
- # 如果转换失败,跳过这个键值对
- continue
- else:
- converted[key] = value
- return converted
- return v
-
- @field_serializer("statistics", when_used="json")
- def serialize_statistics_fields(self, v):
- """序列化统计字段,确保枚举值正确转换为字符串"""
- if isinstance(v, dict):
- serialized = {}
- for key, value in v.items():
- if hasattr(key, "value"):
- # 如果是枚举,使用其值
- serialized[key.value] = value
- else:
- # 否则直接使用键
- serialized[str(key)] = value
- return serialized
- return v
-
- @classmethod
- async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
- # 确保 score 对象完全加载,避免懒加载问题
- await session.refresh(score)
-
- s = cls.model_validate(score.model_dump())
- await score.awaitable_attrs.beatmap
- s.beatmap = await BeatmapResp.from_db(score.beatmap)
- s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
- s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
- s.legacy_perfect = s.max_combo == s.beatmap.max_combo
- s.ruleset_id = int(score.gamemode)
- best_id = await get_best_id(session, score.id)
- if best_id:
- s.best_id = best_id
- s.weight = calculate_pp_weight(best_id - 1)
- s.statistics = {
- HitResult.MISS: score.nmiss,
- HitResult.MEH: score.n50,
- HitResult.OK: score.n100,
- HitResult.GREAT: score.n300,
- HitResult.PERFECT: score.ngeki,
- HitResult.GOOD: score.nkatu,
- }
- if score.nlarge_tick_miss is not None:
- s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
- if score.nslider_tail_hit is not None:
- s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
- if score.nsmall_tick_hit is not None:
- s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
- if score.nlarge_tick_hit is not None:
- s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
- s.user = await UserResp.from_db(
- score.user,
- session,
- include=["statistics", "team", "daily_challenge_user_stats"],
- ruleset=score.gamemode,
- )
- s.rank_global = (
- await get_score_position_by_id(
- session,
- score.beatmap_id,
- score.id,
- mode=score.gamemode,
- user=score.user,
- )
- or None
- )
- s.rank_country = (
- await get_score_position_by_id(
- session,
- score.beatmap_id,
- score.id,
- score.gamemode,
- score.user,
- type=LeaderboardType.COUNTRY,
- )
- or None
- )
- s.current_user_attributes = CurrentUserAttributes(
- pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
- )
- return s
+MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
class LegacyStatistics(BaseModel):
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
class LegacyScoreResp(UTCBaseModel):
- accuracy: float
- best_id: int
- created_at: datetime
id: int
- max_combo: int
- mode: GameMode
- mode_int: int
+ best_id: int
+ user_id: int
+ accuracy: float
mods: list[str] # acronym
- passed: bool
+ score: int
+ max_combo: int
perfect: bool = False
+ statistics: LegacyStatistics
+ passed: bool
pp: float
rank: Rank
+ created_at: datetime
+ mode: GameMode
+ mode_int: int
replay: bool
- score: int
- statistics: LegacyStatistics
- type: str
- user_id: int
- current_user_attributes: CurrentUserAttributes
- user: UserResp
- beatmap: BeatmapResp
- rank_global: int | None = Field(default=None, exclude=True)
@classmethod
- async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
- await session.refresh(score)
+ async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
await score.awaitable_attrs.beatmap
return cls(
accuracy=score.accuracy,
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
count_geki=score.ngeki or 0,
count_katu=score.nkatu or 0,
),
- type=score.type,
user_id=score.user_id,
- current_user_attributes=CurrentUserAttributes(
- pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
- ),
- user=await UserResp.from_db(
- score.user,
- session,
- include=["statistics", "team", "daily_challenge_user_stats"],
- ruleset=score.gamemode,
- ),
- beatmap=await BeatmapResp.from_db(score.beatmap),
perfect=score.is_perfect_combo,
- rank_global=(
- await get_score_position_by_id(
- session,
- score.beatmap_id,
- score.id,
- mode=score.gamemode,
- user=score.user,
- )
- or None
- ),
)
class MultiplayerScores(RespWithCursor):
- scores: list[ScoreResp] = Field(default_factory=list)
+ scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
params: dict[str, Any] = Field(default_factory=dict)
@@ -842,13 +937,13 @@ async def get_user_best_pp(
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
-def get_play_length(score: Score, beatmap_length: int):
+def get_play_length(score: "Score", beatmap_length: int):
speed_rate = get_speed_rate(score.mods)
length = beatmap_length / speed_rate
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
-def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
+def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
total_length = get_play_length(score, beatmap_length)
total_obj_hited = (
score.n300
@@ -937,7 +1032,7 @@ async def process_score(
return score
-async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
+async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
if score.pp != 0:
logger.debug(
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
)
-async def _process_score_events(score: Score, session: AsyncSession):
+async def _process_score_events(score: "Score", session: AsyncSession):
total_users = (await session.exec(select(func.count()).select_from(User))).one()
rank_global = await get_score_position_by_id(
session,
@@ -1088,7 +1183,7 @@ async def _process_statistics(
session: AsyncSession,
redis: Redis,
user: User,
- score: Score,
+ score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,
@@ -1318,7 +1413,7 @@ async def process_user(
redis: Redis,
fetcher: "Fetcher",
user: User,
- score: Score,
+ score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,
diff --git a/app/database/search_beatmapset.py b/app/database/search_beatmapset.py
new file mode 100644
index 0000000..c680175
--- /dev/null
+++ b/app/database/search_beatmapset.py
@@ -0,0 +1,13 @@
+from . import beatmap # noqa: F401
+from .beatmapset import BeatmapsetModel
+
+from sqlmodel import SQLModel
+
+SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags"))
+
+
+class SearchBeatmapsetsResp(SQLModel):
+ beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm]
+ total: int
+ cursor: dict[str, int | float | str] | None = None
+ cursor_string: str | None = None
diff --git a/app/database/statistics.py b/app/database/statistics.py
index 3d0b1dd..a1c90d9 100644
--- a/app/database/statistics.py
+++ b/app/database/statistics.py
@@ -1,10 +1,11 @@
from datetime import timedelta
import math
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.models.score import GameMode
from app.utils import utcnow
+from ._base import DatabaseModel, included, ondemand
from .rank_history import RankHistory
from pydantic import field_validator
@@ -15,7 +16,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
- SQLModel,
col,
func,
select,
@@ -23,10 +23,40 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
- from .user import User, UserResp
+ from .user import User, UserDict
-class UserStatisticsBase(SQLModel):
+class UserStatisticsDict(TypedDict):
+ mode: GameMode
+ count_100: int
+ count_300: int
+ count_50: int
+ count_miss: int
+ pp: float
+ ranked_score: int
+ hit_accuracy: float
+ total_score: int
+ total_hits: int
+ maximum_combo: int
+ play_count: int
+ play_time: int
+ replays_watched_by_others: int
+ is_ranked: bool
+ level: NotRequired[dict[str, int]]
+ global_rank: NotRequired[int | None]
+ grade_counts: NotRequired[dict[str, int]]
+ rank_change_since_30_days: NotRequired[int]
+ country_rank: NotRequired[int | None]
+ user: NotRequired["UserDict"]
+
+
+class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
+ RANKING_INCLUDES: ClassVar[list[str]] = [
+ "user.country",
+ "user.cover",
+ "user.team",
+ ]
+
mode: GameMode = Field(index=True)
count_100: int = Field(default=0, sa_column=Column(BigInteger))
count_300: int = Field(default=0, sa_column=Column(BigInteger))
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
return GameMode.OSU
return v
+ @included
+ @staticmethod
+ async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
+ return {
+ "current": int(statistics.level_current),
+ "progress": int(math.fmod(statistics.level_current, 1) * 100),
+ }
-class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
+ @included
+ @staticmethod
+ async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
+ return await get_rank(session, statistics)
+
+ @included
+ @staticmethod
+ async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
+ return {
+ "ss": statistics.grade_ss,
+ "ssh": statistics.grade_ssh,
+ "s": statistics.grade_s,
+ "sh": statistics.grade_sh,
+ "a": statistics.grade_a,
+ }
+
+ @ondemand
+ @staticmethod
+ async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
+ global_rank = await get_rank(session, statistics)
+ rank_best = (
+ await session.exec(
+ select(func.max(RankHistory.rank)).where(
+ RankHistory.date > utcnow() - timedelta(days=30),
+ RankHistory.user_id == statistics.user_id,
+ )
+ )
+ ).first()
+ if rank_best is None or global_rank is None:
+ return 0
+ return rank_best - global_rank
+
+ @ondemand
+ @staticmethod
+ async def country_rank(
+ session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
+ ) -> int | None:
+ return await get_rank(session, statistics, user_country)
+
+ @ondemand
+ @staticmethod
+ async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
+ from .user import UserModel
+
+ user_instance = await statistics.awaitable_attrs.user
+ return await UserModel.transform(user_instance)
+
+
+class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
__tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
user: "User" = Relationship(back_populates="statistics")
-class UserStatisticsResp(UserStatisticsBase):
- user: "UserResp | None" = None
- rank_change_since_30_days: int | None = 0
- global_rank: int | None = Field(default=None)
- country_rank: int | None = Field(default=None)
- grade_counts: dict[str, int] = Field(
- default_factory=lambda: {
- "ss": 0,
- "ssh": 0,
- "s": 0,
- "sh": 0,
- "a": 0,
- }
- )
- level: dict[str, int] = Field(
- default_factory=lambda: {
- "current": 1,
- "progress": 0,
- }
- )
-
- @classmethod
- async def from_db(
- cls,
- obj: UserStatistics,
- session: AsyncSession,
- user_country: str | None = None,
- include: list[str] = [],
- ) -> "UserStatisticsResp":
- s = cls.model_validate(obj.model_dump())
- s.grade_counts = {
- "ss": obj.grade_ss,
- "ssh": obj.grade_ssh,
- "s": obj.grade_s,
- "sh": obj.grade_sh,
- "a": obj.grade_a,
- }
- s.level = {
- "current": int(obj.level_current),
- "progress": int(math.fmod(obj.level_current, 1) * 100),
- }
- if "user" in include:
- from .user import RANKING_INCLUDES, UserResp
-
- user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
- s.user = user
- user_country = user.country_code
-
- s.global_rank = await get_rank(session, obj)
- s.country_rank = await get_rank(session, obj, user_country)
-
- if "rank_change_since_30_days" in include:
- rank_best = (
- await session.exec(
- select(func.max(RankHistory.rank)).where(
- RankHistory.date > utcnow() - timedelta(days=30),
- RankHistory.user_id == obj.user_id,
- )
- )
- ).first()
- if rank_best is None or s.global_rank is None:
- s.rank_change_since_30_days = 0
- else:
- s.rank_change_since_30_days = rank_best - s.global_rank
-
- return s
-
-
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .user import User
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
query = query.join(User).where(User.country_code == country)
subq = query.subquery()
-
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
rank = result.first()
diff --git a/app/database/total_score_best_scores.py b/app/database/total_score_best_scores.py
index 1ab1d5a..f2a8f2d 100644
--- a/app/database/total_score_best_scores.py
+++ b/app/database/total_score_best_scores.py
@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING
from app.calculator import calculate_score_to_level
-from app.database.statistics import UserStatistics
from app.models.score import GameMode, Rank
+from .statistics import UserStatistics
from .user import User
from sqlmodel import (
diff --git a/app/database/user.py b/app/database/user.py
index fd40c35..2f16390 100644
--- a/app/database/user.py
+++ b/app/database/user.py
@@ -1,22 +1,25 @@
from datetime import datetime, timedelta
import json
-from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, overload
+from typing import TYPE_CHECKING, ClassVar, Literal, NotRequired, TypedDict, overload
from app.config import settings
-from app.models.model import UTCBaseModel
+from app.models.notification import NotificationName
from app.models.score import GameMode
from app.models.user import Country, Page
from app.path import STATIC_DIR
from app.utils import utcnow
+from ._base import DatabaseModel, OnDemand, included, ondemand
from .achievement import UserAchievement, UserAchievementResp
from .auth import TotpKeys
from .beatmap_playcounts import BeatmapPlaycounts
from .counts import CountResp, MonthlyPlaycounts, ReplayWatchedCount
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .events import Event
+from .notification import Notification, UserNotification
from .rank_history import RankHistory, RankHistoryResp, RankTop
-from .statistics import UserStatistics, UserStatisticsResp
+from .relationship import RelationshipModel
+from .statistics import UserStatistics, UserStatisticsModel
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp, UserAccountHistoryType
from .user_preference import DEFAULT_ORDER, UserPreference
@@ -31,7 +34,6 @@ from sqlmodel import (
DateTime,
Field,
Relationship,
- SQLModel,
col,
exists,
func,
@@ -43,7 +45,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .favourite_beatmapset import FavouriteBeatmapset
from .matchmaking import MatchmakingUserStats
- from .relationship import RelationshipResp
+ from .relationship import Relationship, RelationshipDict
+ from .statistics import UserStatisticsDict
class Kudosu(TypedDict):
@@ -76,10 +79,194 @@ Badge = TypedDict(
COUNTRIES = json.loads((STATIC_DIR / "iso3166.json").read_text())
-class UserBase(UTCBaseModel, SQLModel):
+class UserDict(TypedDict):
+ avatar_url: str
+ country_code: str
+ id: int
+ is_active: bool
+ is_bot: bool
+ is_supporter: bool
+ last_visit: datetime | None
+ pm_friends_only: bool
+ profile_colour: str | None
+ username: str
+ g0v0_playmode: GameMode
+ page: NotRequired[Page]
+ previous_usernames: NotRequired[list[str]]
+ support_level: NotRequired[int]
+ badges: NotRequired[list[Badge]]
+ cover: NotRequired[UserProfileCover]
+ beatmap_playcounts_count: NotRequired[int]
+ playmode: NotRequired[GameMode]
+ discord: NotRequired[str | None]
+ has_supported: NotRequired[bool]
+ interests: NotRequired[str | None]
+ join_date: NotRequired[datetime]
+ location: NotRequired[str | None]
+ max_blocks: NotRequired[int]
+ max_friends: NotRequired[int]
+ occupation: NotRequired[str | None]
+ playstyle: NotRequired[list[str]]
+ profile_hue: NotRequired[int | None]
+ title: NotRequired[str | None]
+ title_url: NotRequired[str | None]
+ twitter: NotRequired[str | None]
+ website: NotRequired[str | None]
+ comments_count: NotRequired[int]
+ post_count: NotRequired[int]
+ is_admin: NotRequired[bool]
+ is_gmt: NotRequired[bool]
+ is_qat: NotRequired[bool]
+ is_bng: NotRequired[bool]
+ groups: NotRequired[list[str]]
+ active_tournament_banners: NotRequired[list[dict]]
+ graveyard_beatmapset_count: NotRequired[int]
+ loved_beatmapset_count: NotRequired[int]
+ mapping_follower_count: NotRequired[int]
+ nominated_beatmapset_count: NotRequired[int]
+ guest_beatmapset_count: NotRequired[int]
+ pending_beatmapset_count: NotRequired[int]
+ ranked_beatmapset_count: NotRequired[int]
+ follow_user_mapping: NotRequired[list[int]]
+ is_deleted: NotRequired[bool]
+ country: NotRequired[Country]
+ favourite_beatmapset_count: NotRequired[int]
+ follower_count: NotRequired[int]
+ scores_best_count: NotRequired[int]
+ scores_pinned_count: NotRequired[int]
+ scores_recent_count: NotRequired[int]
+ scores_first_count: NotRequired[int]
+ cover_url: NotRequired[str]
+ profile_order: NotRequired[list[str]]
+ user_preference: NotRequired[UserPreference | None]
+ friends: NotRequired[list["RelationshipDict"]]
+ team: NotRequired[Team | None]
+ account_history: NotRequired[list[UserAccountHistoryResp]]
+ daily_challenge_user_stats: NotRequired[DailyChallengeStatsResp | None]
+ statistics: NotRequired["UserStatisticsDict | None"]
+ statistics_rulesets: NotRequired[dict[str, "UserStatisticsDict"]]
+ monthly_playcounts: NotRequired[list[CountResp]]
+ replay_watched_counts: NotRequired[list[CountResp]]
+ user_achievements: NotRequired[list[UserAchievementResp]]
+ rank_history: NotRequired[RankHistoryResp | None]
+ rank_highest: NotRequired[RankHighest | None]
+ is_restricted: NotRequired[bool]
+ kudosu: NotRequired[Kudosu]
+ unread_pm_count: NotRequired[int]
+ default_group: NotRequired[str]
+ is_online: NotRequired[bool]
+ session_verified: NotRequired[bool]
+ session_verification_method: NotRequired[Literal["totp", "mail"] | None]
+
+
+class UserModel(DatabaseModel[UserDict]):
+ # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L22-L39
+ CARD_INCLUDES: ClassVar[list[str]] = [
+ "country",
+ "cover",
+ "groups",
+ "team",
+ ]
+ LIST_INCLUDES: ClassVar[list[str]] = [
+ *CARD_INCLUDES,
+ "statistics",
+ "support_level",
+ ]
+
+ # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserTransformer.php#L36-L53
+ USER_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
+ "cover_url",
+ "discord",
+ "has_supported",
+ "interests",
+ "join_date",
+ "location",
+ "max_blocks",
+ "max_friends",
+ "occupation",
+ "playmode",
+ "playstyle",
+ "post_count",
+ "profile_hue",
+ "profile_order",
+ "title",
+ "title_url",
+ "twitter",
+ "website",
+ # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserTransformer.php#L13C22-L25
+ "cover",
+ "country",
+ "is_admin",
+ "is_bng",
+ "is_full_bn",
+ "is_gmt",
+ "is_limited_bn",
+ "is_moderator",
+ "is_nat",
+ "is_restricted",
+ "is_silenced",
+ "kudosu",
+ ]
+
+ # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L41-L51
+ PROFILE_HEADER_INCLUDES: ClassVar[list[str]] = [
+ "active_tournament_banner",
+ "active_tournament_banners",
+ "badges",
+ "comments_count",
+ "follower_count",
+ "groups",
+ "mapping_follower_count",
+ "previous_usernames",
+ "support_level",
+ ]
+
+ # https://github.com/ppy/osu-web/blob/3f08fe12d70bcac1e32455c31e984eb6ef589b42/app/Http/Controllers/UsersController.php#L900-L937
+ USER_INCLUDES: ClassVar[list[str]] = [
+ # == apiIncludes ==
+ # historical
+ "beatmap_playcounts_count",
+ "monthly_playcounts",
+ "replays_watched_counts",
+ "scores_recent_count",
+ # beatmapsets
+ "favourite_beatmapset_count",
+ "graveyard_beatmapset_count",
+ "guest_beatmapset_count",
+ "loved_beatmapset_count",
+ "nominated_beatmapset_count",
+ "pending_beatmapset_count",
+ "ranked_beatmapset_count",
+ # top scores
+ "scores_best_count",
+ "scores_first_count",
+ "scores_pinned_count",
+ # others
+ "account_history",
+ "current_season_stats",
+ "daily_challenge_user_stats",
+ "page",
+ "pending_beatmapset_count",
+ "rank_highest",
+ "rank_history",
+ "statistics",
+ "statistics.country_rank",
+ "statistics.rank",
+ "statistics.variants",
+ "team",
+ "user_achievements",
+ *PROFILE_HEADER_INCLUDES,
+ *USER_TRANSFORMER_INCLUDES,
+ ]
+
+ # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L133-L150
avatar_url: str = "https://lazer-data.g0v0.top/default.jpg"
country_code: str = Field(default="CN", max_length=2, index=True)
# ? default_group: str|None
+ id: int = Field(
+ default=None,
+ sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
+ )
is_active: bool = True
is_bot: bool = False
is_supporter: bool = False
@@ -87,45 +274,45 @@ class UserBase(UTCBaseModel, SQLModel):
pm_friends_only: bool = False
profile_colour: str | None = None
username: str = Field(max_length=32, unique=True, index=True)
- page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
- previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
- support_level: int = 0
- badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
+
+ page: OnDemand[Page] = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
+ previous_usernames: OnDemand[list[str]] = Field(default_factory=list, sa_column=Column(JSON))
+ support_level: OnDemand[int] = Field(default=0)
+ badges: OnDemand[list[Badge]] = Field(default_factory=list, sa_column=Column(JSON))
# optional
# blocks
- cover: UserProfileCover = Field(
+ cover: OnDemand[UserProfileCover] = Field(
default=UserProfileCover(url=""),
sa_column=Column(JSON),
)
- beatmap_playcounts_count: int = 0
# kudosu
# UserExtended
- playmode: GameMode = GameMode.OSU
- discord: str | None = None
- has_supported: bool = False
- interests: str | None = None
- join_date: datetime = Field(default_factory=utcnow)
- location: str | None = None
- max_blocks: int = 50
- max_friends: int = 500
- occupation: str | None = None
- playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON))
+ playmode: OnDemand[GameMode] = Field(default=GameMode.OSU)
+ discord: OnDemand[str | None] = Field(default=None)
+ has_supported: OnDemand[bool] = Field(default=False)
+ interests: OnDemand[str | None] = Field(default=None)
+ join_date: OnDemand[datetime] = Field(default_factory=utcnow)
+ location: OnDemand[str | None] = Field(default=None)
+ max_blocks: OnDemand[int] = Field(default=50)
+ max_friends: OnDemand[int] = Field(default=500)
+ occupation: OnDemand[str | None] = Field(default=None)
+ playstyle: OnDemand[list[str]] = Field(default_factory=list, sa_column=Column(JSON))
# TODO: post_count
- profile_hue: int | None = None
- title: str | None = None
- title_url: str | None = None
- twitter: str | None = None
- website: str | None = None
+ profile_hue: OnDemand[int | None] = Field(default=None)
+ title: OnDemand[str | None] = Field(default=None)
+ title_url: OnDemand[str | None] = Field(default=None)
+ twitter: OnDemand[str | None] = Field(default=None)
+ website: OnDemand[str | None] = Field(default=None)
# undocumented
- comments_count: int = 0
- post_count: int = 0
- is_admin: bool = False
- is_gmt: bool = False
- is_qat: bool = False
- is_bng: bool = False
+ comments_count: OnDemand[int] = Field(default=0)
+ post_count: OnDemand[int] = Field(default=0)
+ is_admin: OnDemand[bool] = Field(default=False)
+ is_gmt: OnDemand[bool] = Field(default=False)
+ is_qat: OnDemand[bool] = Field(default=False)
+ is_bng: OnDemand[bool] = Field(default=False)
# g0v0-extra
g0v0_playmode: GameMode = GameMode.OSU
@@ -142,14 +329,388 @@ class UserBase(UTCBaseModel, SQLModel):
return GameMode.OSU
return v
+ @ondemand
+ @staticmethod
+ async def groups(_session: AsyncSession, _obj: "User") -> list[str]:
+ return []
-class User(AsyncAttrs, UserBase, table=True):
+ @ondemand
+ @staticmethod
+ async def active_tournament_banners(_session: AsyncSession, _obj: "User") -> list[dict]:
+ return []
+
+ @ondemand
+ @staticmethod
+ async def graveyard_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def loved_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def mapping_follower_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def nominated_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def guest_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def pending_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def ranked_beatmapset_count(_session: AsyncSession, _obj: "User") -> int:
+ return 0
+
+ @ondemand
+ @staticmethod
+ async def follow_user_mapping(_session: AsyncSession, _obj: "User") -> list[int]:
+ return []
+
+ @ondemand
+ @staticmethod
+ async def is_deleted(_session: AsyncSession, _obj: "User") -> bool:
+ return False
+
+ @ondemand
+ @staticmethod
+ async def country(_session: AsyncSession, obj: "User") -> Country:
+ return Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
+
+ @ondemand
+ @staticmethod
+ async def favourite_beatmapset_count(session: AsyncSession, obj: "User") -> int:
+ from .favourite_beatmapset import FavouriteBeatmapset
+
+ return (
+ await session.exec(
+ select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
+ )
+ ).one()
+
+ @ondemand
+ @staticmethod
+ async def follower_count(session: AsyncSession, obj: "User") -> int:
+ from .relationship import Relationship, RelationshipType
+
+ stmt = (
+ select(func.count())
+ .select_from(Relationship)
+ .where(
+ Relationship.target_id == obj.id,
+ Relationship.type == RelationshipType.FOLLOW,
+ )
+ )
+ return (await session.exec(stmt)).one()
+
+ @ondemand
+ @staticmethod
+ async def scores_best_count(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> int:
+ from .best_scores import BestScore
+
+ mode = ruleset or obj.playmode
+ stmt = (
+ select(func.count())
+ .select_from(BestScore)
+ .where(
+ BestScore.user_id == obj.id,
+ BestScore.gamemode == mode,
+ )
+ .limit(200)
+ )
+ return (await session.exec(stmt)).one()
+
+ @ondemand
+ @staticmethod
+ async def scores_pinned_count(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> int:
+ from .score import Score
+
+ mode = ruleset or obj.playmode
+ stmt = (
+ select(func.count())
+ .select_from(Score)
+ .where(
+ Score.user_id == obj.id,
+ Score.gamemode == mode,
+ Score.pinned_order > 0,
+ col(Score.passed).is_(True),
+ )
+ )
+ return (await session.exec(stmt)).one()
+
+ @ondemand
+ @staticmethod
+ async def scores_recent_count(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> int:
+ from .score import Score
+
+ mode = ruleset or obj.playmode
+ stmt = (
+ select(func.count())
+ .select_from(Score)
+ .where(
+ Score.user_id == obj.id,
+ Score.gamemode == mode,
+ col(Score.passed).is_(True),
+ Score.ended_at > utcnow() - timedelta(hours=24),
+ )
+ )
+ return (await session.exec(stmt)).one()
+
+ @ondemand
+ @staticmethod
+ async def scores_first_count(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> int:
+ from .score import get_user_first_score_count
+
+ mode = ruleset or obj.playmode
+ return await get_user_first_score_count(session, obj.id, mode)
+
+ @ondemand
+ @staticmethod
+ async def beatmap_playcounts_count(session: AsyncSession, obj: "User") -> int:
+ stmt = select(func.count()).select_from(BeatmapPlaycounts).where(BeatmapPlaycounts.user_id == obj.id)
+ return (await session.exec(stmt)).one()
+
+ @ondemand
+ @staticmethod
+ async def cover_url(_session: AsyncSession, obj: "User") -> str:
+ return obj.cover.get("url", "") if obj.cover else ""
+
+ @ondemand
+ @staticmethod
+ async def profile_order(_session: AsyncSession, obj: "User") -> list[str]:
+ await obj.awaitable_attrs.user_preference
+ if obj.user_preference:
+ return list(obj.user_preference.extras_order)
+ return list(DEFAULT_ORDER)
+
+ @ondemand
+ @staticmethod
+ async def user_preference(_session: AsyncSession, obj: "User") -> UserPreference | None:
+ await obj.awaitable_attrs.user_preference
+ return obj.user_preference
+
+ @ondemand
+ @staticmethod
+ async def friends(session: AsyncSession, obj: "User") -> list["RelationshipDict"]:
+ from .relationship import Relationship, RelationshipType
+
+ relationships = (
+ await session.exec(
+ select(Relationship).where(
+ Relationship.user_id == obj.id,
+ Relationship.type == RelationshipType.FOLLOW,
+ )
+ )
+ ).all()
+ return [await RelationshipModel.transform(rel, ruleset=obj.playmode) for rel in relationships]
+
+ @ondemand
+ @staticmethod
+ async def team(_session: AsyncSession, obj: "User") -> Team | None:
+ membership = await obj.awaitable_attrs.team_membership
+ return membership.team if membership else None
+
+ @ondemand
+ @staticmethod
+ async def account_history(_session: AsyncSession, obj: "User") -> list[UserAccountHistoryResp]:
+ await obj.awaitable_attrs.account_history
+ return [UserAccountHistoryResp.from_db(ah) for ah in obj.account_history]
+
+ @ondemand
+ @staticmethod
+ async def daily_challenge_user_stats(_session: AsyncSession, obj: "User") -> DailyChallengeStatsResp | None:
+ stats = await obj.awaitable_attrs.daily_challenge_stats
+ return DailyChallengeStatsResp.from_db(stats) if stats else None
+
+ @ondemand
+ @staticmethod
+ async def statistics(
+ _session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ includes: list[str] | None = None,
+ ) -> "UserStatisticsDict | None":
+ mode = ruleset or obj.playmode
+ for stat in await obj.awaitable_attrs.statistics:
+ if stat.mode == mode:
+ return await UserStatisticsModel.transform(stat, user_country=obj.country_code, includes=includes)
+ return None
+
+ @ondemand
+ @staticmethod
+ async def statistics_rulesets(
+ _session: AsyncSession,
+ obj: "User",
+ includes: list[str] | None = None,
+ ) -> dict[str, "UserStatisticsDict"]:
+ stats = await obj.awaitable_attrs.statistics
+ result: dict[str, UserStatisticsDict] = {}
+ for stat in stats:
+ result[stat.mode.value] = await UserStatisticsModel.transform(
+ stat, user_country=obj.country_code, includes=includes
+ )
+ return result
+
+ @ondemand
+ @staticmethod
+ async def monthly_playcounts(_session: AsyncSession, obj: "User") -> list[CountResp]:
+ playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
+ if len(playcounts) == 1:
+ d = playcounts[0].start_date
+ playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
+ return playcounts
+
+ @ondemand
+ @staticmethod
+ async def replay_watched_counts(_session: AsyncSession, obj: "User") -> list[CountResp]:
+ counts = [CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts]
+ if len(counts) == 1:
+ d = counts[0].start_date
+ counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
+ return counts
+
+ @ondemand
+ @staticmethod
+ async def user_achievements(_session: AsyncSession, obj: "User") -> list[UserAchievementResp]:
+ return [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
+
+ @ondemand
+ @staticmethod
+ async def rank_history(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> RankHistoryResp | None:
+ mode = ruleset or obj.playmode
+ rank_history = await RankHistoryResp.from_db(session, obj.id, mode)
+ return rank_history if len(rank_history.data) != 0 else None
+
+ @ondemand
+ @staticmethod
+ async def rank_highest(
+ session: AsyncSession,
+ obj: "User",
+ ruleset: GameMode | None = None,
+ ) -> RankHighest | None:
+ mode = ruleset or obj.playmode
+ rank_top = (await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == mode))).first()
+ if not rank_top:
+ return None
+ return RankHighest(
+ rank=rank_top.rank,
+ updated_at=datetime.combine(rank_top.date, datetime.min.time()),
+ )
+
+ @ondemand
+ @staticmethod
+ async def is_restricted(session: AsyncSession, obj: "User") -> bool:
+ return await obj.is_restricted(session)
+
+ @ondemand
+ @staticmethod
+ async def kudosu(_session: AsyncSession, _obj: "User") -> Kudosu:
+ return Kudosu(available=0, total=0) # TODO
+
+ @ondemand
+ @staticmethod
+ async def unread_pm_count(session: AsyncSession, obj: "User") -> int:
+ return (
+ await session.exec(
+ select(func.count())
+ .join(Notification, col(Notification.id) == UserNotification.notification_id)
+ .select_from(UserNotification)
+ .where(
+ col(UserNotification.is_read).is_(False),
+ UserNotification.user_id == obj.id,
+ Notification.name == NotificationName.CHANNEL_MESSAGE,
+ text("details->>'$.type' = 'pm'"),
+ )
+ )
+ ).one()
+
+ @included
+ @staticmethod
+ async def default_group(_session: AsyncSession, obj: "User") -> str:
+ return "default" if not obj.is_bot else "bot"
+
+ @included
+ @staticmethod
+ async def is_online(_session: AsyncSession, obj: "User") -> bool:
+ from app.dependencies.database import get_redis
+
+ redis = get_redis()
+ return bool(await redis.exists(f"metadata:online:{obj.id}"))
+
+ @ondemand
+ @staticmethod
+ async def session_verified(
+ session: AsyncSession,
+ obj: "User",
+ token_id: int | None = None,
+ ) -> bool:
+ from app.service.verification_service import LoginSessionService
+
+ return (
+ not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
+ if token_id
+ else True
+ )
+
+ @ondemand
+ @staticmethod
+ async def session_verification_method(
+ session: AsyncSession,
+ obj: "User",
+ token_id: int | None = None,
+ ) -> Literal["totp", "mail"] | None:
+ from app.dependencies.database import get_redis
+ from app.service.verification_service import LoginSessionService
+
+ if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
+ redis = get_redis()
+ if not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id):
+ return None
+ return await LoginSessionService.get_login_method(obj.id, token_id, redis)
+ return None
+
+
+class User(AsyncAttrs, UserModel, table=True):
__tablename__: str = "lazer_users"
- id: int = Field(
- default=None,
- sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
- )
+ email: str = Field(max_length=254, unique=True, index=True)
+ priv: int = Field(default=1)
+ pw_bcrypt: str = Field(max_length=60)
+ silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)))
+ donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)))
+
account_history: list[UserAccountHistory] = Relationship(back_populates="user")
statistics: list[UserStatistics] = Relationship(back_populates="user")
achievement: list[UserAchievement] = Relationship(back_populates="user")
@@ -166,12 +727,6 @@ class User(AsyncAttrs, UserBase, table=True):
totp_key: TotpKeys | None = Relationship(back_populates="user")
user_preference: UserPreference | None = Relationship(back_populates="user")
- email: str = Field(max_length=254, unique=True, index=True, exclude=True)
- priv: int = Field(default=1, exclude=True)
- pw_bcrypt: str = Field(max_length=60, exclude=True)
- silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
- donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
-
async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType
@@ -241,303 +796,8 @@ class User(AsyncAttrs, UserBase, table=True):
return active_restrictions or False
-class UserResp(UserBase):
- id: int | None = None
- is_online: bool = False
- groups: list = [] # TODO
- country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
- favourite_beatmapset_count: int = 0
- graveyard_beatmapset_count: int = 0 # TODO
- guest_beatmapset_count: int = 0 # TODO
- loved_beatmapset_count: int = 0 # TODO
- mapping_follower_count: int = 0 # TODO
- nominated_beatmapset_count: int = 0 # TODO
- pending_beatmapset_count: int = 0 # TODO
- ranked_beatmapset_count: int = 0 # TODO
- follow_user_mapping: list[int] = Field(default_factory=list)
- follower_count: int = 0
- friends: list["RelationshipResp"] | None = None
- scores_best_count: int = 0
- scores_first_count: int = 0
- scores_recent_count: int = 0
- scores_pinned_count: int = 0
- beatmap_playcounts_count: int = 0
- account_history: list[UserAccountHistoryResp] = []
- active_tournament_banners: list[dict] = [] # TODO
- kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
- monthly_playcounts: list[CountResp] = Field(default_factory=list)
- replay_watched_counts: list[CountResp] = Field(default_factory=list)
- unread_pm_count: int = 0 # TODO
- rank_history: RankHistoryResp | None = None
- rank_highest: RankHighest | None = None
- statistics: UserStatisticsResp | None = None
- statistics_rulesets: dict[str, UserStatisticsResp] | None = None
- user_achievements: list[UserAchievementResp] = Field(default_factory=list)
- cover_url: str = "" # deprecated
- team: Team | None = None
- daily_challenge_user_stats: DailyChallengeStatsResp | None = None
- default_group: str = ""
- is_deleted: bool = False # TODO
- is_restricted: bool = False
- user_preference: UserPreference | None = None
- profile_order: list[str] = Field(
- default_factory=lambda: DEFAULT_ORDER,
- )
-
- # TODO: unread_pm_count
-
- @classmethod
- async def from_db(
- cls,
- obj: User,
- session: AsyncSession,
- include: list[str] = [],
- ruleset: GameMode | None = None,
- ) -> "UserResp":
- from app.dependencies.database import get_redis
-
- from .best_scores import BestScore
- from .favourite_beatmapset import FavouriteBeatmapset
- from .relationship import Relationship, RelationshipResp, RelationshipType
- from .score import Score, get_user_first_score_count
- from .total_score_best_scores import TotalScoreBestScore
-
- ruleset = ruleset or obj.playmode
-
- u = cls.model_validate(obj.model_dump())
- u.id = obj.id
- u.default_group = "bot" if u.is_bot else "default"
- u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
- u.follower_count = (
- await session.exec(
- select(func.count())
- .select_from(Relationship)
- .where(
- Relationship.target_id == obj.id,
- Relationship.type == RelationshipType.FOLLOW,
- )
- )
- ).one()
- u.scores_best_count = (
- await session.exec(
- select(func.count())
- .select_from(TotalScoreBestScore)
- .where(
- TotalScoreBestScore.user_id == obj.id,
- )
- .limit(200)
- )
- ).one()
- redis = get_redis()
- u.is_online = bool(await redis.exists(f"metadata:online:{obj.id}"))
- u.cover_url = obj.cover.get("url", "") if obj.cover else ""
-
- await obj.awaitable_attrs.user_preference
- if obj.user_preference:
- u.profile_order = obj.user_preference.extras_order
-
- if "user_preference" in include:
- u.user_preference = obj.user_preference
-
- if "friends" in include:
- u.friends = [
- await RelationshipResp.from_db(session, r)
- for r in (
- await session.exec(
- select(Relationship).where(
- Relationship.user_id == obj.id,
- Relationship.type == RelationshipType.FOLLOW,
- )
- )
- ).all()
- ]
-
- if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
- u.team = team_membership.team
-
- if "account_history" in include:
- u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
-
- if "daily_challenge_user_stats" in include and (
- daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
- ):
- u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
-
- if "statistics" in include:
- current_stattistics = None
- for i in await obj.awaitable_attrs.statistics:
- if i.mode == ruleset:
- current_stattistics = i
- break
- u.statistics = (
- await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
- if current_stattistics
- else None
- )
-
- if "statistics_rulesets" in include:
- u.statistics_rulesets = {
- i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
- for i in await obj.awaitable_attrs.statistics
- }
-
- if "monthly_playcounts" in include:
- u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
- if len(u.monthly_playcounts) == 1:
- d = u.monthly_playcounts[0].start_date
- u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
-
- if "replays_watched_counts" in include:
- u.replay_watched_counts = [
- CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
- ]
- if len(u.replay_watched_counts) == 1:
- d = u.replay_watched_counts[0].start_date
- u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
-
- if "achievements" in include:
- u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
- if "rank_history" in include:
- rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
- if len(rank_history.data) != 0:
- u.rank_history = rank_history
-
- rank_top = (
- await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
- ).first()
- if rank_top:
- u.rank_highest = (
- RankHighest(
- rank=rank_top.rank,
- updated_at=datetime.combine(rank_top.date, datetime.min.time()),
- )
- if rank_top
- else None
- )
- if "is_restricted" in include:
- u.is_restricted = await obj.is_restricted(session)
-
- u.favourite_beatmapset_count = (
- await session.exec(
- select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
- )
- ).one()
- u.scores_pinned_count = (
- await session.exec(
- select(func.count())
- .select_from(Score)
- .where(
- Score.user_id == obj.id,
- Score.pinned_order > 0,
- Score.gamemode == ruleset,
- col(Score.passed).is_(True),
- )
- )
- ).one()
- u.scores_best_count = (
- await session.exec(
- select(func.count())
- .select_from(BestScore)
- .where(
- BestScore.user_id == obj.id,
- BestScore.gamemode == ruleset,
- )
- .limit(200)
- )
- ).one()
- u.scores_recent_count = (
- await session.exec(
- select(func.count())
- .select_from(Score)
- .where(
- Score.user_id == obj.id,
- Score.gamemode == ruleset,
- col(Score.passed).is_(True),
- Score.ended_at > utcnow() - timedelta(hours=24),
- )
- )
- ).one()
- u.scores_first_count = await get_user_first_score_count(session, obj.id, ruleset)
- u.beatmap_playcounts_count = (
- await session.exec(
- select(func.count())
- .select_from(BeatmapPlaycounts)
- .where(
- BeatmapPlaycounts.user_id == obj.id,
- )
- )
- ).one()
-
- return u
-
-
-class MeResp(UserResp):
- session_verification_method: Literal["totp", "mail"] | None = None
- session_verified: bool = True
-
- @classmethod
- async def from_db(
- cls,
- obj: User,
- session: AsyncSession,
- ruleset: GameMode | None = None,
- *,
- token_id: int | None = None,
- ) -> "MeResp":
- from app.dependencies.database import get_redis
- from app.service.verification_service import LoginSessionService
-
- u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
- u.session_verified = (
- not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
- if token_id
- else True
- )
- u = cls.model_validate(u.model_dump())
- if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
- redis = get_redis()
- if not u.session_verified:
- u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis)
- else:
- u.session_verification_method = None
- return u
-
-
-ALL_INCLUDED = [
- "friends",
- "team",
- "account_history",
- "daily_challenge_user_stats",
- "statistics",
- "statistics_rulesets",
- "achievements",
- "monthly_playcounts",
- "replays_watched_counts",
- "rank_history",
- "is_restricted",
- "session_verified",
- "user_preference",
-]
-
-
-SEARCH_INCLUDED = [
- "team",
- "daily_challenge_user_stats",
- "statistics",
- "statistics_rulesets",
- "achievements",
- "monthly_playcounts",
- "replays_watched_counts",
- "rank_history",
-]
-
-BASE_INCLUDES = [
- "team",
- "daily_challenge_user_stats",
- "statistics",
-]
-
-RANKING_INCLUDES = [
- "team",
- "statistics",
-]
+# 为了向后兼容,在 SQL 查询中使用 User
+# 例如: select(User).where(User.id == 1)
+# 但类型注解和返回值使用 User
+# 例如: async def get_user() -> User | None:
+# return (await session.exec(select(User)...)).first()
diff --git a/app/fetcher/beatmap.py b/app/fetcher/beatmap.py
index 272f41f..206cf34 100644
--- a/app/fetcher/beatmap.py
+++ b/app/fetcher/beatmap.py
@@ -1,13 +1,40 @@
-from app.database.beatmap import BeatmapResp
+from app.database.beatmap import BeatmapDict, BeatmapModel
from app.log import fetcher_logger
from ._base import BaseFetcher
+from pydantic import TypeAdapter
+
logger = fetcher_logger("BeatmapFetcher")
+adapter = TypeAdapter(
+ BeatmapModel.generate_typeddict(
+ (
+ "checksum",
+ "accuracy",
+ "ar",
+ "bpm",
+ "convert",
+ "count_circles",
+ "count_sliders",
+ "count_spinners",
+ "cs",
+ "deleted_at",
+ "drain",
+ "hit_length",
+ "is_scoreable",
+ "last_updated",
+ "mode_int",
+ "ranked",
+ "url",
+ "max_combo",
+ "beatmapset",
+ )
+ )
+)
class BeatmapFetcher(BaseFetcher):
- async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
+ async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapDict:
if beatmap_id:
params = {"id": beatmap_id}
elif beatmap_checksum:
@@ -16,7 +43,7 @@ class BeatmapFetcher(BaseFetcher):
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
logger.opt(colors=True).debug(f"get_beatmap: {params}")
- return BeatmapResp.model_validate(
+ return adapter.validate_python( # pyright: ignore[reportReturnType]
await self.request_api(
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
params=params,
diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py
index 893552f..910163f 100644
--- a/app/fetcher/beatmapset.py
+++ b/app/fetcher/beatmapset.py
@@ -3,7 +3,7 @@ import base64
import hashlib
import json
-from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
+from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import fetcher_logger
from app.models.beatmap import SearchQueryModel
@@ -13,6 +13,7 @@ from app.utils import bg_tasks
from ._base import BaseFetcher
from httpx import AsyncClient
+from pydantic import TypeAdapter
import redis.asyncio as redis
@@ -26,6 +27,46 @@ logger = fetcher_logger("BeatmapsetFetcher")
MAX_RETRY_ATTEMPTS = 3
+adapter = TypeAdapter(
+ BeatmapsetModel.generate_typeddict(
+ (
+ "availability",
+ "bpm",
+ "last_updated",
+ "ranked",
+ "ranked_date",
+ "submitted_date",
+ "tags",
+ "storyboard",
+ "description",
+ "genre",
+ "language",
+ *[
+ f"beatmaps.{inc}"
+ for inc in (
+ "checksum",
+ "accuracy",
+ "ar",
+ "bpm",
+ "convert",
+ "count_circles",
+ "count_sliders",
+ "count_spinners",
+ "cs",
+ "deleted_at",
+ "drain",
+ "hit_length",
+ "is_scoreable",
+ "last_updated",
+ "mode_int",
+ "ranked",
+ "url",
+ "max_combo",
+ )
+ ],
+ )
+ )
+)
class BeatmapsetFetcher(BaseFetcher):
@@ -139,10 +180,9 @@ class BeatmapsetFetcher(BaseFetcher):
except Exception:
return {}
- async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
+ async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetDict:
logger.opt(colors=True).debug(f"get_beatmapset: {beatmap_set_id}")
-
- return BeatmapsetResp.model_validate(
+ return adapter.validate_python( # pyright: ignore[reportReturnType]
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
)
diff --git a/app/models/chat.py b/app/models/chat.py
index 0e2b9da..c77bbdb 100644
--- a/app/models/chat.py
+++ b/app/models/chat.py
@@ -1,8 +1,6 @@
-from typing import Any
-
-from pydantic import BaseModel
+from typing import Any, TypedDict
-class ChatEvent(BaseModel):
+class ChatEvent(TypedDict):
event: str
- data: dict[str, Any] | None = None
+ data: dict[str, Any] | None
diff --git a/app/models/model.py b/app/models/model.py
index 4c28048..a44a9df 100644
--- a/app/models/model.py
+++ b/app/models/model.py
@@ -3,16 +3,16 @@ from datetime import UTC, datetime
from app.models.score import GameMode
-from pydantic import BaseModel, field_serializer
+from pydantic import BaseModel, FieldSerializationInfo, field_serializer
class UTCBaseModel(BaseModel):
- @field_serializer("*", when_used="json")
- def serialize_datetime(self, v, _info):
+ @field_serializer("*", when_used="always")
+ def serialize_datetime(self, v, _info: FieldSerializationInfo):
if isinstance(v, datetime):
if v.tzinfo is None:
v = v.replace(tzinfo=UTC)
- return v.astimezone(UTC).isoformat().replace("+00:00", "Z")
+ return v.astimezone(UTC)
return v
diff --git a/app/router/lio.py b/app/router/lio.py
index 33fbaa1..e9bcc18 100644
--- a/app/router/lio.py
+++ b/app/router/lio.py
@@ -447,7 +447,7 @@ async def create_multiplayer_room(
# 让房主加入频道
host_user = await db.get(User, host_user_id)
if host_user:
- await server.batch_join_channel([host_user], channel, db)
+ await server.batch_join_channel([host_user], channel)
# Add playlist items
await _add_playlist_items(db, room_id, room_data, host_user_id)
diff --git a/app/router/notification/banchobot.py b/app/router/notification/banchobot.py
index 39a01c5..84f41be 100644
--- a/app/router/notification/banchobot.py
+++ b/app/router/notification/banchobot.py
@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
from math import ceil
import random
import shlex
+from typing import TYPE_CHECKING
from app.calculator import calculate_weighted_pp
from app.const import BANCHOBOT_ID
-from app.database import ChatMessageResp
-from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
+
+if TYPE_CHECKING:
+ pass
+from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
from app.database.score import Score, get_best_id
from app.database.statistics import UserStatistics, get_rank
from app.database.user import User
@@ -95,7 +98,7 @@ class Bot:
await session.commit()
await session.refresh(msg)
await session.refresh(bot)
- resp = await ChatMessageResp.from_db(msg, session, bot)
+ resp = await ChatMessageModel.transform(msg, includes=["sender"])
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
@@ -119,7 +122,7 @@ class Bot:
await session.refresh(channel)
await session.refresh(user)
await session.refresh(bot)
- await server.batch_join_channel([user, bot], channel, session)
+ await server.batch_join_channel([user, bot], channel)
return channel
async def _send_reply(
diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py
index 3d98dda..4428572 100644
--- a/app/router/notification/channel.py
+++ b/app/router/notification/channel.py
@@ -1,37 +1,40 @@
-from typing import Annotated, Any, Literal, Self
+from typing import Annotated, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
- ChatChannelResp,
+ ChatChannelModel,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
-from app.database.user import User, UserResp
+from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
+from app.utils import api_doc
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, model_validator
from sqlmodel import col, select
-class UpdateResponse(BaseModel):
- presence: list[ChatChannelResp] = Field(default_factory=list)
- silences: list[Any] = Field(default_factory=list)
-
-
@router.get(
"/chat/updates",
- response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
+ responses={
+ 200: api_doc(
+ "获取更新响应。",
+ {"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
+ ChatChannel.LISTING_INCLUDES,
+ name="UpdateResponse",
+ )
+ },
)
async def get_update(
session: Database,
@@ -44,45 +47,44 @@ async def get_update(
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
):
- resp = UpdateResponse()
+ resp = {
+ "presence": [],
+ "silences": [],
+ }
if "presence" in includes:
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
- # 提取必要的属性避免惰性加载
- channel_type = db_channel.type
-
- resp.presence.append(
- await ChatChannelResp.from_db(
+ resp["presence"].append(
+ await ChatChannelModel.transform(
db_channel,
- session,
- current_user,
- redis,
- server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
+ user=current_user,
+ server=server,
+ includes=ChatChannel.LISTING_INCLUDES,
)
)
if "silences" in includes:
if history_since:
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
- resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
+ resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
- resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
+ resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
- response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
+ responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
)
async def join_channel(
session: Database,
@@ -101,7 +103,7 @@ async def join_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
- return await server.join_channel(current_user, db_channel, session)
+ return await server.join_channel(current_user, db_channel)
@router.delete(
@@ -128,13 +130,13 @@ async def leave_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
- await server.leave_channel(current_user, db_channel, session)
+ await server.leave_channel(current_user, db_channel)
return
@router.get(
"/chat/channels",
- response_model=list[ChatChannelResp],
+ responses={200: api_doc("加入的频道", list[ChatChannelModel])},
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
@@ -142,35 +144,30 @@ async def leave_channel(
async def get_channel_list(
session: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
- redis: Redis,
):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
- results = []
- for channel in channels:
- # 提取必要的属性避免惰性加载
- channel_id = channel.channel_id
- channel_type = channel.type
+ results = await ChatChannelModel.transform_many(
+ channels,
+ user=current_user,
+ server=server,
+ )
- results.append(
- await ChatChannelResp.from_db(
- channel,
- session,
- current_user,
- redis,
- server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
- )
- )
return results
-class GetChannelResp(BaseModel):
- channel: ChatChannelResp
- users: list[UserResp] = Field(default_factory=list)
-
-
@router.get(
"/chat/channels/{channel}",
- response_model=GetChannelResp,
+ responses={
+ 200: api_doc(
+ "频道详细信息",
+ {
+ "channel": ChatChannelModel,
+ "users": list[UserModel],
+ },
+ ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
+ name="GetChannelResponse",
+ )
+ },
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
@@ -191,7 +188,6 @@ async def get_channel(
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
- channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
@@ -209,15 +205,15 @@ async def get_channel(
users.extend([target_user, current_user])
break
- return GetChannelResp(
- channel=await ChatChannelResp.from_db(
+ return {
+ "channel": await ChatChannelModel.transform(
db_channel,
- session,
- current_user,
- redis,
- server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
- )
- )
+ user=current_user,
+ server=server,
+ includes=ChatChannel.LISTING_INCLUDES,
+ ),
+ "users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
+ }
class CreateChannelReq(BaseModel):
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
@router.post(
"/chat/channels",
- response_model=ChatChannelResp,
+ responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
@@ -289,21 +285,13 @@ async def create_channel(
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
- await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
+ await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
- await server.batch_join_channel([*target_users, current_user], channel, session)
+ await server.batch_join_channel([*target_users, current_user], channel)
- await server.join_channel(current_user, channel, session)
+ await server.join_channel(current_user, channel)
- # 提取必要的属性避免惰性加载
- channel_id = channel.channel_id
-
- return await ChatChannelResp.from_db(
- channel,
- session,
- current_user,
- redis,
- server.channels.get(channel_id, []),
- include_recent_messages=True,
+ return await ChatChannelModel.transform(
+ channel, user=current_user, server=server, includes=["recent_messages.sender"]
)
diff --git a/app/router/notification/message.py b/app/router/notification/message.py
index 4d5312c..f5b2d5b 100644
--- a/app/router/notification/message.py
+++ b/app/router/notification/message.py
@@ -1,11 +1,11 @@
from typing import Annotated
-from app.database import ChatMessageResp
+from app.database import ChatChannelModel
from app.database.chat import (
ChannelType,
ChatChannel,
- ChatChannelResp,
ChatMessage,
+ ChatMessageModel,
MessageType,
SilenceUser,
UserSilenceResp,
@@ -18,6 +18,7 @@ from app.log import log
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from app.service.redis_message_system import redis_message_system
+from app.utils import api_doc
from .banchobot import bot
from .server import server
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
@router.post(
"/chat/channels/{channel}/messages",
- response_model=ChatMessageResp,
+ responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
name="发送消息",
description="发送消息到指定频道。",
tags=["聊天"],
@@ -130,7 +131,7 @@ async def send_message(
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage(
- message_id=resp.message_id, # 使用 Redis 系统生成的ID
+ message_id=resp["message_id"], # 使用 Redis 系统生成的ID
channel_id=channel_id,
content=req.message,
sender_id=user_id,
@@ -151,7 +152,7 @@ async def send_message(
@router.get(
"/chat/channels/{channel}/messages",
- response_model=list[ChatMessageResp],
+ responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
name="获取消息",
description="获取指定频道的消息列表(统一按时间正序返回)。",
tags=["聊天"],
@@ -177,7 +178,7 @@ async def get_message(
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
- if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
+ if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
messages.reverse()
return messages
except Exception as e:
@@ -189,7 +190,7 @@ async def get_message(
# 向前加载新消息 → 直接 ASC
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
rows = (await session.exec(query)).all()
- resp = [await ChatMessageResp.from_db(m, session) for m in rows]
+ resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
# 已经 ASC,无需反转
return resp
@@ -202,15 +203,14 @@ async def get_message(
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
- resp = [await ChatMessageResp.from_db(m, session) for m in rows]
+ resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
- resp = [await ChatMessageResp.from_db(m, session) for m in rows]
- return resp
+ resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
uuid: str | None = None
-class NewPMResp(BaseModel):
- channel: ChatChannelResp
- message: ChatMessageResp
- new_channel_id: int
-
-
@router.post(
"/chat/new",
name="创建私聊频道",
description="创建一个新的私聊频道。",
tags=["聊天"],
+ responses={
+ 200: api_doc(
+ "创建私聊频道响应",
+ {
+ "channel": ChatChannelModel,
+ "message": ChatMessageModel,
+ "new_channel_id": int,
+ },
+ ["recent_messages.sender", "sender"],
+ name="NewPMResponse",
+ )
+ },
)
async def create_new_pm(
session: Database,
@@ -290,9 +296,9 @@ async def create_new_pm(
await session.refresh(target)
await session.refresh(current_user)
- await server.batch_join_channel([target, current_user], channel, session)
- channel_resp = await ChatChannelResp.from_db(
- channel, session, current_user, redis, server.channels[channel.channel_id]
+ await server.batch_join_channel([target, current_user], channel)
+ channel_resp = await ChatChannelModel.transform(
+ channel, user=current_user, server=server, includes=["recent_messages.sender"]
)
msg = ChatMessage(
channel_id=channel.channel_id,
@@ -306,10 +312,10 @@ async def create_new_pm(
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
- message_resp = await ChatMessageResp.from_db(msg, session, current_user)
+ message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
await server.send_message_to_channel(message_resp)
- return NewPMResp(
- channel=channel_resp,
- message=message_resp,
- new_channel_id=channel_resp.channel_id,
- )
+ return {
+ "channel": channel_resp,
+ "message": message_resp,
+ "new_channel_id": channel_resp["channel_id"],
+ }
diff --git a/app/router/notification/server.py b/app/router/notification/server.py
index 7818887..ad4a17a 100644
--- a/app/router/notification/server.py
+++ b/app/router/notification/server.py
@@ -1,7 +1,8 @@
import asyncio
from typing import Annotated, overload
-from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
+from app.database import ChatMessageDict
+from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
from app.database.notification import UserNotification, insert_notification
from app.database.user import User
from app.dependencies.database import (
@@ -16,7 +17,7 @@ from app.log import log
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
-from app.utils import bg_tasks
+from app.utils import bg_tasks, safe_json_dumps
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -65,7 +66,7 @@ class ChatServer:
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
- await self.leave_channel(user, db_channel, session)
+ await self.leave_channel(user, db_channel)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@@ -80,7 +81,7 @@ class ChatServer:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
- await client.send_text(event.model_dump_json())
+ await client.send_text(safe_json_dumps(event))
async def broadcast(self, channel_id: int, event: ChatEvent):
users_in_channel = self.channels.get(channel_id, [])
@@ -107,38 +108,38 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
- async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
+ async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
logger.info(
- f"Sending message to channel {message.channel_id}, message_id: "
- f"{message.message_id}, is_bot_command: {is_bot_command}"
+ f"Sending message to channel {message['channel_id']}, message_id: "
+ f"{message['message_id']}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
event="chat.message.new",
- data={"messages": [message], "users": [message.sender]},
+ data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
)
if is_bot_command:
- logger.info(f"Sending bot command to user {message.sender_id}")
- bg_tasks.add_task(self.send_event, message.sender_id, event)
+ logger.info(f"Sending bot command to user {message['sender_id']}")
+ bg_tasks.add_task(self.send_event, message["sender_id"], event)
else:
# 总是广播消息,无论是临时ID还是真实ID
- logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
+ logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
bg_tasks.add_task(
self.broadcast,
- message.channel_id,
+ message["channel_id"],
event,
)
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
- if message.message_id and message.message_id > 0:
- await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
- await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
- logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
+ if message["message_id"] and message["message_id"] > 0:
+ await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
+ await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
+ logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
else:
- logger.debug(f"Skipping last message update for message ID: {message.message_id}")
+ logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
- async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
+ async def batch_join_channel(self, users: list[User], channel: ChatChannel):
channel_id = channel.channel_id
not_joined = []
@@ -151,22 +152,18 @@ class ChatServer:
not_joined.append(user)
for user in not_joined:
- channel_resp = await ChatChannelResp.from_db(
- channel,
- session,
- user,
- self.redis,
- self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
+ channel_resp = await ChatChannelModel.transform(
+ channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
- data=channel_resp.model_dump(),
+ data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
- async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
+ async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
user_id = user.id
channel_id = channel.channel_id
@@ -175,25 +172,21 @@ class ChatServer:
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
- channel_resp = await ChatChannelResp.from_db(
- channel,
- session,
- user,
- self.redis,
- self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
+ channel_resp: ChatChannelDict = await ChatChannelModel.transform(
+ channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
- data=channel_resp.model_dump(),
+ data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
return channel_resp
- async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
+ async def leave_channel(self, user: User, channel: ChatChannel) -> None:
user_id = user.id
channel_id = channel.channel_id
@@ -203,18 +196,14 @@ class ChatServer:
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
- channel_resp = await ChatChannelResp.from_db(
- channel,
- session,
- user,
- self.redis,
- self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
+ channel_resp = await ChatChannelModel.transform(
+ channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
- data=channel_resp.model_dump(),
+ data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
@@ -232,7 +221,7 @@ class ChatServer:
return
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
- await self.join_channel(user, db_channel, session)
+ await self.join_channel(user, db_channel)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
@@ -248,7 +237,7 @@ class ChatServer:
return
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
- await self.leave_channel(user, db_channel, session)
+ await self.leave_channel(user, db_channel)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
@@ -336,6 +325,6 @@ async def chat_websocket(
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
- await server.join_channel(user, db_channel, session)
+ await server.join_channel(user, db_channel)
await _listen_stop(websocket, user_id, factory)
diff --git a/app/router/private/team.py b/app/router/private/team.py
index 826176a..1afa61b 100644
--- a/app/router/private/team.py
+++ b/app/router/private/team.py
@@ -2,7 +2,7 @@ import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
-from app.database.user import BASE_INCLUDES, User, UserResp
+from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
@@ -14,12 +14,11 @@ from app.models.notification import (
from app.models.score import GameMode
from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service
-from app.utils import check_image, utcnow
+from app.utils import api_doc, check_image, utcnow
from .router import router
from fastapi import File, Form, HTTPException, Path, Query, Request
-from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -214,12 +213,22 @@ async def delete_team(
await cache_service.invalidate_team_cache()
-class TeamQueryResp(BaseModel):
- team: TeamResp
- members: list[UserResp]
-
-
-@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
+@router.get(
+ "/team/{team_id}",
+ name="查询战队",
+ tags=["战队", "g0v0 API"],
+ responses={
+ 200: api_doc(
+ "战队信息",
+ {
+ "team": TeamResp,
+ "members": list[UserModel],
+ },
+ ["statistics", "country"],
+ name="TeamQueryResp",
+ )
+ },
+)
async def get_team(
session: Database,
team_id: Annotated[int, Path(..., description="战队 ID")],
@@ -233,10 +242,10 @@ async def get_team(
)
)
).all()
- return TeamQueryResp(
- team=await TeamResp.from_db(members[0].team, session, gamemode),
- members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
- )
+ return {
+ "team": await TeamResp.from_db(members[0].team, session, gamemode),
+ "members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]),
+ }
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
diff --git a/app/router/v1/user.py b/app/router/v1/user.py
index aee0569..29b642f 100644
--- a/app/router/v1/user.py
+++ b/app/router/v1/user.py
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Annotated, Literal
-from app.database.statistics import UserStatistics, UserStatisticsResp
+from app.database.statistics import UserStatistics, UserStatisticsModel
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.log import logger
@@ -46,7 +46,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}"
@classmethod
- async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
+ async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User":
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -54,31 +54,33 @@ class V1User(AllStrModel):
current_statistics = i
break
if current_statistics:
- statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
+ statistics = await UserStatisticsModel.transform(
+ current_statistics, country_code=db_user.country_code, includes=["country_rank"]
+ )
else:
statistics = None
return cls(
user_id=db_user.id,
username=db_user.username,
join_date=db_user.join_date,
- count300=statistics.count_300 if statistics else 0,
- count100=statistics.count_100 if statistics else 0,
- count50=statistics.count_50 if statistics else 0,
- playcount=statistics.play_count if statistics else 0,
- ranked_score=statistics.ranked_score if statistics else 0,
- total_score=statistics.total_score if statistics else 0,
- pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
+ count300=current_statistics.count_300 if current_statistics else 0,
+ count100=current_statistics.count_100 if current_statistics else 0,
+ count50=current_statistics.count_50 if current_statistics else 0,
+ playcount=current_statistics.play_count if current_statistics else 0,
+ ranked_score=current_statistics.ranked_score if current_statistics else 0,
+ total_score=current_statistics.total_score if current_statistics else 0,
+ pp_rank=statistics.get("global_rank") or 0 if statistics else 0,
level=current_statistics.level_current if current_statistics else 0,
- pp_raw=statistics.pp if statistics else 0.0,
- accuracy=statistics.hit_accuracy if statistics else 0,
+ pp_raw=current_statistics.pp if current_statistics else 0.0,
+ accuracy=current_statistics.hit_accuracy if current_statistics else 0,
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
count_rank_s=current_statistics.grade_s if current_statistics else 0,
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code,
- total_seconds_played=statistics.play_time if statistics else 0,
- pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
+ total_seconds_played=current_statistics.play_time if current_statistics else 0,
+ pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0,
events=[], # TODO
)
@@ -134,7 +136,7 @@ async def get_user(
try:
# 生成用户数据
- v1_user = await V1User.from_db(session, db_user, ruleset)
+ v1_user = await V1User.from_db(db_user, ruleset)
# 异步缓存结果(如果有用户ID)
if db_user.id is not None:
diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py
index 53cf70f..943317d 100644
--- a/app/router/v2/beatmap.py
+++ b/app/router/v2/beatmap.py
@@ -5,7 +5,11 @@ from typing import Annotated
from app.calculator import get_calculator
from app.calculators.performance import ConvertError
-from app.database import Beatmap, BeatmapResp, User
+from app.database import (
+ Beatmap,
+ BeatmapModel,
+ User,
+)
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
@@ -19,29 +23,20 @@ from app.models.performance import (
from app.models.score import (
GameMode,
)
+from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
-from pydantic import BaseModel
from sqlmodel import col, select
-class BatchGetResp(BaseModel):
- """批量获取谱面返回模型。
-
- 返回字段说明:
- - beatmaps: 谱面详细信息列表。"""
-
- beatmaps: list[BeatmapResp]
-
-
@router.get(
"/beatmaps/lookup",
tags=["谱面"],
name="查询单个谱面",
- response_model=BeatmapResp,
+ responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
@asset_proxy_response
@@ -67,14 +62,14 @@ async def lookup_beatmap(
raise HTTPException(status_code=404, detail="Beatmap not found")
await db.refresh(current_user)
- return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
+ return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
@router.get(
"/beatmaps/{beatmap_id}",
tags=["谱面"],
name="获取谱面详情",
- response_model=BeatmapResp,
+ responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description="获取单个谱面详情。",
)
@asset_proxy_response
@@ -86,7 +81,12 @@ async def get_beatmap(
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
- return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
+ await db.refresh(current_user)
+ return await BeatmapModel.transform(
+ beatmap,
+ user=current_user,
+ includes=BeatmapModel.TRANSFORMER_INCLUDES,
+ )
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -95,7 +95,11 @@ async def get_beatmap(
"/beatmaps/",
tags=["谱面"],
name="批量获取谱面",
- response_model=BatchGetResp,
+ responses={
+ 200: api_doc(
+ "谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
+ )
+ },
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
@asset_proxy_response
@@ -124,7 +128,12 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
await db.refresh(current_user)
- return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
+ return {
+ "beatmaps": [
+ await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
+ for bm in beatmaps
+ ]
+ }
@router.post(
diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py
index 5129c32..7fb685a 100644
--- a/app/router/v2/beatmapset.py
+++ b/app/router/v2/beatmapset.py
@@ -2,17 +2,24 @@ import re
from typing import Annotated, Literal
from urllib.parse import parse_qs
-from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
-from app.database.beatmapset import SearchBeatmapsetsResp
+from app.database import (
+ Beatmap,
+ Beatmapset,
+ BeatmapsetModel,
+ FavouriteBeatmapset,
+ SearchBeatmapsetsResp,
+ User,
+)
from app.dependencies.beatmap_download import DownloadService
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
-from app.dependencies.database import Database, Redis, with_db
+from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.beatmap import SearchQueryModel
from app.service.beatmapset_cache_service import generate_hash
+from app.utils import api_doc
from .router import router
@@ -27,14 +34,7 @@ from fastapi import (
)
from fastapi.responses import RedirectResponse
from httpx import HTTPError
-from sqlmodel import exists, select
-
-
-async def _save_to_db(sets: SearchBeatmapsetsResp):
- async with with_db() as session:
- for s in sets.beatmapsets:
- if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
- await Beatmapset.from_resp(session, s)
+from sqlmodel import select
@router.get(
@@ -105,7 +105,6 @@ async def search_beatmapset(
try:
sets = await fetcher.search_beatmapset(query, cursor, redis)
- background_tasks.add_task(_save_to_db, sets)
# 缓存搜索结果
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
@@ -117,8 +116,8 @@ async def search_beatmapset(
@router.get(
"/beatmapsets/lookup",
tags=["谱面集"],
+ responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="查询谱面集 (通过谱面 ID)",
- response_model=BeatmapsetResp,
description=("通过谱面 ID 查询所属谱面集。"),
)
@asset_proxy_response
@@ -137,7 +136,10 @@ async def lookup_beatmapset(
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
- resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
+
+ resp = await BeatmapsetModel.transform(
+ beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES
+ )
# 缓存结果
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
@@ -149,8 +151,8 @@ async def lookup_beatmapset(
@router.get(
"/beatmapsets/{beatmapset_id}",
tags=["谱面集"],
+ responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="获取谱面集详情",
- response_model=BeatmapsetResp,
description="获取单个谱面集详情。",
)
@asset_proxy_response
@@ -169,7 +171,8 @@ async def get_beatmapset(
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
- resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
+ await db.refresh(current_user)
+ resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user)
# 缓存结果
await cache_service.cache_beatmapset(resp)
diff --git a/app/router/v2/me.py b/app/router/v2/me.py
index 7122638..d0e94b6 100644
--- a/app/router/v2/me.py
+++ b/app/router/v2/me.py
@@ -1,20 +1,29 @@
from typing import Annotated
-from app.database import FavouriteBeatmapset, MeResp, User
+from app.database import FavouriteBeatmapset, User
+from app.database.user import UserModel
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.models.score import GameMode
+from app.utils import api_doc
from .router import router
from fastapi import Path, Security
from fastapi.responses import RedirectResponse
+from pydantic import BaseModel
from sqlmodel import select
+ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"]
+
+
+class BeatmapsetIds(BaseModel):
+ beatmapset_ids: list[int]
+
@router.get(
"/me/beatmapset-favourites",
- response_model=list[int],
+ response_model=BeatmapsetIds,
name="获取当前用户收藏的谱面集 ID 列表",
description="获取当前登录用户收藏的谱面集 ID 列表。",
tags=["用户", "谱面集"],
@@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites(
beatmapset_ids = await session.exec(
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
)
- return beatmapset_ids.all()
+ return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all()))
@router.get(
"/me/{ruleset}",
- response_model=MeResp,
+ responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)},
name="获取当前用户信息 (指定 ruleset)",
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
tags=["用户"],
)
async def get_user_info_with_ruleset(
- session: Database,
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
- user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
+ user_resp = await UserModel.transform(
+ user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES
+ )
return user_resp
@router.get(
"/me/",
- response_model=MeResp,
+ responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)},
name="获取当前用户信息",
description="获取当前登录用户信息。",
tags=["用户"],
)
async def get_user_info_default(
- session: Database,
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
- user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
+ user_resp = await UserModel.transform(
+ user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES
+ )
return user_resp
diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py
index 17aabb6..98682a6 100644
--- a/app/router/v2/ranking.py
+++ b/app/router/v2/ranking.py
@@ -1,11 +1,13 @@
from typing import Annotated, Literal
from app.config import settings
-from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
+from app.database import Team, TeamMember, User, UserStatistics
+from app.database.statistics import UserStatisticsModel
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
+from app.utils import api_doc
from .router import router
@@ -308,14 +310,16 @@ async def get_country_ranking(
return response
-class TopUsersResponse(BaseModel):
- ranking: list[UserStatisticsResp]
- total: int
-
-
@router.get(
"/rankings/{ruleset}/{sort}",
- response_model=TopUsersResponse,
+ responses={
+ 200: api_doc(
+ "用户排行榜",
+ {"ranking": list[UserStatisticsModel], "total": int},
+ ["user.country", "user.cover"],
+ name="TopUsersResponse",
+ )
+ },
name="获取用户排行榜",
description="获取在指定模式下的用户排行榜",
tags=["排行榜"],
@@ -339,10 +343,10 @@ async def get_user_ranking(
if cached_data and cached_stats:
# 从缓存返回数据
- return TopUsersResponse(
- ranking=[UserStatisticsResp.model_validate(item) for item in cached_data],
- total=cached_stats.get("total", 0),
- )
+ return {
+ "ranking": cached_data,
+ "total": cached_stats.get("total", 0),
+ }
# 缓存未命中,从数据库查询
wheres = [
@@ -350,7 +354,7 @@ async def get_user_ranking(
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked),
]
- include = ["user"]
+ include = UserStatistics.RANKING_INCLUDES.copy()
if sort == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
@@ -358,6 +362,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
+ include.append("country_rank")
# 查询总数
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
@@ -378,12 +383,14 @@ async def get_user_ranking(
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
- user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
+ user_stats_resp = await UserStatisticsModel.transform(
+ statistics, includes=include, user_country=current_user.country_code
+ )
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
- cache_data = [item.model_dump() for item in ranking_data]
+ cache_data = ranking_data
stats_data = {"total": total_count}
# 创建后台任务来缓存数据
@@ -407,5 +414,7 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
- resp = TopUsersResponse(ranking=ranking_data, total=total_count)
- return resp
+ return {
+ "ranking": ranking_data,
+ "total": total_count,
+ }
diff --git a/app/router/v2/relationship.py b/app/router/v2/relationship.py
index 554596a..7d0c891 100644
--- a/app/router/v2/relationship.py
+++ b/app/router/v2/relationship.py
@@ -1,15 +1,16 @@
-from typing import Annotated
+from typing import Annotated, Any
-from app.database import Relationship, RelationshipResp, RelationshipType, User
-from app.database.user import UserResp
+from app.database import Relationship, RelationshipType, User
+from app.database.relationship import RelationshipModel
+from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database
from app.dependencies.user import ClientUser, get_current_user
+from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Request, Security
-from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -17,38 +18,19 @@ from sqlmodel import col, exists, select
"/friends",
tags=["用户关系"],
responses={
- 200: {
- "description": "好友列表",
- "content": {
- "application/json": {
- "schema": {
- "oneOf": [
- {
- "type": "array",
- "items": {"$ref": "#/components/schemas/RelationshipResp"},
- "description": "好友列表",
- },
- {
- "type": "array",
- "items": {"$ref": "#/components/schemas/UserResp"},
- "description": "好友列表 (`x-api-version < 20241022`)",
- },
- ]
- }
- }
- },
- }
+ 200: api_doc(
+ "好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表。",
+ list[RelationshipModel] | list[UserModel],
+ [f"target.{inc}" for inc in User.LIST_INCLUDES],
+ )
},
name="获取好友列表",
- description=(
- "获取当前用户的好友列表。\n\n"
- "如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。"
- ),
+ description="获取当前用户的好友列表。",
)
@router.get(
"/blocks",
tags=["用户关系"],
- response_model=list[RelationshipResp],
+ response_model=list[dict[str, Any]],
name="获取屏蔽列表",
description="获取当前用户的屏蔽用户列表。",
)
@@ -67,35 +49,29 @@ async def get_relationship(
)
)
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
- return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
+ return [
+ await RelationshipModel.transform(
+ rel,
+ includes=[f"target.{inc}" for inc in User.LIST_INCLUDES],
+ ruleset=current_user.playmode,
+ )
+ for rel in relationships.unique()
+ ]
else:
return [
- await UserResp.from_db(
+ await UserModel.transform(
rel.target,
- db,
- include=[
- "team",
- "daily_challenge_user_stats",
- "statistics",
- "statistics_rulesets",
- ],
+ ruleset=current_user.playmode,
+ includes=User.LIST_INCLUDES,
)
for rel in relationships.unique()
]
-class AddFriendResp(BaseModel):
- """添加好友/屏蔽 返回模型。
-
- - user_relation: 新的或更新后的关系对象。"""
-
- user_relation: RelationshipResp
-
-
@router.post(
"/friends",
tags=["用户关系"],
- response_model=AddFriendResp,
+ responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")},
name="添加或更新好友关系",
description="\n添加或更新与目标用户的好友关系。",
)
@@ -163,7 +139,13 @@ async def add_relationship(
)
)
).one()
- return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
+ return {
+ "user_relation": await RelationshipModel.transform(
+ relationship,
+ includes=[],
+ ruleset=current_user.playmode,
+ )
+ }
@router.delete(
diff --git a/app/router/v2/room.py b/app/router/v2/room.py
index b8d06ed..94ed5a0 100644
--- a/app/router/v2/room.py
+++ b/app/router/v2/room.py
@@ -1,25 +1,27 @@
from datetime import UTC
from typing import Annotated, Literal
-from app.database.beatmap import Beatmap, BeatmapResp
-from app.database.beatmapset import BeatmapsetResp
-from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp
+from app.database.beatmap import (
+ Beatmap,
+ BeatmapModel,
+)
+from app.database.beatmapset import BeatmapsetModel
+from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
-from app.database.playlists import Playlist, PlaylistResp
-from app.database.room import APIUploadedRoom, Room, RoomResp
+from app.database.playlists import Playlist, PlaylistModel
+from app.database.room import APIUploadedRoom, Room, RoomModel
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
-from app.database.user import User, UserResp
+from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import MatchType, RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
-from app.utils import utcnow
+from app.utils import api_doc, utcnow
from .router import router
from fastapi import HTTPException, Path, Query, Security
-from pydantic import BaseModel, Field
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get(
"/rooms",
tags=["房间"],
- response_model=list[RoomResp],
+ responses={
+ 200: api_doc(
+ "房间列表",
+ list[RoomModel],
+ [
+ "current_playlist_item.beatmap.beatmapset",
+ "difficulty_range",
+ "host.country",
+ "playlist_item_stats",
+ "recent_participants",
+ ],
+ )
+ },
name="获取房间列表",
description="获取房间列表。支持按状态/模式筛选",
)
@@ -49,7 +63,7 @@ async def get_all_rooms(
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
):
- resp_list: list[RoomResp] = []
+ resp_list = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
now = utcnow()
@@ -90,22 +104,24 @@ async def get_all_rooms(
.all()
)
for room in db_rooms:
- resp = await RoomResp.from_db(room, db)
+ resp = await RoomModel.transform(
+ room,
+ includes=[
+ "current_playlist_item.beatmap.beatmapset",
+ "difficulty_range",
+ "host.country",
+ "playlist_item_stats",
+ "recent_participants",
+ ],
+ )
if category == RoomCategory.REALTIME:
- resp.category = RoomCategory.NORMAL
+ resp["category"] = RoomCategory.NORMAL
resp_list.append(resp)
return resp_list
-class APICreatedRoom(RoomResp):
- """创建房间返回模型,继承 RoomResp。额外字段:
- - error: 错误信息(为空表示成功)。"""
-
- error: str = ""
-
-
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
@@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
@router.post(
"/rooms",
tags=["房间"],
- response_model=APICreatedRoom,
name="创建房间",
description="\n创建一个新的房间。",
+ responses={
+ 200: api_doc(
+ "创建的房间信息",
+ RoomModel,
+ Room.SHOW_RESPONSE_INCLUDES,
+ )
+ },
)
async def create_room(
db: Database,
@@ -145,23 +167,27 @@ async def create_room(
):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
-
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
- created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
- created_room.error = ""
+ created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return created_room
@router.get(
"/rooms/{room_id}",
tags=["房间"],
- response_model=RoomResp,
+ responses={
+ 200: api_doc(
+ "房间详细信息",
+ RoomModel,
+ Room.SHOW_RESPONSE_INCLUDES,
+ )
+ },
name="获取房间详情",
- description="获取单个房间详情。",
+ description="获取指定房间详情。",
)
async def get_room(
db: Database,
@@ -177,7 +203,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
- resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
+ resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user)
return resp
@@ -225,10 +251,10 @@ async def add_user_to_room(
await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
- resp = await RoomResp.from_db(db_room, db)
+ resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return resp
else:
- raise HTTPException(404, "room not found0")
+ raise HTTPException(404, "room not found")
@router.delete(
@@ -268,21 +294,22 @@ async def remove_user_from_room(
raise HTTPException(404, "Room not found")
-class APILeaderboard(BaseModel):
- """房间全局排行榜返回模型。
- - leaderboard: 用户游玩统计(尝试次数/分数等)。
- - user_score: 当前用户对应统计。"""
-
- leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
- user_score: ItemAttemptsResp | None = None
-
-
@router.get(
"/rooms/{room_id}/leaderboard",
tags=["房间"],
- response_model=APILeaderboard,
name="获取房间排行榜",
description="获取房间内累计得分排行榜。",
+ responses={
+ 200: api_doc(
+ "房间排行榜",
+ {
+ "leaderboard": list[ItemAttemptsCountModel],
+ "user_score": ItemAttemptsCountModel | None,
+ },
+ ["user.country", "position"],
+ name="RoomLeaderboardResponse",
+ )
+ },
)
async def get_room_leaderboard(
db: Database,
@@ -300,45 +327,43 @@ async def get_room_leaderboard(
aggs_resp = []
user_agg = None
for i, agg in enumerate(aggs):
- resp = await ItemAttemptsResp.from_db(agg, db)
- resp.position = i + 1
+ includes = ["user.country"]
+ if agg.user_id == current_user.id:
+ includes.append("position")
+ resp = await ItemAttemptsCountModel.transform(agg, includes=includes)
aggs_resp.append(resp)
if agg.user_id == current_user.id:
user_agg = resp
- return APILeaderboard(
- leaderboard=aggs_resp,
- user_score=user_agg,
- )
-
-class RoomEvents(BaseModel):
- """房间事件流返回模型。
- - beatmaps: 本次结果涉及的谱面列表。
- - beatmapsets: 谱面集映射。
- - current_playlist_item_id: 当前游玩列表(项目)项 ID。
- - events: 事件列表。
- - first_event_id / last_event_id: 事件范围。
- - playlist_items: 房间游玩列表(项目)详情。
- - room: 房间详情。
- - user: 关联用户列表。"""
-
- beatmaps: list[BeatmapResp] = Field(default_factory=list)
- beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
- current_playlist_item_id: int = 0
- events: list[MultiplayerEventResp] = Field(default_factory=list)
- first_event_id: int = 0
- last_event_id: int = 0
- playlist_items: list[PlaylistResp] = Field(default_factory=list)
- room: RoomResp
- user: list[UserResp] = Field(default_factory=list)
+ return {
+ "leaderboard": aggs_resp,
+ "user_score": user_agg,
+ }
@router.get(
"/rooms/{room_id}/events",
- response_model=RoomEvents,
tags=["房间"],
name="获取房间事件",
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
+ responses={
+ 200: api_doc(
+ "房间事件",
+ {
+ "beatmaps": list[BeatmapModel],
+ "beatmapsets": list[BeatmapsetModel],
+ "current_playlist_item_id": int,
+ "events": list[MultiplayerEventResp],
+ "first_event_id": int,
+ "last_event_id": int,
+ "playlist_items": list[PlaylistModel],
+ "room": RoomModel,
+ "user": list[UserModel],
+ },
+ ["country", "details", "scores"],
+ name="RoomEventsResponse",
+ )
+ },
)
async def get_room_events(
db: Database,
@@ -402,28 +427,44 @@ async def get_room_events(
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
- room_resp = await RoomResp.from_db(room, db)
- if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
- current_playlist_item_id = room_resp.current_playlist_item.id
+ room_resp = await RoomModel.transform(room, includes=["current_playlist_item"])
+ if room.category == RoomCategory.REALTIME:
+ current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"]
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
- user_resps = [await UserResp.from_db(user, db) for user in users]
+ user_resps = [await UserModel.transform(user, includes=["country"]) for user in users]
+
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
- beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
- beatmapset_resps = {}
- for beatmap_resp in beatmap_resps:
- beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
+ beatmap_resps = [
+ await BeatmapModel.transform(
+ beatmap,
+ )
+ for beatmap in beatmaps
+ ]
- playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
+ beatmapsets = []
+ for beatmap in beatmaps:
+ if beatmap.beatmapset_id not in beatmapsets:
+ beatmapsets.append(beatmap.beatmapset)
+ beatmapset_resps = [
+ await BeatmapsetModel.transform(
+ beatmapset,
+ )
+ for beatmapset in beatmapsets
+ ]
- return RoomEvents(
- beatmaps=beatmap_resps,
- beatmapsets=beatmapset_resps,
- current_playlist_item_id=current_playlist_item_id,
- events=event_resps,
- first_event_id=first_event_id,
- last_event_id=last_event_id,
- playlist_items=playlist_items_resps,
- room=room_resp,
- user=user_resps,
- )
+ playlist_items_resps = [
+ await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values()
+ ]
+
+ return {
+ "beatmaps": beatmap_resps,
+ "beatmapsets": beatmapset_resps,
+ "current_playlist_item_id": current_playlist_item_id,
+ "events": event_resps,
+ "first_event_id": first_event_id,
+ "last_event_id": last_event_id,
+ "playlist_items": playlist_items_resps,
+ "room": room_resp,
+ "user": user_resps,
+ }
diff --git a/app/router/v2/score.py b/app/router/v2/score.py
index 55eb6ec..fffcea2 100644
--- a/app/router/v2/score.py
+++ b/app/router/v2/score.py
@@ -9,7 +9,6 @@ from app.database import (
Playlist,
Room,
Score,
- ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
@@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType
from app.database.score import (
LegacyScoreResp,
MultiplayerScores,
- ScoreAround,
+ MultiplayScoreDict,
+ ScoreModel,
get_leaderboard,
+ get_score_position_by_id,
process_score,
process_user,
)
@@ -49,7 +50,7 @@ from app.models.score import (
)
from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background
-from app.utils import utcnow
+from app.utils import api_doc, utcnow
from .router import router
@@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 10
+DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"]
logger = log("Score")
@@ -180,13 +182,15 @@ async def submit_score(
await db.refresh(score)
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
- resp: ScoreResp = await ScoreResp.from_db(db, score)
+ resp = await ScoreModel.transform(
+ score,
+ )
score_gamemode = score.gamemode
await db.commit()
if user_id is not None:
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
- background_task.add_task(_process_user_achievement, resp.id)
+ background_task.add_task(_process_user_achievement, resp["id"])
return resp
@@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None:
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
-class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel):
+LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp
+
+
+class BeatmapUserScore(BaseModel):
position: int
- score: T
+ score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm]
-class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
- scores: list[T]
- user_score: BeatmapUserScore[T] | None = None
+class BeatmapScores(BaseModel):
+ scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm]
+ user_score: BeatmapUserScore | None = None
score_count: int = 0
@router.get(
"/beatmaps/{beatmap_id}/scores",
tags=["成绩"],
- response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp],
+ responses={
+ 200: {
+ "model": BeatmapScores,
+ "description": (
+ "排行榜及当前用户成绩。\n\n"
+ f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`"
+ f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
+ "否则为 `BeatmapScores[LegacyScoreResp]`。"
+ ),
+ }
+ },
name="获取谱面排行榜",
- description=(
- "获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n"
- "如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`,"
- "否则为 `BeatmapScores[LegacyScoreResp]`。"
- ),
+ description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
)
async def get_beatmap_scores(
db: Database,
@@ -266,27 +279,46 @@ async def get_beatmap_scores(
mods=sorted(mods),
)
- user_score_resp = await user_score.to_resp(db, api_version) if user_score else None
- resp = BeatmapScores(
- scores=[await score.to_resp(db, api_version) for score in all_scores],
- user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
- if user_score_resp
- else None,
- score_count=count,
- )
- return resp
+ user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None
+ return {
+ "scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores],
+ "user_score": (
+ {
+ "score": user_score_resp,
+ "position": (
+ await get_score_position_by_id(
+ db,
+ user_score.beatmap_id,
+ user_score.id,
+ mode=user_score.gamemode,
+ user=user_score.user,
+ )
+ or 0
+ ),
+ }
+ if user_score and user_score_resp
+ else None
+ ),
+ "score_count": count,
+ }
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
tags=["成绩"],
- response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp],
+ responses={
+ 200: {
+ "model": BeatmapUserScore,
+ "description": (
+ "指定用户在指定谱面上的最高成绩\n\n"
+ "如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`,"
+ f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
+ "否则为 `BeatmapUserScore[LegacyScoreResp]`。"
+ ),
+ }
+ },
name="获取用户谱面最高成绩",
- description=(
- "获取指定用户在指定谱面上的最高成绩。\n\n"
- "如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`,"
- "否则为 `BeatmapUserScore[LegacyScoreResp]`。"
- ),
+ description="获取指定用户在指定谱面上的最高成绩。",
)
async def get_user_beatmap_score(
db: Database,
@@ -318,23 +350,38 @@ async def get_user_beatmap_score(
detail=f"Cannot find user {user_id}'s score on this beatmap",
)
else:
- resp = await user_score.to_resp(db, api_version=api_version)
- return BeatmapUserScore(
- position=resp.rank_global or 0,
- score=resp,
- )
+ resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES)
+ return {
+ "position": (
+ await get_score_position_by_id(
+ db,
+ user_score.beatmap_id,
+ user_score.id,
+ mode=user_score.gamemode,
+ user=user_score.user,
+ )
+ or 0
+ ),
+ "score": resp,
+ }
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
tags=["成绩"],
- response_model=list[ScoreResp] | list[LegacyScoreResp],
+ responses={
+ 200: api_doc(
+ (
+ "用户谱面全部成绩\n\n"
+ "如果 `x-api-version >= 20220705`,返回值为 `Score`列表,"
+ "否则为 `LegacyScoreResp`列表。"
+ ),
+ list[ScoreModel] | list[LegacyScoreResp],
+ DEFAULT_SCORE_INCLUDES,
+ )
+ },
name="获取用户谱面全部成绩",
- description=(
- "获取指定用户在指定谱面上的全部成绩列表。\n\n"
- "如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表,"
- "否则为 `LegacyScoreResp`列表。"
- ),
+ description="获取指定用户在指定谱面上的全部成绩列表。",
)
async def get_user_all_beatmap_scores(
db: Database,
@@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores(
)
).all()
- return [await score.to_resp(db, api_version) for score in all_user_scores]
+ return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores]
@router.post(
@@ -413,9 +460,9 @@ async def create_solo_score(
@router.put(
"/beatmaps/{beatmap_id}/solo/scores/{token}",
tags=["游玩"],
- response_model=ScoreResp,
name="提交单曲成绩",
description="\n使用令牌提交单曲成绩。",
+ responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_solo_score(
background_task: BackgroundTasks,
@@ -520,6 +567,7 @@ async def create_playlist_score(
tags=["游玩"],
name="提交房间项目成绩",
description="\n提交房间游玩项目成绩。",
+ responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_playlist_score(
background_task: BackgroundTasks,
@@ -560,13 +608,13 @@ async def submit_playlist_score(
room_id,
playlist_id,
user_id,
- score_resp.id,
- score_resp.total_score,
+ score_resp["id"],
+ score_resp["total_score"],
session,
redis,
)
await session.commit()
- if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed:
+ if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]:
await process_daily_challenge_score(session, user_id, room_id)
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
await session.commit()
@@ -575,15 +623,23 @@ async def submit_playlist_score(
class IndexedScoreResp(MultiplayerScores):
total: int
- user_score: ScoreResp | None = None
+ user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm]
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores",
- response_model=IndexedScoreResp,
+ # response_model=IndexedScoreResp,
name="获取房间项目排行榜",
description="获取房间游玩项目排行榜。",
tags=["成绩"],
+ responses={
+ 200: {
+ "description": (
+ f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}"
+ ),
+ "model": IndexedScoreResp,
+ }
+ },
)
async def index_playlist_scores(
session: Database,
@@ -620,16 +676,14 @@ async def index_playlist_scores(
scores = scores[:-1]
user_score = None
- score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
+ score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores]
for score in score_resp:
- score.position = await get_position(room_id, playlist_id, score.id, session)
- if score.user_id == user_id:
+ if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[
+ "user_id"
+ ] == user_id:
user_score = score
-
- if room.category == RoomCategory.DAILY_CHALLENGE:
- score_resp = [s for s in score_resp if s.passed]
- if user_score and not user_score.passed:
- user_score = None
+ user_score["position"] = await get_position(room_id, playlist_id, score["id"], session)
+ break
resp = IndexedScoreResp(
scores=score_resp,
@@ -648,10 +702,16 @@ async def index_playlist_scores(
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
- response_model=ScoreResp,
name="获取房间项目单个成绩",
description="获取指定房间游玩项目中单个成绩详情。",
tags=["成绩"],
+ responses={
+ 200: api_doc(
+ "房间项目单个成绩详情。",
+ ScoreModel,
+ [*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
+ )
+ },
)
async def show_playlist_score(
session: Database,
@@ -687,39 +747,25 @@ async def show_playlist_score(
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
- resp = await ScoreResp.from_db(session, score_record.score)
- resp.position = await get_position(room_id, playlist_id, score_id, session)
+ includes = [
+ *Score.MULTIPLAYER_BASE_INCLUDES,
+ "position",
+ ]
if completed:
- scores = (
- await session.exec(
- select(PlaylistBestScore).where(
- PlaylistBestScore.playlist_id == playlist_id,
- PlaylistBestScore.room_id == room_id,
- ~User.is_restricted_query(col(PlaylistBestScore.user_id)),
- )
- )
- ).all()
- higher_scores = []
- lower_scores = []
- for score in scores:
- resp = await ScoreResp.from_db(session, score.score)
- if is_playlist and not resp.passed:
- continue
- if score.total_score > resp.total_score:
- higher_scores.append(resp)
- elif score.total_score < resp.total_score:
- lower_scores.append(resp)
- resp.scores_around = ScoreAround(
- higher=MultiplayerScores(scores=higher_scores),
- lower=MultiplayerScores(scores=lower_scores),
- )
-
+ includes.append("scores_around")
+ resp = await ScoreModel.transform(score_record.score, includes=includes)
return resp
@router.get(
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
- response_model=ScoreResp,
+ responses={
+ 200: api_doc(
+ "房间项目单个成绩详情。",
+ ScoreModel,
+ [*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
+ )
+ },
name="获取房间项目用户成绩",
description="获取指定用户在房间游玩项目中的成绩。",
tags=["成绩"],
@@ -749,8 +795,14 @@ async def get_user_playlist_score(
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
- resp = await ScoreResp.from_db(session, score_record.score)
- resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
+ resp = await ScoreModel.transform(
+ score_record.score,
+ includes=[
+ *Score.MULTIPLAYER_BASE_INCLUDES,
+ "position",
+ "scores_around",
+ ],
+ )
return resp
diff --git a/app/router/v2/user.py b/app/router/v2/user.py
index ce3e080..31403f9 100644
--- a/app/router/v2/user.py
+++ b/app/router/v2/user.py
@@ -5,17 +5,16 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import (
Beatmap,
+ BeatmapModel,
BeatmapPlaycounts,
- BeatmapPlaycountsResp,
- BeatmapResp,
- BeatmapsetResp,
+ BeatmapsetModel,
User,
- UserResp,
)
+from app.database.beatmap_playcounts import BeatmapPlaycountsModel
from app.database.best_scores import BestScore
from app.database.events import Event
-from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
-from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
+from app.database.score import Score, get_user_first_scores
+from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.cache import UserCacheService
from app.dependencies.database import Database, get_redis
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.user_cache_service import get_user_cache_service
-from app.utils import utcnow
+from app.utils import api_doc, utcnow
from .router import router
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
-from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
-class BatchUserResponse(BaseModel):
- users: list[UserResp]
-
-
-class BeatmapsPassedResponse(BaseModel):
- beatmaps_passed: list[BeatmapResp]
-
-
def _get_difficulty_reduction_mods() -> set[str]:
mods: set[str] = set()
for ruleset_mods in API_MODS.values():
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
@router.get(
"/users/",
- response_model=BatchUserResponse,
+ responses={
+ 200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
+ },
name="批量获取用户信息",
description="通过用户 ID 列表批量获取用户信息。",
tags=["用户"],
)
-@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
-@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
+@router.get("/users/lookup", include_in_schema=False)
+@router.get("/users/lookup/", include_in_schema=False)
@asset_proxy_response
async def get_users(
session: Database,
@@ -108,16 +100,15 @@ async def get_users(
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
- user_resp = await UserResp.from_db(
+ user_resp = await UserModel.transform(
searched_user,
- session,
- include=SEARCH_INCLUDED,
+ includes=User.CARD_INCLUDES,
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
background_task.add_task(cache_service.cache_user, user_resp)
- response = BatchUserResponse(users=cached_users)
+ response = {"users": cached_users}
return response
else:
searched_users = (
@@ -127,16 +118,15 @@ async def get_users(
for searched_user in searched_users:
if searched_user.id == BANCHOBOT_ID:
continue
- user_resp = await UserResp.from_db(
+ user_resp = await UserModel.transform(
searched_user,
- session,
- include=SEARCH_INCLUDED,
+ includes=User.CARD_INCLUDES,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
- response = BatchUserResponse(users=users)
+ response = {"users": users}
return response
@@ -200,10 +190,12 @@ async def get_user_kudosu(
@router.get(
"/users/{user_id}/beatmaps-passed",
- response_model=BeatmapsPassedResponse,
name="获取用户已通过谱面",
description="获取指定用户在给定谱面集中的已通过谱面列表。",
tags=["用户"],
+ responses={
+ 200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
+ },
)
@asset_proxy_response
async def get_user_beatmaps_passed(
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
):
if not beatmapset_ids:
- return BeatmapsPassedResponse(beatmaps_passed=[])
+ return {"beatmaps_passed": []}
if len(beatmapset_ids) > 50:
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
scores = (await session.exec(score_query)).all()
if not scores:
- return BeatmapsPassedResponse(beatmaps_passed=[])
+ return {"beatmaps_passed": []}
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
passed_beatmap_ids: set[int] = set()
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
continue
passed_beatmap_ids.add(beatmap_id)
if not passed_beatmap_ids:
- return BeatmapsPassedResponse(beatmaps_passed=[])
+ return {"beatmaps_passed": []}
beatmaps = (
await session.exec(
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
)
).all()
- return BeatmapsPassedResponse(
- beatmaps_passed=[
- await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
+ return {
+ "beatmaps_passed": [
+ await BeatmapModel.transform(
+ beatmap,
+ )
+ for beatmap in beatmaps
]
- )
+ }
@router.get(
"/users/{user_id}/{ruleset}",
- response_model=UserResp,
name="获取用户信息(指定ruleset)",
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
tags=["用户"],
+ responses={
+ 200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
+ },
)
@asset_proxy_response
async def get_user_info_ruleset(
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
if should_not_show:
raise HTTPException(404, detail="User not found")
- include = SEARCH_INCLUDED
- if searched_is_self:
- include = ALL_INCLUDED
- user_resp = await UserResp.from_db(
+ user_resp = await UserModel.transform(
searched_user,
- session,
- include=include,
+ includes=User.USER_INCLUDES,
ruleset=ruleset,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
-
return user_resp
-@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
+@router.get("/users/{user_id}/", include_in_schema=False)
@router.get(
"/users/{user_id}",
- response_model=UserResp,
name="获取用户信息",
description="通过用户 ID 或用户名获取单个用户的详细信息。",
tags=["用户"],
+ responses={
+ 200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
+ },
)
@asset_proxy_response
async def get_user_info(
@@ -381,27 +375,31 @@ async def get_user_info(
if should_not_show:
raise HTTPException(404, detail="User not found")
- include = SEARCH_INCLUDED
- if searched_is_self:
- include = ALL_INCLUDED
- user_resp = await UserResp.from_db(
+ user_resp = await UserModel.transform(
searched_user,
- session,
- include=include,
+ includes=User.USER_INCLUDES,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp)
-
return user_resp
+beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
+
+
@router.get(
"/users/{user_id}/beatmapsets/{type}",
- response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
name="获取用户谱面集列表",
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
tags=["用户"],
+ responses={
+ 200: api_doc(
+ "当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
+ list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
+ beatmapset_includes,
+ )
+ },
)
@asset_proxy_response
async def get_user_beatmapsets(
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
- # 根据类型恢复对象
- if type == BeatmapsetType.MOST_PLAYED:
- return [BeatmapPlaycountsResp(**item) for item in cached_result]
- else:
- return [BeatmapsetResp(**item) for item in cached_result]
+ return cached_result
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
- await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
+ await BeatmapsetModel.transform(
+ favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
+ )
+ for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
- resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
+ resp = [
+ await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
+ for most_played_beatmap in most_played
+ ]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
@router.get(
"/users/{user_id}/scores/{type}",
- response_model=list[ScoreResp] | list[LegacyScoreResp],
name="获取用户成绩列表",
description=(
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
@@ -523,6 +522,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
+ includes = Score.USER_PROFILE_INCLUDES.copy()
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -531,6 +531,7 @@ async def get_user_scores(
elif type == "best":
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
order_by = col(Score.pp).desc()
+ includes.append("weight")
elif type == "recent":
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
order_by = col(Score.ended_at).desc()
@@ -551,6 +552,7 @@ async def get_user_scores(
await score.to_resp(
session,
api_version,
+ includes=includes,
)
for score in scores
]
diff --git a/app/service/beatmapset_cache_service.py b/app/service/beatmapset_cache_service.py
index 16ec2e2..dbb7682 100644
--- a/app/service/beatmapset_cache_service.py
+++ b/app/service/beatmapset_cache_service.py
@@ -3,14 +3,14 @@ Beatmapset缓存服务
用于缓存beatmapset数据,减少数据库查询频率
"""
-from datetime import datetime
import hashlib
import json
from typing import TYPE_CHECKING
from app.config import settings
-from app.database.beatmapset import BeatmapsetResp
+from app.database import BeatmapsetDict
from app.log import logger
+from app.utils import safe_json_dumps
from redis.asyncio import Redis
@@ -18,20 +18,6 @@ if TYPE_CHECKING:
pass
-class DateTimeEncoder(json.JSONEncoder):
- """处理datetime序列化的JSON编码器"""
-
- def default(self, obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- return super().default(obj)
-
-
-def safe_json_dumps(data) -> str:
- """安全的JSON序列化,处理datetime对象"""
- return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
-
-
def generate_hash(data) -> str:
"""生成数据的MD5哈希值"""
content = data if isinstance(data, str) else safe_json_dumps(data)
@@ -57,15 +43,14 @@ class BeatmapsetCacheService:
"""生成搜索结果缓存键"""
return f"beatmapset_search:{query_hash}:{cursor_hash}"
- async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None:
+ async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None:
"""从缓存获取beatmapset信息"""
try:
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
- data = json.loads(cached_data)
- return BeatmapsetResp(**data)
+ return json.loads(cached_data)
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmapset from cache: {e}")
@@ -73,24 +58,21 @@ class BeatmapsetCacheService:
async def cache_beatmapset(
self,
- beatmapset_resp: BeatmapsetResp,
+ beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存beatmapset信息"""
try:
if expire_seconds is None:
expire_seconds = self._default_ttl
- if beatmapset_resp.id is None:
- logger.warning("Cannot cache beatmapset with None id")
- return
- cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id)
- cached_data = beatmapset_resp.model_dump_json()
+ cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"])
+ cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
- logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s")
+ logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error caching beatmapset: {e}")
- async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None:
+ async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None:
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
try:
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
@@ -98,7 +80,7 @@ class BeatmapsetCacheService:
if cached_data:
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
data = json.loads(cached_data)
- return BeatmapsetResp(**data)
+ return data
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmap lookup from cache: {e}")
@@ -107,7 +89,7 @@ class BeatmapsetCacheService:
async def cache_beatmap_lookup(
self,
beatmap_id: int,
- beatmapset_resp: BeatmapsetResp,
+ beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存通过beatmap ID查找的beatmapset信息"""
@@ -115,7 +97,7 @@ class BeatmapsetCacheService:
if expire_seconds is None:
expire_seconds = self._default_ttl
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
- cached_data = beatmapset_resp.model_dump_json()
+ cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:
diff --git a/app/service/beatmapset_update_service.py b/app/service/beatmapset_update_service.py
index b642ca7..5c09123 100644
--- a/app/service/beatmapset_update_service.py
+++ b/app/service/beatmapset_update_service.py
@@ -3,12 +3,12 @@ from datetime import timedelta
from enum import Enum
import math
import random
-from typing import TYPE_CHECKING, NamedTuple
+from typing import TYPE_CHECKING, NamedTuple, cast
from app.config import OldScoreProcessingMode, settings
-from app.database.beatmap import Beatmap, BeatmapResp
+from app.database.beatmap import Beatmap, BeatmapDict
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
-from app.database.beatmapset import Beatmapset, BeatmapsetResp
+from app.database.beatmapset import Beatmapset, BeatmapsetDict
from app.database.score import Score
from app.dependencies.database import get_redis, with_db
from app.dependencies.storage import get_storage_service
@@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = {
SCHEDULER_INTERVAL_MINUTES = 2
+class EnsuredBeatmap(BeatmapDict):
+ checksum: str
+ ranked: int
+
+
+class EnsuredBeatmapset(BeatmapsetDict):
+ ranked: int
+ ranked_date: datetime.datetime
+ last_updated: datetime.datetime
+ play_count: int
+ beatmaps: list[EnsuredBeatmap]
+
+
class ProcessingBeatmapset:
- def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None:
+ def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None:
self.beatmapset = beatmapset
- self.status = BeatmapRankStatus(self.beatmapset.ranked)
+ self.status = BeatmapRankStatus(self.beatmapset["ranked"])
self.record = record
def calculate_next_sync_time(
@@ -76,19 +89,19 @@ class ProcessingBeatmapset:
now = utcnow()
if self.status == BeatmapRankStatus.QUALIFIED:
- assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps"
- time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds()
+ assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps"
+ time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds()
baseline = max(MIN_DELTA, time_to_ranked / 2)
next_delta = max(MIN_DELTA, baseline)
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
- seconds_since_update = (now - self.beatmapset.last_updated).total_seconds()
+ seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds()
factor_update = max(1.0, seconds_since_update / TAU)
- factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count)
+ factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"])
status_factor = STATUS_FACTOR[self.status]
baseline = BASE * factor_play / factor_update * status_factor
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
elif self.status == BeatmapRankStatus.GRAVEYARD:
- days_since_update = (now - self.beatmapset.last_updated).days
+ days_since_update = (now - self.beatmapset["last_updated"]).days
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
delta = MIN_DELTA * (2**doubling_periods)
max_seconds = GRAVEYARD_MAX_DAYS * 86400
@@ -105,21 +118,24 @@ class ProcessingBeatmapset:
@property
def beatmapset_changed(self) -> bool:
- return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked)
+ return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"])
@property
def changed_beatmaps(self) -> list[ChangedBeatmap]:
changed_beatmaps = []
- for bm in self.beatmapset.beatmaps:
- saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
+ for bm in self.beatmapset["beatmaps"]:
+ saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None)
if not saved or saved["is_deleted"]:
- changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
- elif saved["md5"] != bm.checksum:
- changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
- elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked):
- changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED))
+ changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED))
+ elif saved["md5"] != bm["checksum"]:
+ changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED))
+ elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]):
+ changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED))
for saved in self.record.beatmaps:
- if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]:
+ if (
+ not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"])
+ and not saved["is_deleted"]
+ ):
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
return changed_beatmaps
@@ -132,7 +148,7 @@ class BeatmapsetUpdateService:
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
if immediate:
- await self._sync_immediately(beatmapset)
+ await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset))
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
return True
await self.add(beatmapset)
@@ -172,7 +188,7 @@ class BeatmapsetUpdateService:
BeatmapSync(
beatmapset_id=missing,
beatmap_status=BeatmapRankStatus.GRAVEYARD,
- next_sync_time=datetime.datetime.max,
+ next_sync_time=datetime.datetime(year=6000, month=1, day=1),
beatmaps=[],
)
)
@@ -185,11 +201,13 @@ class BeatmapsetUpdateService:
await session.commit()
self._adding_missing = False
- async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True):
+ async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True):
+ beatmapset = cast(EnsuredBeatmapset, set)
async with with_db() as session:
- sync_record = await session.get(BeatmapSync, beatmapset.id)
+ beatmapset_id = beatmapset["id"]
+ sync_record = await session.get(BeatmapSync, beatmapset_id)
if not sync_record:
- database_beatmapset = await session.get(Beatmapset, beatmapset.id)
+ database_beatmapset = await session.get(Beatmapset, beatmapset_id)
if database_beatmapset:
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
await database_beatmapset.awaitable_attrs.beatmaps
@@ -203,19 +221,29 @@ class BeatmapsetUpdateService:
for bm in database_beatmapset.beatmaps
]
else:
- status = BeatmapRankStatus(beatmapset.ranked)
- beatmaps = [
- SavedBeatmapMeta(
- beatmap_id=bm.id,
- md5=bm.checksum,
- is_deleted=False,
- beatmap_status=BeatmapRankStatus(bm.ranked),
+ ranked = beatmapset.get("ranked")
+ if ranked is None:
+ raise ValueError("ranked field is required")
+ status = BeatmapRankStatus(ranked)
+ beatmap_list = beatmapset.get("beatmaps", [])
+ beatmaps = []
+ for bm in beatmap_list:
+ bm_id = bm.get("id")
+ checksum = bm.get("checksum")
+ ranked = bm.get("ranked")
+ if bm_id is None or checksum is None or ranked is None:
+ continue
+ beatmaps.append(
+ SavedBeatmapMeta(
+ beatmap_id=bm_id,
+ md5=checksum,
+ is_deleted=False,
+ beatmap_status=BeatmapRankStatus(ranked),
+ )
)
- for bm in beatmapset.beatmaps
- ]
sync_record = BeatmapSync(
- beatmapset_id=beatmapset.id,
+ beatmapset_id=beatmapset_id,
beatmaps=beatmaps,
beatmap_status=status,
)
@@ -223,13 +251,27 @@ class BeatmapsetUpdateService:
await session.commit()
await session.refresh(sync_record)
else:
- sync_record.beatmaps = [
- SavedBeatmapMeta(
- beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked)
+ ranked = beatmapset.get("ranked")
+ if ranked is None:
+ raise ValueError("ranked field is required")
+ beatmap_list = beatmapset.get("beatmaps", [])
+ beatmaps = []
+ for bm in beatmap_list:
+ bm_id = bm.get("id")
+ checksum = bm.get("checksum")
+ bm_ranked = bm.get("ranked")
+ if bm_id is None or checksum is None or bm_ranked is None:
+ continue
+ beatmaps.append(
+ SavedBeatmapMeta(
+ beatmap_id=bm_id,
+ md5=checksum,
+ is_deleted=False,
+ beatmap_status=BeatmapRankStatus(bm_ranked),
+ )
)
- for bm in beatmapset.beatmaps
- ]
- sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
+ sync_record.beatmaps = beatmaps
+ sync_record.beatmap_status = BeatmapRankStatus(ranked)
if calculate_next_sync:
processing = ProcessingBeatmapset(beatmapset, sync_record)
next_time_delta = processing.calculate_next_sync_time()
@@ -238,17 +280,19 @@ class BeatmapsetUpdateService:
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
return
sync_record.next_sync_time = utcnow() + next_time_delta
- logger.opt(colors=True).info(f"[{beatmapset.id}] next sync at {sync_record.next_sync_time}")
+ beatmapset_id = beatmapset.get("id")
+ if beatmapset_id:
+ logger.opt(colors=True).debug(f"[{beatmapset_id}] next sync at {sync_record.next_sync_time}")
await session.commit()
- async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None:
+ async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None:
async with with_db() as session:
- record = await session.get(BeatmapSync, beatmapset.id)
+ record = await session.get(BeatmapSync, beatmapset["id"])
if not record:
record = BeatmapSync(
- beatmapset_id=beatmapset.id,
+ beatmapset_id=beatmapset["id"],
beatmaps=[],
- beatmap_status=BeatmapRankStatus(beatmapset.ranked),
+ beatmap_status=BeatmapRankStatus(beatmapset["ranked"]),
)
session.add(record)
await session.commit()
@@ -261,19 +305,18 @@ class BeatmapsetUpdateService:
record: BeatmapSync,
session: AsyncSession,
*,
- beatmapset: BeatmapsetResp | None = None,
+ beatmapset: EnsuredBeatmapset | None = None,
):
- logger.opt(colors=True).info(f"[{record.beatmapset_id}] syncing...")
+ logger.opt(colors=True).debug(f"[{record.beatmapset_id}] syncing...")
if beatmapset is None:
try:
- beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id)
+ beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id))
except Exception as e:
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
logger.opt(colors=True).warning(
f"[{record.beatmapset_id}] beatmapset not found (404), removing from sync list"
)
await session.delete(record)
- await session.commit()
return
if isinstance(e, HTTPError):
logger.opt(colors=True).warning(
@@ -292,20 +335,20 @@ class BeatmapsetUpdateService:
if changed:
record.beatmaps = [
SavedBeatmapMeta(
- beatmap_id=bm.id,
- md5=bm.checksum,
+ beatmap_id=bm["id"],
+ md5=bm["checksum"],
is_deleted=False,
- beatmap_status=BeatmapRankStatus(bm.ranked),
+ beatmap_status=BeatmapRankStatus(bm["ranked"]),
)
- for bm in beatmapset.beatmaps
+ for bm in beatmapset["beatmaps"]
]
- record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
+ record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"])
record.consecutive_no_change = 0
bg_tasks.add_task(
self._process_changed_beatmaps,
changed_beatmaps,
- beatmapset.beatmaps,
+ beatmapset["beatmaps"],
)
bg_tasks.add_task(
self._process_changed_beatmapset,
@@ -317,13 +360,13 @@ class BeatmapsetUpdateService:
next_time_delta = processing.calculate_next_sync_time()
if not next_time_delta:
logger.opt(colors=True).info(
- f"[{beatmapset.id}] beatmapset has transformed to ranked or loved,"
+ f"[{beatmapset['id']}] beatmapset has transformed to ranked or loved,"
f" removing from sync list"
)
await session.delete(record)
else:
record.next_sync_time = utcnow() + next_time_delta
- logger.opt(colors=True).info(f"[{record.beatmapset_id}] next sync at {record.next_sync_time}")
+ logger.opt(colors=True).debug(f"[{record.beatmapset_id}] next sync at {record.next_sync_time}")
async def _update_beatmaps(self):
async with with_db() as session:
@@ -338,18 +381,18 @@ class BeatmapsetUpdateService:
await self.sync(record, session)
await session.commit()
- async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
+ async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset):
async with with_db() as session:
- db_beatmapset = await session.get(Beatmapset, beatmapset.id)
- new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
+ db_beatmapset = await session.get(Beatmapset, beatmapset["id"])
+ new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType]
if db_beatmapset:
await session.merge(new_beatmapset)
- await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id)
+ await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"])
await session.commit()
- async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]):
+ async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]):
storage_service = get_storage_service()
- beatmaps = {bm.id: bm for bm in beatmaps_list}
+ beatmaps = {bm["id"]: bm for bm in beatmaps_list}
async with with_db() as session:
@@ -380,9 +423,9 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
- f"[{beatmap.beatmapset_id}] adding beatmap {beatmap.id}"
+ f"[{beatmap['beatmapset_id']}] adding beatmap {beatmap['id']}"
)
- await Beatmap.from_resp_no_save(session, beatmap)
+ await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
else:
beatmap = beatmaps.get(change.beatmap_id)
if not beatmap:
@@ -391,10 +434,10 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
- f"[{beatmap.beatmapset_id}] processing beatmap {beatmap.id} "
+ f"[{beatmap['beatmapset_id']}] processing beatmap {beatmap['id']} "
f"change {change.type}"
)
- new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap)
+ new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
if existing_beatmap:
await session.merge(new_db_beatmap)
diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py
index d2702ef..e1c475c 100644
--- a/app/service/ranking_cache_service.py
+++ b/app/service/ranking_cache_service.py
@@ -4,16 +4,15 @@
"""
import asyncio
-from datetime import datetime
import json
from typing import TYPE_CHECKING, Literal
from app.config import settings
-from app.database.statistics import UserStatistics, UserStatisticsResp
+from app.database.statistics import UserStatistics, UserStatisticsModel
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
-from app.utils import utcnow
+from app.utils import safe_json_dumps, utcnow
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -23,20 +22,6 @@ if TYPE_CHECKING:
pass
-class DateTimeEncoder(json.JSONEncoder):
- """自定义 JSON 编码器,支持 datetime 序列化"""
-
- def default(self, obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- return super().default(obj)
-
-
-def safe_json_dumps(data) -> str:
- """安全的 JSON 序列化,支持 datetime 对象"""
- return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
-
-
class RankingCacheService:
"""用户排行榜缓存服务"""
@@ -311,7 +296,7 @@ class RankingCacheService:
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked).is_(True),
]
- include = ["user"]
+ include = UserStatistics.RANKING_INCLUDES.copy()
if type == "performance":
order_by = col(UserStatistics.pp).desc()
@@ -321,6 +306,7 @@ class RankingCacheService:
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
+ include.append("country_rank")
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
@@ -353,9 +339,9 @@ class RankingCacheService:
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
- user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
+ user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include)
- user_dict = user_stats_resp.model_dump()
+ user_dict = user_stats_resp
# 应用资源代理处理
if settings.enable_asset_proxy:
diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py
index 34fefba..2485261 100644
--- a/app/service/redis_message_system.py
+++ b/app/service/redis_message_system.py
@@ -8,14 +8,14 @@
import asyncio
from datetime import datetime
import json
-import time
from typing import Any
-from app.database.chat import ChatMessage, ChatMessageResp, MessageType
-from app.database.user import RANKING_INCLUDES, User, UserResp
+from app.database import ChatMessageDict
+from app.database.chat import ChatMessage, ChatMessageModel, MessageType
+from app.database.user import User, UserModel
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
-from app.utils import bg_tasks
+from app.utils import bg_tasks, safe_json_dumps
class RedisMessageSystem:
@@ -35,7 +35,7 @@ class RedisMessageSystem:
content: str,
is_action: bool = False,
user_uuid: str | None = None,
- ) -> ChatMessageResp:
+ ) -> "ChatMessageDict":
"""
发送消息 - 立即存储到 Redis 并返回
@@ -47,7 +47,7 @@ class RedisMessageSystem:
user_uuid: 用户UUID
Returns:
- ChatMessageResp: 消息响应对象
+ ChatMessage: 消息响应对象
"""
# 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id)
@@ -57,28 +57,16 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
- # 获取频道类型以判断是否需要存储到数据库
- async with with_db() as session:
- from app.database.chat import ChannelType, ChatChannel
-
- from sqlmodel import select
-
- channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
- channel_type = channel_result.first()
- is_multiplayer = channel_type == ChannelType.MULTIPLAYER
-
# 准备消息数据
- message_data = {
+ message_data: "ChatMessageDict" = {
"message_id": message_id,
"channel_id": channel_id,
"sender_id": user.id,
"content": content,
- "timestamp": timestamp.isoformat(),
- "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
+ "timestamp": timestamp,
+ "type": MessageType.ACTION if is_action else MessageType.PLAIN,
"uuid": user_uuid or "",
- "status": "cached", # Redis 缓存状态
- "created_at": time.time(),
- "is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
+ "is_action": is_action,
}
# 立即存储到 Redis
@@ -86,51 +74,13 @@ class RedisMessageSystem:
# 创建响应对象
async with with_db() as session:
- user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
+ user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
+ message_data["sender"] = user_resp
- # 确保 statistics 不为空
- if user_resp.statistics is None:
- from app.database.statistics import UserStatisticsResp
+ logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
+ return message_data
- user_resp.statistics = UserStatisticsResp(
- mode=user.playmode,
- global_rank=0,
- country_rank=0,
- pp=0.0,
- ranked_score=0,
- hit_accuracy=0.0,
- play_count=0,
- play_time=0,
- total_score=0,
- total_hits=0,
- maximum_combo=0,
- replays_watched_by_others=0,
- is_ranked=False,
- grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
- level={"current": 1, "progress": 0},
- )
-
- response = ChatMessageResp(
- message_id=message_id,
- channel_id=channel_id,
- content=content,
- timestamp=timestamp,
- sender_id=user.id,
- sender=user_resp,
- is_action=is_action,
- uuid=user_uuid,
- )
-
- if is_multiplayer:
- logger.info(
- f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
- " will not be persisted to database"
- )
- else:
- logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
- return response
-
- async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
+ async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
@@ -140,9 +90,9 @@ class RedisMessageSystem:
since: 起始消息ID
Returns:
- List[ChatMessageResp]: 消息列表
+ List[ChatMessageDict]: 消息列表
"""
- messages = []
+ messages: list["ChatMessageDict"] = []
try:
# 从 Redis 获取最新消息
@@ -154,45 +104,21 @@ class RedisMessageSystem:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
- user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
+ user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
- if user_resp.statistics is None:
- from app.database.statistics import UserStatisticsResp
+ from app.database.chat import ChatMessageDict
- user_resp.statistics = UserStatisticsResp(
- mode=sender.playmode,
- global_rank=0,
- country_rank=0,
- pp=0.0,
- ranked_score=0,
- hit_accuracy=0.0,
- play_count=0,
- play_time=0,
- total_score=0,
- total_hits=0,
- maximum_combo=0,
- replays_watched_by_others=0,
- is_ranked=False,
- grade_counts={
- "ssh": 0,
- "ss": 0,
- "sh": 0,
- "s": 0,
- "a": 0,
- },
- level={"current": 1, "progress": 0},
- )
-
- message_resp = ChatMessageResp(
- message_id=msg_data["message_id"],
- channel_id=msg_data["channel_id"],
- content=msg_data["content"],
- timestamp=datetime.fromisoformat(msg_data["timestamp"]),
- sender_id=msg_data["sender_id"],
- sender=user_resp,
- is_action=msg_data["type"] == MessageType.ACTION.value,
- uuid=msg_data.get("uuid") or None,
- )
+ message_resp: ChatMessageDict = {
+ "message_id": msg_data["message_id"],
+ "channel_id": msg_data["channel_id"],
+ "content": msg_data["content"],
+ "timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
+ "sender_id": msg_data["sender_id"],
+ "sender": user_resp,
+ "is_action": msg_data["type"] == MessageType.ACTION.value,
+ "uuid": msg_data.get("uuid") or None,
+ "type": MessageType(msg_data["type"]),
+ }
messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充
@@ -216,86 +142,46 @@ class RedisMessageSystem:
return message_id
- async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
+ async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
"""存储消息到 Redis"""
try:
- # 检查是否是多人房间消息
- is_multiplayer = message_data.get("is_multiplayer", False)
-
- # 存储消息数据
- await self.redis.hset(
+ # 存储消息数据为 JSON 字符串
+ await self.redis.set(
f"msg:{channel_id}:{message_id}",
- mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
+ safe_json_dumps(message_data),
+ ex=604800, # 7天过期
)
- # 设置消息过期时间(7天)
- await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
-
- # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
+ # 添加到频道消息列表(按时间排序)
channel_messages_key = f"channel:{channel_id}:messages"
- # 更健壮的键类型检查和清理
+ # 检查并清理错误类型的键
try:
key_type = await self.redis.type(channel_messages_key)
- if key_type == "none":
- # 键不存在,这是正常的
- pass
- elif key_type != "zset":
- # 键类型错误,需要清理
+ if key_type not in ("none", "zset"):
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self.redis.delete(channel_messages_key)
-
- # 验证删除是否成功
- verify_type = await self.redis.type(channel_messages_key)
- if verify_type != "none":
- logger.error(
- f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
- )
- # 强制删除
- await self.redis.unlink(channel_messages_key)
-
except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
- # 如果检查失败,尝试强制删除键以确保清理
- try:
- await self.redis.delete(channel_messages_key)
- except Exception:
- # 最后的努力:使用unlink
- try:
- await self.redis.unlink(channel_messages_key)
- except Exception as final_error:
- logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
+ await self.redis.delete(channel_messages_key)
# 添加到频道消息列表(sorted set)
- try:
- await self.redis.zadd(
- channel_messages_key,
- mapping={f"msg:{channel_id}:{message_id}": message_id},
- )
- except Exception as zadd_error:
- logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
- # 如果添加失败,再次尝试清理并重试
- await self.redis.delete(channel_messages_key)
- await self.redis.zadd(
- channel_messages_key,
- mapping={f"msg:{channel_id}:{message_id}": message_id},
- )
+ await self.redis.zadd(
+ channel_messages_key,
+ mapping={f"msg:{channel_id}:{message_id}": message_id},
+ )
# 保持频道消息列表大小(最多1000条)
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
- # 只有非多人房间消息才添加到待持久化队列
- if not is_multiplayer:
- await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
- logger.debug(f"Message {message_id} added to persistence queue")
- else:
- logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
+ await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
+ logger.debug(f"Message {message_id} added to persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
raise
- async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
+ async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表,按消息ID排序
@@ -314,28 +200,16 @@ class RedisMessageSystem:
messages = []
for key in message_keys:
- # 获取消息数据
- raw_data = await self.redis.hgetall(key)
+ # 获取消息数据(JSON 字符串)
+ raw_data = await self.redis.get(key)
if raw_data:
- # 解码数据
- message_data: dict[str, Any] = {}
- for k, v in raw_data.items():
- # 尝试解析 JSON
- try:
- if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
- message_data[k] = json.loads(v)
- elif k in ["message_id", "channel_id", "sender_id"]:
- message_data[k] = int(v)
- elif k == "is_multiplayer":
- message_data[k] = v == "True"
- elif k == "created_at":
- message_data[k] = float(v)
- else:
- message_data[k] = v
- except (json.JSONDecodeError, ValueError):
- message_data[k] = v
-
- messages.append(message_data)
+ try:
+ # 解析 JSON 字符串为字典
+ message_data = json.loads(raw_data)
+ messages.append(message_data)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to decode message JSON from {key}: {e}")
+ continue
# 确保消息按ID正序排序(时间顺序)
messages.sort(key=lambda x: x.get("message_id", 0))
@@ -350,15 +224,15 @@ class RedisMessageSystem:
logger.error(f"Failed to get messages from Redis: {e}")
return []
- async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
+ async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
min_id = float("inf")
if existing_messages:
for msg in existing_messages:
- if msg.message_id is not None and msg.message_id < min_id:
- min_id = msg.message_id
+ if msg["message_id"] is not None and msg["message_id"] < min_id:
+ min_id = msg["message_id"]
needed = limit - len(existing_messages)
@@ -378,13 +252,13 @@ class RedisMessageSystem:
db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入
- msg_resp = await ChatMessageResp.from_db(msg, session)
+ msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
existing_messages.insert(0, msg_resp)
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
- async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
+ async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
@@ -402,7 +276,7 @@ class RedisMessageSystem:
messages = (await session.exec(query)).all()
- results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
+ results = await ChatMessageModel.transform_many(messages, includes=["sender"])
# 如果是 since > 0,保持正序;否则反转为时间正序
if since == 0:
@@ -450,27 +324,17 @@ class RedisMessageSystem:
# 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(":"))
- # 从 Redis 获取消息数据
- raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
+ # 从 Redis 获取消息数据(JSON 字符串)
+ raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
if not raw_data:
continue
- # 解码数据
- message_data = {}
- for k, v in raw_data.items():
- message_data[k] = v
-
- # 检查是否是多人房间消息,如果是则跳过数据库存储
- is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
- if is_multiplayer:
- # 多人房间消息不存储到数据库,直接标记为已跳过
- await self.redis.hset(
- f"msg:{channel_id}:{message_id}",
- "status",
- "skipped_multiplayer",
- )
- logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
+ # 解析 JSON 字符串为字典
+ try:
+ message_data = json.loads(raw_data)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
continue
# 检查消息是否已存在于数据库
@@ -491,13 +355,6 @@ class RedisMessageSystem:
session.add(db_message)
- # 更新 Redis 中的状态
- await self.redis.hset(
- f"msg:{channel_id}:{message_id}",
- "status",
- "persisted",
- )
-
logger.debug(f"Message {message_id} persisted to database")
except Exception as e:
diff --git a/app/service/room.py b/app/service/room.py
index cec7029..3204383 100644
--- a/app/service/room.py
+++ b/app/service/room.py
@@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
- db_room = room.to_room()
- db_room.host_id = host_id
+ db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})})
db_room.starts_at = utcnow()
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
session.add(db_room)
diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py
index 3da0ab3..5e634fd 100644
--- a/app/service/user_cache_service.py
+++ b/app/service/user_cache_service.py
@@ -3,19 +3,19 @@
用于缓存用户信息,提供热缓存和实时刷新功能
"""
-from datetime import datetime
import json
from typing import TYPE_CHECKING, Any
from app.config import settings
from app.const import BANCHOBOT_ID
-from app.database import User, UserResp
-from app.database.score import LegacyScoreResp, ScoreResp
-from app.database.user import SEARCH_INCLUDED
+from app.database import User
+from app.database.score import LegacyScoreResp
+from app.database.user import UserDict, UserModel
from app.dependencies.database import with_db
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
+from app.utils import safe_json_dumps
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -25,20 +25,6 @@ if TYPE_CHECKING:
pass
-class DateTimeEncoder(json.JSONEncoder):
- """自定义 JSON 编码器,支持 datetime 序列化"""
-
- def default(self, obj):
- if isinstance(obj, datetime):
- return obj.isoformat()
- return super().default(obj)
-
-
-def safe_json_dumps(data: Any) -> str:
- """安全的 JSON 序列化,支持 datetime 对象"""
- return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
-
-
class UserCacheService:
"""用户缓存服务"""
@@ -125,7 +111,7 @@ class UserCacheService:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
- async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
+ async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None:
"""从缓存获取用户信息"""
try:
cache_key = self._get_user_cache_key(user_id, ruleset)
@@ -133,7 +119,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User cache hit for user {user_id}")
data = json.loads(cached_data)
- return UserResp(**data)
+ return data
return None
except Exception as e:
logger.error(f"Error getting user from cache: {e}")
@@ -141,7 +127,7 @@ class UserCacheService:
async def cache_user(
self,
- user_resp: UserResp,
+ user_resp: UserDict,
ruleset: GameMode | None = None,
expire_seconds: int | None = None,
):
@@ -149,13 +135,10 @@ class UserCacheService:
try:
if expire_seconds is None:
expire_seconds = settings.user_cache_expire_seconds
- if user_resp.id is None:
- logger.warning("Cannot cache user with None id")
- return
- cache_key = self._get_user_cache_key(user_resp.id, ruleset)
- cached_data = user_resp.model_dump_json()
+ cache_key = self._get_user_cache_key(user_resp["id"], ruleset)
+ cached_data = safe_json_dumps(user_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data)
- logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
+ logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user: {e}")
@@ -168,10 +151,9 @@ class UserCacheService:
limit: int = 100,
offset: int = 0,
is_legacy: bool = False,
- ) -> list[ScoreResp] | list[LegacyScoreResp] | None:
+ ) -> list[UserDict] | list[LegacyScoreResp] | None:
"""从缓存获取用户成绩"""
try:
- model = LegacyScoreResp if is_legacy else ScoreResp
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
@@ -179,7 +161,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
data = json.loads(cached_data)
- return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType]
+ return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data
return None
except Exception as e:
logger.error(f"Error getting user scores from cache: {e}")
@@ -189,7 +171,7 @@ class UserCacheService:
self,
user_id: int,
score_type: str,
- scores: list[ScoreResp] | list[LegacyScoreResp],
+ scores: list[UserDict] | list[LegacyScoreResp],
include_fail: bool,
mode: GameMode | None = None,
limit: int = 100,
@@ -204,8 +186,12 @@ class UserCacheService:
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
- # 使用 model_dump_json() 而不是 model_dump() + json.dumps()
- scores_json_list = [score.model_dump_json() for score in scores]
+ if len(scores) == 0:
+ return
+ if isinstance(scores[0], dict):
+ scores_json_list = [safe_json_dumps(score) for score in scores]
+ else:
+ scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
@@ -308,7 +294,7 @@ class UserCacheService:
for user in users:
if user.id != BANCHOBOT_ID:
try:
- await self._cache_single_user(user, session)
+ await self._cache_single_user(user)
cached_count += 1
except Exception as e:
logger.error(f"Failed to cache user {user.id}: {e}")
@@ -320,10 +306,10 @@ class UserCacheService:
finally:
self._refreshing = False
- async def _cache_single_user(self, user: User, session: AsyncSession):
+ async def _cache_single_user(self, user: User):
"""缓存单个用户"""
try:
- user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
+ user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES)
# 应用资源代理处理
if settings.enable_asset_proxy:
@@ -347,7 +333,7 @@ class UserCacheService:
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
- await self._cache_single_user(user, session)
+ await self._cache_single_user(user)
logger.info(f"Refreshed cache for user {user_id} after score submit")
except Exception as e:
logger.error(f"Error refreshing user cache on score submit: {e}")
diff --git a/app/utils.py b/app/utils.py
index c786511..1123353 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -4,10 +4,13 @@ from datetime import UTC, datetime
import functools
import inspect
from io import BytesIO
+import json
import re
-from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
+from types import NoneType, UnionType
+from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict, TypeVar, Union, get_args, get_origin
from fastapi import HTTPException
+from fastapi.encoders import jsonable_encoder
from PIL import Image
if TYPE_CHECKING:
@@ -299,3 +302,51 @@ def hex_to_hue(hex_color: str) -> int:
hue = (60 * ((r - g) / delta) + 240) % 360
return int(hue)
+
+
+def safe_json_dumps(data) -> str:
+ return json.dumps(jsonable_encoder(data), ensure_ascii=False)
+
+
+def type_is_optional(typ: type):
+ origin_type = get_origin(typ)
+ args = get_args(typ)
+ return (origin_type is UnionType or origin_type is Union) and len(args) == 2 and NoneType in args
+
+
+def _get_type(typ: type, includes: tuple[str, ...]) -> Any:
+ from app.database._base import DatabaseModel
+
+ origin = get_origin(typ)
+ if issubclass(typ, DatabaseModel):
+ return typ.generate_typeddict(includes)
+ elif origin is list:
+ item_type = typ.__args__[0]
+ return list[_get_type(item_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
+ elif origin is dict:
+ key_type, value_type = typ.__args__
+ return dict[key_type, _get_type(value_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
+ elif type_is_optional(typ):
+ inner_type = next(arg for arg in get_args(typ) if arg is not NoneType)
+ return Union[_get_type(inner_type, includes), None] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
+ elif origin is UnionType or origin is Union:
+ new_types = []
+ for arg in get_args(typ):
+ new_types.append(_get_type(arg, includes)) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
+ return Union[tuple(new_types)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
+ else:
+ return typ
+
+
+def api_doc(desc: str, model: Any, includes: list[str] = [], *, name: str = "APIDict"):
+ if includes:
+ includes_str = ", ".join(f"`{inc}`" for inc in includes)
+ desc += f"\n\n包含:{includes_str}"
+ if isinstance(model, dict):
+ fields = {}
+ for k, v in model.items():
+ fields[k] = _get_type(v, tuple(includes))
+ typed_dict = TypedDict(name, fields) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
+ else:
+ typed_dict = _get_type(model, tuple(includes))
+ return {"description": desc, "model": typed_dict}
diff --git a/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py b/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py
new file mode 100644
index 0000000..4af10eb
--- /dev/null
+++ b/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py
@@ -0,0 +1,498 @@
+"""project: remove unused fields in database
+
+Revision ID: 23707640303c
+Revises: 3f0f22f38c3d
+Create Date: 2025-11-23 08:14:05.284238
+
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import mysql
+
+# revision identifiers, used by Alembic.
+revision: str = "23707640303c"
+down_revision: str | Sequence[str] | None = "3f0f22f38c3d"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ """Upgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column(
+ "beatmaps",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.drop_column("beatmaps", "current_user_playcount")
+ op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=True)
+ op.alter_column(
+ "best_scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "lazer_user_statistics",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=True)
+ op.alter_column(
+ "lazer_users",
+ "playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "lazer_users",
+ "g0v0_playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.drop_column("lazer_users", "beatmap_playcounts_count")
+ op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=False)
+ op.alter_column(
+ "rank_history",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "rank_top",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "rooms",
+ "type",
+ existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
+ nullable=False,
+ )
+ op.execute("UPDATE rooms SET channel_id = 0 WHERE channel_id IS NULL")
+ op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=False)
+ op.alter_column(
+ "score_tokens",
+ "ruleset_id",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "teams",
+ "playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ op.alter_column(
+ "total_score_best_scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=False,
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ """Downgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column(
+ "total_score_best_scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "teams",
+ "playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "score_tokens",
+ "ruleset_id",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=True)
+ op.alter_column(
+ "rooms",
+ "type",
+ existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
+ nullable=True,
+ )
+ op.alter_column(
+ "rank_top",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "rank_history",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=True)
+ op.add_column(
+ "lazer_users", sa.Column("beatmap_playcounts_count", mysql.INTEGER(), autoincrement=False, nullable=False)
+ )
+ op.alter_column(
+ "lazer_users",
+ "g0v0_playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "lazer_users",
+ "playmode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=False)
+ op.alter_column(
+ "lazer_user_statistics",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column(
+ "best_scores",
+ "gamemode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=False)
+ op.add_column("beatmaps", sa.Column("current_user_playcount", mysql.INTEGER(), autoincrement=False, nullable=False))
+ op.alter_column(
+ "beatmaps",
+ "mode",
+ existing_type=mysql.ENUM(
+ "OSU",
+ "TAIKO",
+ "FRUITS",
+ "MANIA",
+ "OSURX",
+ "OSUAP",
+ "TAIKORX",
+ "FRUITSRX",
+ "SENTAKKI",
+ "TAU",
+ "RUSH",
+ "HISHIGATA",
+ "SOYOKAZE",
+ ),
+ nullable=True,
+ )
+ # ### end Alembic commands ###