diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6818b66..d6ed022 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,23 +9,26 @@ git clone https://github.com/GooGuTeam/g0v0-server.git 此外,您还需要: - clone 旁观服务器到 g0v0-server 的文件夹。 + ```bash git clone https://github.com/GooGuTeam/osu-server-spectator.git spectator-server ``` + - clone 表现分计算器到 g0v0-server 的文件夹。 ```bash git clone https://github.com/GooGuTeam/osu-performance-server.git performance-server ``` + - 下载并放置自定义规则集 DLL 到 `rulesets/` 目录(如果需要)。 ## 开发环境 为了确保一致的开发环境,我们强烈建议使用提供的 Dev Container。这将设置一个容器化的环境,预先安装所有必要的工具和依赖项。 -1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。 -2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。 -3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。 +1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。 +2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。 +3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。 ## 配置项目 @@ -67,54 +70,109 @@ uv sync 以下是项目主要目录和文件的结构说明: -- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。 -- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。 -- `alembic.ini`: Alembic 数据库迁移工具的配置文件。 -- `app/`: 存放所有核心应用代码。 - - `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。 - - `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。 - - `database/`: 定义数据库模型 (SQLModel) 和会话管理。 - - `models/`: 定义非数据库模型和其他模型。 - - `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。 - - `dependencies/`: 管理 FastAPI 的依赖项注入。 - - `achievements/`: 存放与成就相关的逻辑。 - - `storage/`: 存储服务代码。 - - `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。 - - `middleware/`: 定义中间件,例如会话验证。 - - `helpers/`: 存放辅助函数和工具类。 - - `config.py`: 应用配置,使用 pydantic-settings 管理。 - - `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。 - - `log.py`: 日志记录模块,提供统一的日志接口。 - - `const.py`: 定义常量。 - - `path.py`: 定义跨文件使用的常量。 -- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。 -- `static/`: 存放静态文件,如 `mods.json`。 +- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。 +- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。 +- `alembic.ini`: Alembic 数据库迁移工具的配置文件。 +- `app/`: 存放所有核心应用代码。 + - `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。 + - `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。 + - `database/`: 定义数据库模型 (SQLModel) 和会话管理。 + - `models/`: 定义非数据库模型和其他模型。 + - `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。 + - `dependencies/`: 管理 FastAPI 的依赖项注入。 + - `achievements/`: 存放与成就相关的逻辑。 + - `storage/`: 存储服务代码。 + - `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。 + - `middleware/`: 定义中间件,例如会话验证。 + - `helpers/`: 存放辅助函数和工具类。 + - `config.py`: 应用配置,使用 pydantic-settings 管理。 + - `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。 + - `log.py`: 日志记录模块,提供统一的日志接口。 + - `const.py`: 定义常量。 + - `path.py`: 定义跨文件使用的常量。 +- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。 +- `static/`: 存放静态文件,如 `mods.json`。 ### 数据库模型定义 所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。 -如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构: +本项目使用一种“按需返回”的设计模式,遵循 `Dict` - `Model` - `Table` 结构。详细设计思路请参考[这篇文章](https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/)。 + +#### 基本结构 + +1. **Dict**: 定义模型转换后的字典结构,用于类型检查。必须在 Model 之前定义。 +2. **Model**: 继承自 `DatabaseModel[Dict]`,定义字段和计算属性。 +3. **Table**: 继承自 `Model`,定义数据库表结构。 ```python -class ModelBase(SQLModel): - # 定义共有内容 - ... +from typing import TypedDict, NotRequired +from app.database._base import DatabaseModel, OnDemand, included, ondemand +from sqlmodel import Field +# 1. 定义 Dict +class UserDict(TypedDict): + id: int + username: str + email: NotRequired[str] # 可选字段 + followers_count: int # 计算属性 -class Model(ModelBase, table=True): - # 定义数据库表内容 - ... +# 2. 定义 Model +class UserModel(DatabaseModel[UserDict]): + id: int = Field(primary_key=True) + username: str + email: OnDemand[str] # 使用 OnDemand 标记可选字段 - -class ModelResp(ModelBase): - # 定义响应内容 - ... - - @classmethod - def from_db(cls, db: Model) -> "ModelResp": - # 从数据库模型转换 + # 普通计算属性 (总是返回) + @included + @staticmethod + async def followers_count(session: AsyncSession, instance: "User") -> int: + return await session.scalar(select(func.count()).where(Follower.followed_id == instance.id)) + + # 可选计算属性 (仅在 includes 中指定时返回) + @ondemand + @staticmethod + async def some_optional_property(session: AsyncSession, instance: "User") -> str: ... + +# 3. 定义 Table +class User(UserModel, table=True): + password: str # 仅在数据库中存在的字段 + ... +``` + +#### 字段类型 + +- **普通属性**: 直接定义在 Model 中,总是返回。 +- **可选属性**: 使用 `OnDemand[T]` 标记,仅在 `includes` 中指定时返回。 +- **普通计算属性**: 使用 `@included` 装饰的静态方法,总是返回。 +- **可选计算属性**: 使用 `@ondemand` 装饰的静态方法,仅在 `includes` 中指定时返回。 + +#### 使用方法 + +**转换模型**: + +使用 `Model.transform` 方法将数据库实例转换为字典: + +```python +user = await session.get(User, 1) +user_dict = await UserModel.transform( + user, + includes=["email"], # 指定需要返回的可选字段 + some_context="foo-bar", # 如果计算属性需要上下文,可以传入额外参数 + session=session # 可选传入自己的 session +) +``` + +**API 文档**: + +在 FastAPI 路由中,使用 `Model.generate_typeddict` 生成准确的响应文档: + +```python +@router.get("/users/{id}", response_model=UserModel.generate_typeddict(includes=("email",))) +async def get_user(id: int) -> dict: + ... + return await UserModel.transform(user, includes=["email"]) ``` 数据库模块名应与表名相同,定义了多个模型的除外。 @@ -227,16 +285,16 @@ pre-commit 不提供 pyright 的 hook,您需要手动运行 `pyright` 检查 **类型** 必须是以下之一: -* **feat**:新功能 -* **fix**:错误修复 -* **docs**:仅文档更改 -* **style**:不影响代码含义的更改(空格、格式、缺少分号等) -* **refactor**:代码重构 -* **perf**:改善性能的代码更改 -* **test**:添加缺失的测试或修正现有测试 -* **chore**:对构建过程或辅助工具和库(如文档生成)的更改 -* **ci**:持续集成相关的更改 -* **deploy**: 部署相关的更改 +- **feat**:新功能 +- **fix**:错误修复 +- **docs**:仅文档更改 +- **style**:不影响代码含义的更改(空格、格式、缺少分号等) +- **refactor**:代码重构 +- **perf**:改善性能的代码更改 +- **test**:添加缺失的测试或修正现有测试 +- **chore**:对构建过程或辅助工具和库(如文档生成)的更改 +- **ci**:持续集成相关的更改 +- **deploy**: 部署相关的更改 **范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。对整个项目的更改使用 `project`。 diff --git a/app/database/__init__.py b/app/database/__init__.py index 6922184..601b235 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys from .beatmap import ( Beatmap, - BeatmapResp, + BeatmapDict, + BeatmapModel, +) +from .beatmap_playcounts import ( + BeatmapPlaycounts, + BeatmapPlaycountsDict, + BeatmapPlaycountsModel, ) -from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp from .beatmap_sync import BeatmapSync from .beatmap_tags import BeatmapTagVote from .beatmapset import ( Beatmapset, - BeatmapsetResp, + BeatmapsetDict, + BeatmapsetModel, ) from .beatmapset_ratings import BeatmapRating from .best_scores import BestScore from .chat import ( ChannelType, ChatChannel, - ChatChannelResp, + ChatChannelDict, + ChatChannelModel, ChatMessage, - ChatMessageResp, + ChatMessageDict, + ChatMessageModel, ) from .counts import ( CountResp, @@ -30,8 +38,8 @@ from .events import Event from .favourite_beatmapset import FavouriteBeatmapset from .item_attempts_count import ( ItemAttemptsCount, - ItemAttemptsResp, - PlaylistAggregateScore, + ItemAttemptsCountDict, + ItemAttemptsCountModel, ) from .matchmaking import ( MatchmakingPool, @@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp from .notification import Notification, UserNotification from .password_reset import PasswordReset from .playlist_best_score import PlaylistBestScore -from .playlists import Playlist, PlaylistResp +from .playlists import Playlist, PlaylistDict, PlaylistModel from .rank_history import RankHistory, RankHistoryResp, RankTop -from .relationship import Relationship, RelationshipResp, RelationshipType -from .room import APIUploadedRoom, Room, RoomResp +from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType +from .room import APIUploadedRoom, Room, RoomDict, RoomModel from .room_participated_user import RoomParticipatedUser from .score import ( MultiplayerScores, Score, ScoreAround, - ScoreBase, - ScoreResp, + ScoreDict, + ScoreModel, ScoreStatistics, ) from .score_token import ScoreToken, ScoreTokenResp +from .search_beatmapset import SearchBeatmapsetsResp from .statistics import ( UserStatistics, - UserStatisticsResp, + UserStatisticsDict, + UserStatisticsModel, ) from .team import Team, TeamMember, TeamRequest, TeamResp from .total_score_best_scores import TotalScoreBestScore from .user import ( - MeResp, User, - UserResp, + UserDict, + UserModel, ) from .user_account_history import ( UserAccountHistory, @@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru __all__ = [ "APIUploadedRoom", "Beatmap", + "BeatmapDict", + "BeatmapModel", "BeatmapPlaycounts", - "BeatmapPlaycountsResp", + "BeatmapPlaycountsDict", + "BeatmapPlaycountsModel", "BeatmapRating", - "BeatmapResp", "BeatmapSync", "BeatmapTagVote", "Beatmapset", - "BeatmapsetResp", + "BeatmapsetDict", + "BeatmapsetModel", "BestScore", "ChannelType", "ChatChannel", - "ChatChannelResp", + "ChatChannelDict", + "ChatChannelModel", "ChatMessage", - "ChatMessageResp", + "ChatMessageDict", + "ChatMessageModel", "CountResp", "DailyChallengeStats", "DailyChallengeStatsResp", @@ -100,13 +115,13 @@ __all__ = [ "Event", "FavouriteBeatmapset", "ItemAttemptsCount", - "ItemAttemptsResp", + "ItemAttemptsCountDict", + "ItemAttemptsCountModel", "LoginSession", "LoginSessionResp", "MatchmakingPool", "MatchmakingPoolBeatmap", "MatchmakingUserStats", - "MeResp", "MonthlyPlaycounts", "MultiplayerEvent", "MultiplayerEventResp", @@ -116,26 +131,29 @@ __all__ = [ "OAuthToken", "PasswordReset", "Playlist", - "PlaylistAggregateScore", "PlaylistBestScore", - "PlaylistResp", + "PlaylistDict", + "PlaylistModel", "RankHistory", "RankHistoryResp", "RankTop", "Relationship", - "RelationshipResp", + "RelationshipDict", + "RelationshipModel", "RelationshipType", "ReplayWatchedCount", "Room", + "RoomDict", + "RoomModel", "RoomParticipatedUser", - "RoomResp", "Score", "ScoreAround", - "ScoreBase", - "ScoreResp", + "ScoreDict", + "ScoreModel", "ScoreStatistics", "ScoreToken", "ScoreTokenResp", + "SearchBeatmapsetsResp", "Team", "TeamMember", "TeamRequest", @@ -149,17 +167,18 @@ __all__ = [ "UserAccountHistoryResp", "UserAccountHistoryType", "UserAchievement", - "UserAchievement", "UserAchievementResp", + "UserDict", "UserLoginLog", + "UserModel", "UserNotification", "UserPreference", - "UserResp", "UserStatistics", - "UserStatisticsResp", + "UserStatisticsDict", + "UserStatisticsModel", "V1APIKeys", ] for i in __all__: - if i.endswith("Resp"): - globals()[i].model_rebuild() # type: ignore[call-arg] + if i.endswith("Model") or i.endswith("Resp"): + globals()[i].model_rebuild() diff --git a/app/database/_base.py b/app/database/_base.py new file mode 100644 index 0000000..7a05010 --- /dev/null +++ b/app/database/_base.py @@ -0,0 +1,499 @@ +from collections.abc import Awaitable, Callable, Sequence +from functools import lru_cache, wraps +import inspect +import sys +from types import NoneType, get_original_bases +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Concatenate, + ForwardRef, + ParamSpec, + TypedDict, + cast, + get_args, + get_origin, + overload, +) + +from app.models.model import UTCBaseModel +from app.utils import type_is_optional + +from sqlalchemy.ext.asyncio import async_object_session +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlmodel.main import SQLModelMetaclass + +_dict_to_model: dict[type, type["DatabaseModel"]] = {} + + +def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any: + """Safely evaluate a ForwardRef, with fallback to app.database module""" + if isinstance(type_, str): + type_ = ForwardRef(type_) + + try: + return evaluate_forwardref( + type_, + globalns=vars(sys.modules[module_name]), + localns={}, + ) + except (NameError, AttributeError, KeyError): + # Fallback to app.database module + try: + import app.database + + return evaluate_forwardref( + type_, + globalns=vars(app.database), + localns={}, + ) + except (NameError, AttributeError, KeyError): + return None + + +class OnDemand[T]: + if TYPE_CHECKING: + + def __get__(self, instance: object | None, owner: Any) -> T: ... + + def __set__(self, instance: Any, value: T) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + +class Exclude[T]: + if TYPE_CHECKING: + + def __get__(self, instance: object | None, owner: Any) -> T: ... + + def __set__(self, instance: Any, value: T) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + +# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140 +def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {}) + if sys.version_info >= (3, 14) and "__annotations__" not in class_dict: + # See https://github.com/pydantic/pydantic/pull/11991 + from annotationlib import ( + Format, + call_annotate_function, + get_annotate_from_class_namespace, + ) + + if annotate := get_annotate_from_class_namespace(class_dict): + raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF) + return raw_annotations + + +# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77 +if sys.version_info < (3, 12, 4): + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + # Even though it is the right signature for python 3.9, mypy complains with + # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... + # Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid: + # TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard' + return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set()) + +else: + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + # Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid + # warnings: + return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set()) + + +class DatabaseModelMetaclass(SQLModelMetaclass): + def __new__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> "DatabaseModelMetaclass": + original_annotations = _get_annotations(namespace) + new_annotations = {} + ondemands = [] + excludes = [] + + for k, v in original_annotations.items(): + if get_origin(v) is OnDemand: + inner_type = v.__args__[0] + new_annotations[k] = inner_type + ondemands.append(k) + elif get_origin(v) is Exclude: + inner_type = v.__args__[0] + new_annotations[k] = inner_type + excludes.append(k) + else: + new_annotations[k] = v + + new_class = super().__new__( + cls, + name, + bases, + { + **namespace, + "__annotations__": new_annotations, + }, + **kwargs, + ) + + new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {})) + new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list( + ondemands + ) + new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {})) + new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes) + + for attr_name, attr_value in namespace.items(): + target = _get_callable_target(attr_value) + if target is None: + continue + + if getattr(target, "__included__", False): + new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target) + _pre_calculate_context_params(target, attr_value) + + if getattr(target, "__calculated_ondemand__", False): + new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target) + _pre_calculate_context_params(target, attr_value) + + # Register TDict to DatabaseModel mapping + for base in get_original_bases(new_class): + cls_name = base.__name__ + if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name: + generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1] + generic_type = evaluate_forwardref( + ForwardRef(generic_type_name), + globalns=vars(sys.modules[new_class.__module__]), + localns={}, + ) + _dict_to_model[generic_type[0]] = new_class + + return new_class + + +def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None: + if hasattr(target, "__context_params__"): + return + + sig = inspect.signature(target) + params = list(sig.parameters.keys()) + + start_index = 2 + if isinstance(attr_value, classmethod): + start_index = 3 + + context_params = [] if len(params) < start_index else params[start_index:] + + setattr(target, "__context_params__", context_params) + + +def _get_callable_target(value: Any) -> Callable | None: + if isinstance(value, (staticmethod, classmethod)): + return value.__func__ + if inspect.isfunction(value): + return value + if inspect.ismethod(value): + return value.__func__ + return None + + +def _mark_callable(value: Any, flag: str) -> Callable | None: + target = _get_callable_target(value) + if target is None: + return None + setattr(target, flag, True) + return target + + +def _get_return_type(func: Callable) -> type: + sig = inspect.get_annotations(func) + return sig.get("return", Any) + + +P = ParamSpec("P") +CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]] +DecoratorTarget = CalculatedField | staticmethod | classmethod + + +def included(func: DecoratorTarget) -> DecoratorTarget: + marker = _mark_callable(func, "__included__") + if marker is None: + raise RuntimeError("@included is only usable on callables.") + + @wraps(marker) + async def wrapper(*args, **kwargs): + return await marker(*args, **kwargs) + + if isinstance(func, staticmethod): + return staticmethod(wrapper) + if isinstance(func, classmethod): + return classmethod(wrapper) + return wrapper + + +def ondemand(func: DecoratorTarget) -> DecoratorTarget: + marker = _mark_callable(func, "__calculated_ondemand__") + if marker is None: + raise RuntimeError("@ondemand is only usable on callables.") + + @wraps(marker) + async def wrapper(*args, **kwargs): + return await marker(*args, **kwargs) + + if isinstance(func, staticmethod): + return staticmethod(wrapper) + if isinstance(func, classmethod): + return classmethod(wrapper) + return wrapper + + +async def call_awaitable_with_context( + func: CalculatedField, + session: AsyncSession, + instance: Any, + context: dict[str, Any], +) -> Any: + context_params: list[str] | None = getattr(func, "__context_params__", None) + + if context_params is None: + # Fallback if not pre-calculated + sig = inspect.signature(func) + if len(sig.parameters) == 2: + return await func(session, instance) + else: + call_params = {} + for param in sig.parameters.values(): + if param.name in context: + call_params[param.name] = context[param.name] + return await func(session, instance, **call_params) + + if not context_params: + return await func(session, instance) + + call_params = {} + for name in context_params: + if name in context: + call_params[name] = context[name] + return await func(session, instance, **call_params) + + +class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass): + _CALCULATED_FIELDS: ClassVar[dict[str, type]] = {} + + _ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = [] + _ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {} + + _EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = [] + + @overload + @classmethod + async def transform( + cls, + db_instance: "DatabaseModel", + *, + session: AsyncSession, + includes: list[str] | None = None, + **context: Any, + ) -> TDict: ... + + @overload + @classmethod + async def transform( + cls, + db_instance: "DatabaseModel", + *, + includes: list[str] | None = None, + **context: Any, + ) -> TDict: ... + + @classmethod + async def transform( + cls, + db_instance: "DatabaseModel", + *, + session: AsyncSession | None = None, + includes: list[str] | None = None, + **context: Any, + ) -> TDict: + includes = includes.copy() if includes is not None else [] + session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session + if session is None: + raise RuntimeError("DatabaseModel.transform requires a session-bound instance.") + resp_obj = cls.model_validate(db_instance.model_dump()) + data = resp_obj.model_dump() + + for field in cls._CALCULATED_FIELDS: + func = getattr(cls, field) + value = await call_awaitable_with_context(func, session, db_instance, context) + data[field] = value + + sub_include_map: dict[str, list[str]] = {} + for include in [i for i in includes if "." in i]: + parent, sub_include = include.split(".", 1) + if parent not in sub_include_map: + sub_include_map[parent] = [] + sub_include_map[parent].append(sub_include) + includes.remove(include) # pyright: ignore[reportOptionalMemberAccess] + + for field, sub_includes in sub_include_map.items(): + if field in cls._ONDEMAND_CALCULATED_FIELDS: + func = getattr(cls, field) + value = await call_awaitable_with_context( + func, session, db_instance, {**context, "includes": sub_includes} + ) + data[field] = value + + for include in includes: + if include in data: + continue + + if include in cls._ONDEMAND_CALCULATED_FIELDS: + func = getattr(cls, include) + value = await call_awaitable_with_context(func, session, db_instance, context) + data[include] = value + + for field in cls._ONDEMAND_DATABASE_FIELDS: + if field not in includes: + del data[field] + + for field in cls._EXCLUDED_DATABASE_FIELDS: + if field in data: + del data[field] + + return cast(TDict, data) + + @classmethod + async def transform_many( + cls, + db_instances: Sequence["DatabaseModel"], + *, + session: AsyncSession | None = None, + includes: list[str] | None = None, + **context: Any, + ) -> list[TDict]: + if not db_instances: + return [] + + # SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here + # if the transform method performs any database operations using the shared session. + # Since we don't know if the transform method (or its calculated fields) will use the DB, + # we must execute them serially to be safe. + results = [] + for instance in db_instances: + results.append(await cls.transform(instance, session=session, includes=includes, **context)) + return results + + @classmethod + @lru_cache + def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm] + def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any: + # Evaluate ForwardRef if present + if isinstance(field_type, (str, ForwardRef)): + resolved = _safe_evaluate_forwardref(field_type, cls.__module__) + if resolved is not None: + field_type = resolved + + origin_type = get_origin(field_type) + inner_type = field_type + args = get_args(field_type) + + is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType] + if is_optional: + inner_type = next((arg for arg in args if arg is not NoneType), field_type) + + is_list = False + if origin_type is list: + is_list = True + inner_type = args[0] + + # Evaluate ForwardRef in inner_type if present + if isinstance(inner_type, (str, ForwardRef)): + resolved = _safe_evaluate_forwardref(inner_type, cls.__module__) + if resolved is not None: + inner_type = resolved + + if not resolve_database_model: + if is_optional: + return inner_type | None # pyright: ignore[reportOperatorIssue] + elif is_list: + return list[inner_type] + return inner_type + + model_class = None + + # First check if inner_type is directly a DatabaseModel subclass + try: + if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore + model_class = inner_type + except TypeError: + pass + + # If not found, look up in _dict_to_model + if model_class is None: + model_class = _dict_to_model.get(inner_type) # type: ignore + + if model_class is not None: + nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ()))) + resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore + + if is_optional: + resolved_type = resolved_type | None # type: ignore + + return resolved_type + + # Fallback: use the resolved inner_type + resolved_type = list[inner_type] if is_list else inner_type # type: ignore + if is_optional: + resolved_type = resolved_type | None # type: ignore + return resolved_type + + if includes is None: + includes = () + + # Parse nested includes + direct_includes = [] + sub_include_map: dict[str, list[str]] = {} + for include in includes: + if "." in include: + parent, sub_include = include.split(".", 1) + if parent not in sub_include_map: + sub_include_map[parent] = [] + sub_include_map[parent].append(sub_include) + if parent not in direct_includes: + direct_includes.append(parent) + else: + direct_includes.append(include) + + fields = {} + + # Process model fields + for field_name, field_info in cls.model_fields.items(): + field_type = field_info.annotation or Any + field_type = _evaluate_type(field_type, field_name=field_name) + + if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes: + continue + else: + fields[field_name] = field_type + + # Process calculated fields + for field_name, field_type in cls._CALCULATED_FIELDS.items(): + field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name) + fields[field_name] = field_type + + # Process ondemand calculated fields + for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items(): + if field_name not in direct_includes: + continue + + field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name) + fields[field_name] = field_type + + return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType] diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 695e8a7..8072d46 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -1,117 +1,339 @@ from datetime import datetime import hashlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict from app.calculator import get_calculator from app.config import settings -from app.database.beatmap_tags import BeatmapTagVote -from app.database.failtime import FailTime, FailTimeResp from app.models.beatmap import BeatmapRankStatus from app.models.mods import APIMod from app.models.performance import DifficultyAttributesUnion from app.models.score import GameMode +from ._base import DatabaseModel, OnDemand, included, ondemand from .beatmap_playcounts import BeatmapPlaycounts -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmap_tags import BeatmapTagVote +from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel +from .failtime import FailTime, FailTimeResp +from .user import User, UserDict, UserModel from pydantic import BaseModel, TypeAdapter from redis.asyncio import Redis from sqlalchemy import Column, DateTime +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher - from .user import User - class BeatmapOwner(SQLModel): id: int username: str -class BeatmapBase(SQLModel): - # Beatmap - url: str +class BeatmapDict(TypedDict): + beatmapset_id: int + difficulty_rating: float + id: int mode: GameMode + total_length: int + user_id: int + version: str + url: str + + checksum: NotRequired[str] + max_combo: NotRequired[int | None] + ar: NotRequired[float] + cs: NotRequired[float] + drain: NotRequired[float] + accuracy: NotRequired[float] + bpm: NotRequired[float] + count_circles: NotRequired[int] + count_sliders: NotRequired[int] + count_spinners: NotRequired[int] + deleted_at: NotRequired[datetime | None] + hit_length: NotRequired[int] + last_updated: NotRequired[datetime] + + status: NotRequired[str] + beatmapset: NotRequired[BeatmapsetDict] + current_user_playcount: NotRequired[int] + current_user_tag_ids: NotRequired[list[int]] + failtimes: NotRequired[FailTimeResp] + top_tag_ids: NotRequired[list[dict[str, int]]] + user: NotRequired[UserDict] + convert: NotRequired[bool] + is_scoreable: NotRequired[bool] + mode_int: NotRequired[int] + ranked: NotRequired[int] + playcount: NotRequired[int] + passcount: NotRequired[int] + + +class BeatmapModel(DatabaseModel[BeatmapDict]): + BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [ + "checksum", + "accuracy", + "ar", + "bpm", + "convert", + "count_circles", + "count_sliders", + "count_spinners", + "cs", + "deleted_at", + "drain", + "hit_length", + "is_scoreable", + "last_updated", + "mode_int", + "passcount", + "playcount", + "ranked", + "url", + ] + DEFAULT_API_INCLUDES: ClassVar[list[str]] = [ + "beatmapset.ratings", + "current_user_playcount", + "failtimes", + "max_combo", + "owners", + ] + TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES] + + # Beatmap beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) difficulty_rating: float = Field(default=0.0, index=True) + id: int = Field(primary_key=True, index=True) + mode: GameMode total_length: int user_id: int = Field(index=True) version: str = Field(index=True) + url: OnDemand[str] # optional - checksum: str = Field(sa_column=Column(VARCHAR(32), index=True)) - current_user_playcount: int = Field(default=0) - max_combo: int | None = Field(default=0) - # TODO: failtimes, owners + checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True)) + max_combo: OnDemand[int | None] = Field(default=0) + # TODO: owners # BeatmapExtended - ar: float = Field(default=0.0) - cs: float = Field(default=0.0) - drain: float = Field(default=0.0) # hp - accuracy: float = Field(default=0.0) # od - bpm: float = Field(default=0.0) - count_circles: int = Field(default=0) - count_sliders: int = Field(default=0) - count_spinners: int = Field(default=0) - deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) - hit_length: int = Field(default=0) - last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) + ar: OnDemand[float] = Field(default=0.0) + cs: OnDemand[float] = Field(default=0.0) + drain: OnDemand[float] = Field(default=0.0) # hp + accuracy: OnDemand[float] = Field(default=0.0) # od + bpm: OnDemand[float] = Field(default=0.0) + count_circles: OnDemand[int] = Field(default=0) + count_sliders: OnDemand[int] = Field(default=0) + count_spinners: OnDemand[int] = Field(default=0) + deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime)) + hit_length: OnDemand[int] = Field(default=0) + last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True)) + + @included + @staticmethod + async def status(_session: AsyncSession, beatmap: "Beatmap") -> str: + if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard(): + return BeatmapRankStatus.APPROVED.name.lower() + return beatmap.beatmap_status.name.lower() + + @ondemand + @staticmethod + async def beatmapset( + _session: AsyncSession, + beatmap: "Beatmap", + includes: list[str] | None = None, + ) -> BeatmapsetDict | None: + if beatmap.beatmapset is not None: + return await BeatmapsetModel.transform( + beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES + ) + + @ondemand + @staticmethod + async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int: + playcount = ( + await _session.exec( + select(BeatmapPlaycounts.playcount).where( + BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id + ) + ) + ).first() + return int(playcount or 0) + + @ondemand + @staticmethod + async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]: + if user is None: + return [] + tag_ids = ( + await _session.exec( + select(BeatmapTagVote.tag_id).where( + BeatmapTagVote.beatmap_id == beatmap.id, + BeatmapTagVote.user_id == user.id, + ) + ) + ).all() + return list(tag_ids) + + @ondemand + @staticmethod + async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp: + if beatmap.failtimes is not None: + return FailTimeResp.from_db(beatmap.failtimes) + return FailTimeResp() + + @ondemand + @staticmethod + async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]: + all_votes = ( + await _session.exec( + select(BeatmapTagVote.tag_id, func.count().label("vote_count")) + .where(BeatmapTagVote.beatmap_id == beatmap.id) + .group_by(col(BeatmapTagVote.tag_id)) + .having(func.count() > settings.beatmap_tag_top_count) + ) + ).all() + top_tag_ids: list[dict[str, int]] = [] + for id, votes in all_votes: + top_tag_ids.append({"tag_id": id, "count": votes}) + top_tag_ids.sort(key=lambda x: x["count"], reverse=True) + return top_tag_ids + + @ondemand + @staticmethod + async def user( + _session: AsyncSession, + beatmap: "Beatmap", + includes: list[str] | None = None, + ) -> UserDict | None: + from .user import User + + user = await _session.get(User, beatmap.user_id) + if user is None: + return None + return await UserModel.transform(user, includes=includes) + + @ondemand + @staticmethod + async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool: + return False + + @ondemand + @staticmethod + async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool: + beatmap_status = beatmap.beatmap_status + if settings.enable_all_beatmap_leaderboard: + return True + return beatmap_status.has_leaderboard() + + @ondemand + @staticmethod + async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int: + return int(beatmap.mode) + + @ondemand + @staticmethod + async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int: + beatmap_status = beatmap.beatmap_status + if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): + return BeatmapRankStatus.APPROVED.value + return beatmap_status.value + + @ondemand + @staticmethod + async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int: + result = ( + await _session.exec( + select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id) + ) + ).first() + return int(result or 0) + + @ondemand + @staticmethod + async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int: + from .score import Score + + return ( + await _session.exec( + select(func.count()) + .select_from(Score) + .where( + Score.beatmap_id == beatmap.id, + col(Score.passed).is_(True), + ) + ) + ).one() -class Beatmap(BeatmapBase, table=True): +class Beatmap(AsyncAttrs, BeatmapModel, table=True): __tablename__: str = "beatmaps" - id: int = Field(primary_key=True, index=True) - beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) + beatmap_status: BeatmapRankStatus = Field(index=True) # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}) + beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}) failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}) @classmethod - async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": - d = resp.model_dump() - del d["beatmapset"] + async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap": + d = {k: v for k, v in resp.items() if k != "beatmapset"} + beatmapset_id = resp.get("beatmapset_id") + bid = resp.get("id") + ranked = resp.get("ranked") + if beatmapset_id is None or bid is None or ranked is None: + raise ValueError("beatmapset_id, id and ranked are required") beatmap = cls.model_validate( { **d, - "beatmapset_id": resp.beatmapset_id, - "id": resp.id, - "beatmap_status": BeatmapRankStatus(resp.ranked), + "beatmapset_id": beatmapset_id, + "id": bid, + "beatmap_status": BeatmapRankStatus(ranked), } ) return beatmap @classmethod - async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": + async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap": beatmap = await cls.from_resp_no_save(session, resp) - if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): + resp_id = resp.get("id") + if resp_id is None: + raise ValueError("id is required") + if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first(): session.add(beatmap) await session.commit() - return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() + return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one() @classmethod - async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]: + async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]: beatmaps = [] - for resp in inp: - if resp.id == from_: + for resp_dict in inp: + bid = resp_dict.get("id") + if bid == from_ or bid is None: continue - d = resp.model_dump() - del d["beatmapset"] + + beatmapset_id = resp_dict.get("beatmapset_id") + ranked = resp_dict.get("ranked") + if beatmapset_id is None or ranked is None: + continue + + # 创建 beatmap 字典,移除 beatmapset + d = {k: v for k, v in resp_dict.items() if k != "beatmapset"} + beatmap = Beatmap.model_validate( { **d, - "beatmapset_id": resp.beatmapset_id, - "id": resp.id, - "beatmap_status": BeatmapRankStatus(resp.ranked), + "beatmapset_id": beatmapset_id, + "id": bid, + "beatmap_status": BeatmapRankStatus(ranked), } ) - if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): + if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first(): session.add(beatmap) beatmaps.append(beatmap) await session.commit() + for beatmap in beatmaps: + await session.refresh(beatmap) return beatmaps @classmethod @@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True): beatmap = (await session.exec(stmt)).first() if not beatmap: resp = await fetcher.get_beatmap(bid, md5) - r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)) + beatmapset_id = resp.get("beatmapset_id") + if beatmapset_id is None: + raise ValueError("beatmapset_id is required") + r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id)) if not r.first(): - set_resp = await fetcher.get_beatmapset(resp.beatmapset_id) - await Beatmapset.from_resp(session, set_resp, from_=resp.id) + set_resp = await fetcher.get_beatmapset(beatmapset_id) + resp_id = resp.get("id") + await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0) return await Beatmap.from_resp(session, resp) return beatmap @@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel): count: int -class BeatmapResp(BeatmapBase): - id: int - beatmapset_id: int - beatmapset: BeatmapsetResp | None = None - convert: bool = False - is_scoreable: bool - status: str - mode_int: int - ranked: int - url: str = "" - playcount: int = 0 - passcount: int = 0 - failtimes: FailTimeResp | None = None - top_tag_ids: list[APIBeatmapTag] | None = None - current_user_tag_ids: list[int] | None = None - is_deleted: bool = False - - @classmethod - async def from_db( - cls, - beatmap: Beatmap, - query_mode: GameMode | None = None, - from_set: bool = False, - session: AsyncSession | None = None, - user: "User | None" = None, - ) -> "BeatmapResp": - from .score import Score - - beatmap_ = beatmap.model_dump() - beatmap_status = beatmap.beatmap_status - if query_mode is not None and beatmap.mode != query_mode: - beatmap_["convert"] = True - beatmap_["is_scoreable"] = beatmap_status.has_leaderboard() - if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): - beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value - beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower() - else: - beatmap_["status"] = beatmap_status.name.lower() - beatmap_["ranked"] = beatmap_status.value - beatmap_["mode_int"] = int(beatmap.mode) - beatmap_["is_deleted"] = beatmap.deleted_at is not None - if not from_set: - beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user) - if beatmap.failtimes is not None: - beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes) - else: - beatmap_["failtimes"] = FailTimeResp() - if session: - beatmap_["playcount"] = ( - await session.exec( - select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id) - ) - ).first() or 0 - beatmap_["passcount"] = ( - await session.exec( - select(func.count()) - .select_from(Score) - .where( - Score.beatmap_id == beatmap.id, - col(Score.passed).is_(True), - ) - ) - ).one() - - all_votes = ( - await session.exec( - select(BeatmapTagVote.tag_id, func.count().label("vote_count")) - .where(BeatmapTagVote.beatmap_id == beatmap.id) - .group_by(col(BeatmapTagVote.tag_id)) - .having(func.count() > settings.beatmap_tag_top_count) - ) - ).all() - top_tag_ids: list[dict[str, int]] = [] - for id, votes in all_votes: - top_tag_ids.append({"tag_id": id, "count": votes}) - top_tag_ids.sort(key=lambda x: x["count"], reverse=True) - beatmap_["top_tag_ids"] = top_tag_ids - - if user is not None: - beatmap_["current_user_tag_ids"] = ( - await session.exec( - select(BeatmapTagVote.tag_id) - .where(BeatmapTagVote.beatmap_id == beatmap.id) - .where(BeatmapTagVote.user_id == user.id) - ) - ).all() - else: - beatmap_["current_user_tag_ids"] = [] - return cls.model_validate(beatmap_) - - class BannedBeatmaps(SQLModel, table=True): __tablename__: str = "banned_beatmaps" id: int | None = Field(primary_key=True, index=True, default=None) diff --git a/app/database/beatmap_playcounts.py b/app/database/beatmap_playcounts.py index 2635d6a..cf2fb6c 100644 --- a/app/database/beatmap_playcounts.py +++ b/app/database/beatmap_playcounts.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NotRequired, TypedDict from app.config import settings -from app.database.events import Event, EventType from app.utils import utcnow -from pydantic import BaseModel +from ._base import DatabaseModel, included +from .events import Event, EventType + from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( BigInteger, @@ -12,52 +13,65 @@ from sqlmodel import ( Field, ForeignKey, Relationship, - SQLModel, select, ) from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from .beatmap import Beatmap, BeatmapResp - from .beatmapset import BeatmapsetResp + from .beatmap import Beatmap, BeatmapDict + from .beatmapset import BeatmapsetDict from .user import User -class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True): +class BeatmapPlaycountsDict(TypedDict): + user_id: int + beatmap_id: int + count: NotRequired[int] + beatmap: NotRequired["BeatmapDict"] + beatmapset: NotRequired["BeatmapsetDict"] + + +class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]): __tablename__: str = "beatmap_playcounts" id: int | None = Field( - default=None, - sa_column=Column(BigInteger, primary_key=True, autoincrement=True), + default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True ) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) - playcount: int = Field(default=0) + playcount: int = Field(default=0, exclude=True) + @included + @staticmethod + async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int: + return obj.playcount + + @included + @staticmethod + async def beatmap( + _session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None + ) -> "BeatmapDict": + from .beatmap import BeatmapModel + + await obj.awaitable_attrs.beatmap + return await BeatmapModel.transform(obj.beatmap, includes=includes) + + @included + @staticmethod + async def beatmapset( + _session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None + ) -> "BeatmapsetDict": + from .beatmap import BeatmapsetModel + + await obj.awaitable_attrs.beatmap + return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes) + + +class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True): user: "User" = Relationship() beatmap: "Beatmap" = Relationship() -class BeatmapPlaycountsResp(BaseModel): - beatmap_id: int - beatmap: "BeatmapResp | None" = None - beatmapset: "BeatmapsetResp | None" = None - count: int - - @classmethod - async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp": - from .beatmap import BeatmapResp - from .beatmapset import BeatmapsetResp - - await db_model.awaitable_attrs.beatmap - return cls( - beatmap_id=db_model.beatmap_id, - count=db_model.playcount, - beatmap=await BeatmapResp.from_db(db_model.beatmap), - beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset), - ) - - async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None: existing_playcount = ( await session.exec( diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index fd58c8b..ff8346d 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -1,14 +1,15 @@ from datetime import datetime -from typing import TYPE_CHECKING, NotRequired, Self, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict from app.config import settings -from app.database.beatmap_playcounts import BeatmapPlaycounts from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.score import GameMode -from .user import BASE_INCLUDES, User, UserResp +from ._base import DatabaseModel, OnDemand, included, ondemand +from .beatmap_playcounts import BeatmapPlaycounts +from .user import User, UserDict -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel from sqlalchemy import JSON, Boolean, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select @@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher - from .beatmap import Beatmap, BeatmapResp + from .beatmap import Beatmap, BeatmapDict from .favourite_beatmapset import FavouriteBeatmapset @@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel): id: int | None = None -class BeatmapsetBase(SQLModel): +class BeatmapsetDict(TypedDict): + id: int + artist: str + artist_unicode: str + covers: BeatmapCovers | None + creator: str + nsfw: bool + preview_url: str + source: str + spotlight: bool + title: str + title_unicode: str + track_id: int | None + user_id: int + video: bool + current_nominations: list[BeatmapNomination] | None + description: BeatmapDescription | None + pack_tags: list[str] + + bpm: NotRequired[float] + can_be_hyped: NotRequired[bool] + discussion_locked: NotRequired[bool] + last_updated: NotRequired[datetime] + ranked_date: NotRequired[datetime | None] + storyboard: NotRequired[bool] + submitted_date: NotRequired[datetime] + tags: NotRequired[str] + discussion_enabled: NotRequired[bool] + legacy_thread_url: NotRequired[str | None] + status: NotRequired[str] + ranked: NotRequired[int] + is_scoreable: NotRequired[bool] + favourite_count: NotRequired[int] + genre_id: NotRequired[int] + hype: NotRequired[BeatmapHype] + language_id: NotRequired[int] + play_count: NotRequired[int] + availability: NotRequired[BeatmapAvailability] + beatmaps: NotRequired[list["BeatmapDict"]] + has_favourited: NotRequired[bool] + recent_favourites: NotRequired[list[UserDict]] + genre: NotRequired[BeatmapTranslationText] + language: NotRequired[BeatmapTranslationText] + nominations: NotRequired["BeatmapNominations"] + ratings: NotRequired[list[int]] + + +class BeatmapsetModel(DatabaseModel[BeatmapsetDict]): + BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [ + "availability", + "has_favourited", + "bpm", + "deleted_atcan_be_hyped", + "discussion_locked", + "is_scoreable", + "last_updated", + "legacy_thread_url", + "ranked", + "ranked_date", + "submitted_date", + "tags", + "rating", + "storyboard", + ] + API_INCLUDES: ClassVar[list[str]] = [ + *BEATMAPSET_TRANSFORMER_INCLUDES, + "beatmaps.current_user_playcount", + "beatmaps.current_user_tag_ids", + "beatmaps.max_combo", + "current_nominations", + "current_user_attributes", + "description", + "genre", + "language", + "pack_tags", + "ratings", + "recent_favourites", + "related_tags", + "related_users", + "user", + "version_count", + *[ + f"beatmaps.{inc}" + for inc in { + "failtimes", + "owners", + "top_tag_ids", + } + ], + ] + # Beatmapset + id: int = Field(default=None, primary_key=True, index=True) artist: str = Field(index=True) artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) @@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel): nsfw: bool = Field(default=False, sa_column=Column(Boolean)) preview_url: str source: str = Field(default="") - spotlight: bool = Field(default=False, sa_column=Column(Boolean)) title: str = Field(index=True) title_unicode: str = Field(index=True) + track_id: int | None = Field(default=None, index=True) # feature artist? user_id: int = Field(index=True) video: bool = Field(sa_column=Column(Boolean, index=True)) # optional # converts: list[Beatmap] = Relationship(back_populates="beatmapset") - current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON)) - description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON)) + current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON)) + description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON)) # TODO: discussions: list[BeatmapsetDiscussion] = None # TODO: current_user_attributes: Optional[CurrentUserAttributes] = None # TODO: events: Optional[list[BeatmapsetEvent]] = None - pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) + pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON)) # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) - track_id: int | None = Field(default=None, index=True) # feature artist? # BeatmapsetExtended - bpm: float = Field(default=0.0) - can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean)) - discussion_locked: bool = Field(default=False, sa_column=Column(Boolean)) - last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) - ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True)) - storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True)) - submitted_date: datetime = Field(sa_column=Column(DateTime, index=True)) - tags: str = Field(default="", sa_column=Column(Text)) + bpm: OnDemand[float] = Field(default=0.0) + can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean)) + discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean)) + last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True)) + ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True)) + storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True)) + submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True)) + tags: OnDemand[str] = Field(default="", sa_column=Column(Text)) + + @ondemand + @staticmethod + async def legacy_thread_url( + _session: AsyncSession, + _beatmapset: "Beatmapset", + ) -> str | None: + return None + + @included + @staticmethod + async def discussion_enabled( + _session: AsyncSession, + _beatmapset: "Beatmapset", + ) -> bool: + return True + + @included + @staticmethod + async def status( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> str: + beatmap_status = beatmapset.beatmap_status + if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): + return BeatmapRankStatus.APPROVED.name.lower() + return beatmap_status.name.lower() + + @included + @staticmethod + async def ranked( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> int: + beatmap_status = beatmapset.beatmap_status + if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): + return BeatmapRankStatus.APPROVED.value + return beatmap_status.value + + @ondemand + @staticmethod + async def is_scoreable( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> bool: + beatmap_status = beatmapset.beatmap_status + if settings.enable_all_beatmap_leaderboard: + return True + return beatmap_status.has_leaderboard() + + @included + @staticmethod + async def favourite_count( + session: AsyncSession, + beatmapset: "Beatmapset", + ) -> int: + from .favourite_beatmapset import FavouriteBeatmapset + + count = await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + return count.one() + + @included + @staticmethod + async def genre_id( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> int: + return beatmapset.beatmap_genre.value + + @ondemand + @staticmethod + async def hype( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> BeatmapHype: + return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required) + + @included + @staticmethod + async def language_id( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> int: + return beatmapset.beatmap_language.value + + @included + @staticmethod + async def play_count( + session: AsyncSession, + beatmapset: "Beatmapset", + ) -> int: + from .beatmap import Beatmap + + playcount = await session.exec( + select(func.sum(BeatmapPlaycounts.playcount)).where( + col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id) + ) + ) + return int(playcount.first() or 0) + + @ondemand + @staticmethod + async def availability( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> BeatmapAvailability: + return BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ) + + @ondemand + @staticmethod + async def beatmaps( + _session: AsyncSession, + beatmapset: "Beatmapset", + includes: list[str] | None = None, + user: "User | None" = None, + ) -> list["BeatmapDict"]: + from .beatmap import BeatmapModel + + return [ + await BeatmapModel.transform( + beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user + ) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ] + + # @ondemand + # @staticmethod + # async def current_nominations( + # _session: AsyncSession, + # beatmapset: "Beatmapset", + # ) -> list[BeatmapNomination] | None: + # return beatmapset.current_nominations or [] + + @ondemand + @staticmethod + async def has_favourited( + session: AsyncSession, + beatmapset: "Beatmapset", + user: User | None = None, + ) -> bool: + from .favourite_beatmapset import FavouriteBeatmapset + + if session is None: + return False + query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + if user is not None: + query = query.where(FavouriteBeatmapset.user_id == user.id) + existing = (await session.exec(query)).first() + return existing is not None + + @ondemand + @staticmethod + async def recent_favourites( + session: AsyncSession, + beatmapset: "Beatmapset", + includes: list[str] | None = None, + ) -> list[UserDict]: + from .favourite_beatmapset import FavouriteBeatmapset + + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + return [ + await User.transform( + (await favourite.awaitable_attrs.user), + includes=includes, + ) + for favourite in recent_favourites + ] + + @ondemand + @staticmethod + async def genre( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> BeatmapTranslationText: + return BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ) + + @ondemand + @staticmethod + async def language( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> BeatmapTranslationText: + return BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ) + + @ondemand + @staticmethod + async def nominations( + _session: AsyncSession, + beatmapset: "Beatmapset", + ) -> BeatmapNominations: + return BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ) + + # @ondemand + # @staticmethod + # async def user( + # session: AsyncSession, + # beatmapset: Beatmapset, + # includes: list[str] | None = None, + # ) -> dict[str, Any] | None: + # db_user = await session.get(User, beatmapset.user_id) + # if not db_user: + # return None + # return await UserResp.transform(db_user, includes=includes) + + @ondemand + @staticmethod + async def ratings( + session: AsyncSession, + beatmapset: "Beatmapset", + ) -> list[int]: + # Provide a stable default shape if no session is available + if session is None: + return [] + + from .beatmapset_ratings import BeatmapRating + + beatmapset_all_ratings = ( + await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id)) + ).all() + ratings_list = [0] * 11 + for rating in beatmapset_all_ratings: + ratings_list[rating.rating] += 1 + return ratings_list -class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True): __tablename__: str = "beatmapsets" - id: int = Field(default=None, primary_key=True, index=True) # Beatmapset beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True) @@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod - async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset": - d = resp.model_dump() - if resp.nominations: - d["nominations_required"] = resp.nominations.required - d["nominations_current"] = resp.nominations.current - if resp.hype: - d["hype_current"] = resp.hype.current - d["hype_required"] = resp.hype.required - if resp.genre_id: - d["beatmap_genre"] = Genre(resp.genre_id) - elif resp.genre: - d["beatmap_genre"] = Genre(resp.genre.id) - if resp.language_id: - d["beatmap_language"] = Language(resp.language_id) - elif resp.language: - d["beatmap_language"] = Language(resp.language.id) + async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset": + # make a shallow copy so we can mutate safely + d: dict[str, Any] = dict(resp) + + # nominations = resp.get("nominations") + # if nominations is not None: + # d["nominations_required"] = nominations.required + # d["nominations_current"] = nominations.current + + hype = resp.get("hype") + if hype is not None: + d["hype_current"] = hype.current + d["hype_required"] = hype.required + + genre_id = resp.get("genre_id") + genre = resp.get("genre") + if genre_id is not None: + d["beatmap_genre"] = Genre(genre_id) + elif genre is not None: + d["beatmap_genre"] = Genre(genre.id) + + language_id = resp.get("language_id") + language = resp.get("language") + if language_id is not None: + d["beatmap_language"] = Language(language_id) + elif language is not None: + d["beatmap_language"] = Language(language.id) + + availability = resp.get("availability") + ranked = resp.get("ranked") + if ranked is None: + raise ValueError("ranked field is required") + beatmapset = Beatmapset.model_validate( { **d, - "id": resp.id, - "beatmap_status": BeatmapRankStatus(resp.ranked), - "availability_info": resp.availability.more_information, - "download_disabled": resp.availability.download_disabled or False, + "beatmap_status": BeatmapRankStatus(ranked), + "availability_info": availability.more_information if availability is not None else None, + "download_disabled": bool(availability.download_disabled) if availability is not None else False, } ) return beatmapset @@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): async def from_resp( cls, session: AsyncSession, - resp: "BeatmapsetResp", + resp: BeatmapsetDict, from_: int = 0, ) -> "Beatmapset": from .beatmap import Beatmap + beatmapset_id = resp["id"] beatmapset = await cls.from_resp_no_save(resp) - if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first(): + if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first(): session.add(beatmapset) await session.commit() - await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) - beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one() + beatmaps = resp.get("beatmaps", []) + await Beatmap.from_resp_batch(session, beatmaps, from_=from_) + beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one() return beatmapset @classmethod @@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): resp = await fetcher.get_beatmapset(sid) beatmapset = await cls.from_resp(session, resp) await get_beatmapset_update_service().add(resp) + await session.refresh(beatmapset) return beatmapset - - -class BeatmapsetResp(BeatmapsetBase): - id: int - beatmaps: list["BeatmapResp"] = Field(default_factory=list) - discussion_enabled: bool = True - status: str - ranked: int - legacy_thread_url: str | None = "" - is_scoreable: bool - hype: BeatmapHype | None = None - availability: BeatmapAvailability - genre: BeatmapTranslationText | None = None - genre_id: int - language: BeatmapTranslationText | None = None - language_id: int - nominations: BeatmapNominations | None = None - has_favourited: bool = False - favourite_count: int = 0 - recent_favourites: list[UserResp] = Field(default_factory=list) - play_count: int = 0 - - @field_validator( - "nsfw", - "spotlight", - "video", - "can_be_hyped", - "discussion_locked", - "storyboard", - "discussion_enabled", - "is_scoreable", - "has_favourited", - mode="before", - ) - @classmethod - def validate_bool_fields(cls, v): - """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" - if isinstance(v, int): - return bool(v) - return v - - @model_validator(mode="after") - def fix_genre_language(self) -> Self: - if self.genre is None: - self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id) - if self.language is None: - self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id) - return self - - @classmethod - async def from_db( - cls, - beatmapset: Beatmapset, - include: list[str] = [], - session: AsyncSession | None = None, - user: User | None = None, - ) -> "BeatmapsetResp": - from .beatmap import Beatmap, BeatmapResp - from .favourite_beatmapset import FavouriteBeatmapset - - update = { - "beatmaps": [ - await BeatmapResp.from_db(beatmap, from_set=True, session=session) - for beatmap in await beatmapset.awaitable_attrs.beatmaps - ], - "hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "genre_id": beatmapset.beatmap_genre.value, - "language_id": beatmapset.beatmap_language.value, - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "is_scoreable": beatmapset.beatmap_status.has_leaderboard(), - **beatmapset.model_dump(), - } - - if session is not None: - # 从数据库读取对应谱面集的评分 - from .beatmapset_ratings import BeatmapRating - - beatmapset_all_ratings = ( - await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id)) - ).all() - ratings_list = [0] * 11 - for rating in beatmapset_all_ratings: - ratings_list[rating.rating] += 1 - update["ratings"] = ratings_list - else: - # 返回非空值避免客户端崩溃 - if update.get("ratings") is None: - update["ratings"] = [] - - beatmap_status = beatmapset.beatmap_status - if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): - update["status"] = BeatmapRankStatus.APPROVED.name.lower() - update["ranked"] = BeatmapRankStatus.APPROVED.value - else: - update["status"] = beatmap_status.name.lower() - update["ranked"] = beatmap_status.value - - if session and user: - existing_favourite = ( - await session.exec( - select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) - ) - ).first() - update["has_favourited"] = existing_favourite is not None - - if session and "recent_favourites" in include: - recent_favourites = ( - await session.exec( - select(FavouriteBeatmapset) - .where( - FavouriteBeatmapset.beatmapset_id == beatmapset.id, - ) - .order_by(col(FavouriteBeatmapset.date).desc()) - .limit(50) - ) - ).all() - update["recent_favourites"] = [ - await UserResp.from_db( - await favourite.awaitable_attrs.user, - session=session, - include=BASE_INCLUDES, - ) - for favourite in recent_favourites - ] - - if session: - update["favourite_count"] = ( - await session.exec( - select(func.count()) - .select_from(FavouriteBeatmapset) - .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) - ) - ).one() - - update["play_count"] = ( - await session.exec( - select(func.sum(BeatmapPlaycounts.playcount)).where( - col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id) - ) - ) - ).first() or 0 - return cls.model_validate( - update, - ) - - -class SearchBeatmapsetsResp(SQLModel): - beatmapsets: list[BeatmapsetResp] - total: int - cursor: dict[str, int | float | str] | None = None - cursor_string: str | None = None diff --git a/app/database/beatmapset_ratings.py b/app/database/beatmapset_ratings.py index 8b63d88..c6e1d75 100644 --- a/app/database/beatmapset_ratings.py +++ b/app/database/beatmapset_ratings.py @@ -1,5 +1,5 @@ -from app.database.beatmapset import Beatmapset -from app.database.user import User +from .beatmapset import Beatmapset +from .user import User from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel diff --git a/app/database/best_scores.py b/app/database/best_scores.py index b1f9059..13fead7 100644 --- a/app/database/best_scores.py +++ b/app/database/best_scores.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING -from app.database.statistics import UserStatistics from app.models.score import GameMode +from .statistics import UserStatistics from .user import User from sqlmodel import ( diff --git a/app/database/chat.py b/app/database/chat.py index b05c790..3dfb0cb 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -1,13 +1,14 @@ from datetime import datetime from enum import Enum -from typing import Self +from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict -from app.database.user import RANKING_INCLUDES, User, UserResp from app.models.model import UTCBaseModel from app.utils import utcnow +from ._base import DatabaseModel, included, ondemand +from .user import User, UserDict, UserModel + from pydantic import BaseModel -from redis.asyncio import Redis from sqlmodel import ( VARCHAR, BigInteger, @@ -22,6 +23,8 @@ from sqlmodel import ( ) from sqlmodel.ext.asyncio.session import AsyncSession +if TYPE_CHECKING: + from app.router.notification.server import ChatServer # ChatChannel @@ -44,16 +47,168 @@ class ChannelType(str, Enum): TEAM = "TEAM" -class ChatChannelBase(SQLModel): - name: str = Field(sa_column=Column(VARCHAR(50), index=True)) +class MessageType(str, Enum): + ACTION = "action" + MARKDOWN = "markdown" + PLAIN = "plain" + + +class ChatChannelDict(TypedDict): + channel_id: int + description: str + name: str + icon: str | None + type: ChannelType + uuid: NotRequired[str | None] + message_length_limit: NotRequired[int] + moderated: NotRequired[bool] + current_user_attributes: NotRequired[ChatUserAttributes] + last_read_id: NotRequired[int | None] + last_message_id: NotRequired[int | None] + recent_messages: NotRequired[list["ChatMessageDict"]] + users: NotRequired[list[int]] + + +class ChatChannelModel(DatabaseModel[ChatChannelDict]): + CONVERSATION_INCLUDES: ClassVar[list[str]] = [ + "last_message_id", + "users", + ] + LISTING_INCLUDES: ClassVar[list[str]] = [ + *CONVERSATION_INCLUDES, + "current_user_attributes", + "last_read_id", + ] + + channel_id: int = Field(primary_key=True, index=True, default=None) description: str = Field(sa_column=Column(VARCHAR(255), index=True)) icon: str | None = Field(default=None) type: ChannelType = Field(index=True) + @included + @staticmethod + async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str: + users = server.channels.get(channel.channel_id, []) + if channel.type == ChannelType.PM and users and len(users) == 2: + target_user_id = next(u for u in users if u != user.id) + target_name = await session.exec(select(User.username).where(User.id == target_user_id)) + return target_name.one() + return channel.name -class ChatChannel(ChatChannelBase, table=True): + @included + @staticmethod + async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool: + silence = ( + await session.exec( + select(SilenceUser).where( + SilenceUser.channel_id == channel.channel_id, + SilenceUser.user_id == user.id, + ) + ) + ).first() + + return silence is not None + + @ondemand + @staticmethod + async def current_user_attributes( + session: AsyncSession, + channel: "ChatChannel", + user: User, + ) -> ChatUserAttributes: + from app.dependencies.database import get_redis + + silence = ( + await session.exec( + select(SilenceUser).where( + SilenceUser.channel_id == channel.channel_id, + SilenceUser.user_id == user.id, + ) + ) + ).first() + can_message = silence is None + can_message_error = "You are silenced in this channel" if not can_message else None + + redis = get_redis() + last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") + last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") + last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None + last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0) + + return ChatUserAttributes( + can_message=can_message, + can_message_error=can_message_error, + last_read_id=last_read_id, + ) + + @ondemand + @staticmethod + async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None: + from app.dependencies.database import get_redis + + redis = get_redis() + last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") + last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") + last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None + return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg + + @ondemand + @staticmethod + async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None: + from app.dependencies.database import get_redis + + redis = get_redis() + last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") + return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None + + @ondemand + @staticmethod + async def recent_messages( + session: AsyncSession, + channel: "ChatChannel", + ) -> list["ChatMessageDict"]: + messages = ( + await session.exec( + select(ChatMessage) + .where(ChatMessage.channel_id == channel.channel_id) + .order_by(col(ChatMessage.message_id).desc()) + .limit(50) + ) + ).all() + result = [ + await ChatMessageModel.transform( + msg, + ) + for msg in reversed(messages) + ] + return result + + @ondemand + @staticmethod + async def users( + _session: AsyncSession, + channel: "ChatChannel", + server: "ChatServer", + user: User, + ) -> list[int]: + if channel.type == ChannelType.PUBLIC: + return [] + users = server.channels.get(channel.channel_id, []).copy() + if channel.type == ChannelType.PM and users and len(users) == 2: + target_user_id = next(u for u in users if u != user.id) + users = [target_user_id, user.id] + return users + + @included + @staticmethod + async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int: + return 1000 + + +class ChatChannel(ChatChannelModel, table=True): __tablename__: str = "chat_channels" - channel_id: int = Field(primary_key=True, index=True, default=None) + + name: str = Field(sa_column=Column(VARCHAR(50), index=True)) @classmethod async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None": @@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True): return channel -class ChatChannelResp(ChatChannelBase): - channel_id: int - moderated: bool = False - uuid: str | None = None - current_user_attributes: ChatUserAttributes | None = None - last_read_id: int | None = None - last_message_id: int | None = None - recent_messages: list["ChatMessageResp"] = Field(default_factory=list) - users: list[int] = Field(default_factory=list) - message_length_limit: int = 1000 - - @classmethod - async def from_db( - cls, - channel: ChatChannel, - session: AsyncSession, - user: User, - redis: Redis, - users: list[int] | None = None, - include_recent_messages: bool = False, - ) -> Self: - c = cls.model_validate(channel) - silence = ( - await session.exec( - select(SilenceUser).where( - SilenceUser.channel_id == channel.channel_id, - SilenceUser.user_id == user.id, - ) - ) - ).first() - - last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") - last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None - - last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") - last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg - - if silence is not None: - attribute = ChatUserAttributes( - can_message=False, - can_message_error=silence.reason or "You are muted in this channel.", - last_read_id=last_read_id or 0, - ) - c.moderated = True - else: - attribute = ChatUserAttributes( - can_message=True, - last_read_id=last_read_id or 0, - ) - c.moderated = False - - c.current_user_attributes = attribute - if c.type != ChannelType.PUBLIC and users is not None: - c.users = users - c.last_message_id = last_msg - c.last_read_id = last_read_id - - if include_recent_messages: - messages = ( - await session.exec( - select(ChatMessage) - .where(ChatMessage.channel_id == channel.channel_id) - .order_by(col(ChatMessage.timestamp).desc()) - .limit(10) - ) - ).all() - c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages] - c.recent_messages.reverse() - - if c.type == ChannelType.PM and users and len(users) == 2: - target_user_id = next(u for u in users if u != user.id) - target_name = await session.exec(select(User.username).where(User.id == target_user_id)) - c.name = target_name.one() - c.users = [target_user_id, user.id] - return c - - # ChatMessage +class ChatMessageDict(TypedDict): + channel_id: int + content: str + message_id: int + sender_id: int + timestamp: datetime + type: MessageType + uuid: str | None + is_action: NotRequired[bool] + sender: NotRequired[UserDict] -class MessageType(str, Enum): - ACTION = "action" - MARKDOWN = "markdown" - PLAIN = "plain" - - -class ChatMessageBase(UTCBaseModel, SQLModel): +class ChatMessageModel(DatabaseModel[ChatMessageDict]): channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") content: str = Field(sa_column=Column(VARCHAR(1000))) message_id: int = Field(index=True, primary_key=True, default=None) @@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel): type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) uuid: str | None = Field(default=None) + @included + @staticmethod + async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool: + return db_message.type == MessageType.ACTION -class ChatMessage(ChatMessageBase, table=True): + @ondemand + @staticmethod + async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict: + return await UserModel.transform(db_message.user) + + +class ChatMessage(ChatMessageModel, table=True): __tablename__: str = "chat_messages" user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) - channel: ChatChannel = Relationship() - - -class ChatMessageResp(ChatMessageBase): - sender: UserResp | None = None - is_action: bool = False - - @classmethod - async def from_db( - cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None - ) -> "ChatMessageResp": - m = cls.model_validate(db_message.model_dump()) - m.is_action = db_message.type == MessageType.ACTION - if user: - m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES) - else: - m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES) - return m - - -# SilenceUser + channel: "ChatChannel" = Relationship() class SilenceUser(UTCBaseModel, SQLModel, table=True): diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py index 308bc30..2925207 100644 --- a/app/database/favourite_beatmapset.py +++ b/app/database/favourite_beatmapset.py @@ -1,7 +1,7 @@ import datetime -from app.database.beatmapset import Beatmapset -from app.database.user import User +from .beatmapset import Beatmapset +from .user import User from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( diff --git a/app/database/item_attempts_count.py b/app/database/item_attempts_count.py index a4487c2..4522787 100644 --- a/app/database/item_attempts_count.py +++ b/app/database/item_attempts_count.py @@ -1,7 +1,9 @@ -from .playlist_best_score import PlaylistBestScore -from .user import User, UserResp +from typing import Any, NotRequired, TypedDict + +from ._base import DatabaseModel, ondemand +from .playlist_best_score import PlaylistBestScore +from .user import User, UserDict, UserModel -from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( BigInteger, @@ -9,7 +11,6 @@ from sqlmodel import ( Field, ForeignKey, Relationship, - SQLModel, col, func, select, @@ -17,17 +18,66 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession -class ItemAttemptsCountBase(SQLModel): - room_id: int = Field(foreign_key="rooms.id", index=True) +class ItemAttemptsCountDict(TypedDict): + accuracy: float + attempts: int + completed: int + pp: float + room_id: int + total_score: int + user_id: int + user: NotRequired[UserDict] + position: NotRequired[int] + playlist_item_attempts: NotRequired[list[dict[str, Any]]] + + +class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]): + accuracy: float = 0.0 attempts: int = Field(default=0) completed: int = Field(default=0) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) - accuracy: float = 0.0 pp: float = 0 + room_id: int = Field(foreign_key="rooms.id", index=True) total_score: int = 0 + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + + @ondemand + @staticmethod + async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict: + user_instance = await item_attempts.awaitable_attrs.user + return await UserModel.transform(user_instance) + + @ondemand + @staticmethod + async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int: + return await item_attempts.get_position(session) + + @ondemand + @staticmethod + async def playlist_item_attempts( + session: AsyncSession, + item_attempts: "ItemAttemptsCount", + ) -> list[dict[str, Any]]: + playlist_scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == item_attempts.room_id, + PlaylistBestScore.user_id == item_attempts.user_id, + ) + ) + ).all() + result: list[dict[str, Any]] = [] + for score in playlist_scores: + result.append( + { + "id": score.playlist_id, + "attempts": score.attempts, + "passed": score.score.passed, + } + ) + return result -class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): +class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True): __tablename__: str = "item_attempts_count" id: int | None = Field(default=None, primary_key=True) @@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): rownum = ( func.row_number() .over( - partition_by=col(ItemAttemptsCountBase.room_id), - order_by=col(ItemAttemptsCountBase.total_score).desc(), + partition_by=col(ItemAttemptsCount.room_id), + order_by=col(ItemAttemptsCount.total_score).desc(), ) .label("rn") ) - subq = select(ItemAttemptsCountBase, rownum).subquery() + subq = select(ItemAttemptsCount, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id) result = await session.exec(stmt) - return result.one() + return result.first() or 0 async def update(self, session: AsyncSession): playlist_scores = ( @@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): await session.refresh(item_attempts) await item_attempts.update(session) return item_attempts - - -class ItemAttemptsResp(ItemAttemptsCountBase): - user: UserResp | None = None - position: int | None = None - - @classmethod - async def from_db( - cls, - item_attempts: ItemAttemptsCount, - session: AsyncSession, - include: list[str] = [], - ) -> "ItemAttemptsResp": - resp = cls.model_validate(item_attempts.model_dump()) - resp.user = await UserResp.from_db( - await item_attempts.awaitable_attrs.user, - session=session, - include=["statistics", "team", "daily_challenge_user_stats"], - ) - if "position" in include: - resp.position = await item_attempts.get_position(session) - # resp.accuracy *= 100 - return resp - - -class ItemAttemptsCountForItem(BaseModel): - id: int - attempts: int - passed: bool - - -class PlaylistAggregateScore(BaseModel): - playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list) - - @classmethod - async def from_db( - cls, - room_id: int, - user_id: int, - session: AsyncSession, - ) -> "PlaylistAggregateScore": - playlist_scores = ( - await session.exec( - select(PlaylistBestScore).where( - PlaylistBestScore.room_id == room_id, - PlaylistBestScore.user_id == user_id, - ) - ) - ).all() - playlist_item_attempts = [] - for score in playlist_scores: - playlist_item_attempts.append( - ItemAttemptsCountForItem( - id=score.playlist_id, - attempts=score.attempts, - passed=score.score.passed, - ) - ) - return cls(playlist_item_attempts=playlist_item_attempts) diff --git a/app/database/playlists.py b/app/database/playlists.py index 6edf16f..6d5b6f7 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -1,11 +1,11 @@ from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, NotRequired, TypedDict -from app.models.model import UTCBaseModel from app.models.mods import APIMod from app.models.playlist import PlaylistItem -from .beatmap import Beatmap, BeatmapResp +from ._base import DatabaseModel, ondemand +from .beatmap import Beatmap, BeatmapDict, BeatmapModel from sqlmodel import ( JSON, @@ -15,7 +15,6 @@ from sqlmodel import ( Field, ForeignKey, Relationship, - SQLModel, func, select, ) @@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .room import Room + from .score import ScoreDict -class PlaylistBase(SQLModel, UTCBaseModel): - id: int = Field(index=True) - owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) +class PlaylistDict(TypedDict): + id: int + room_id: int + beatmap_id: int + created_at: datetime | None ruleset_id: int - expired: bool = Field(default=False) - playlist_order: int = Field(default=0) - played_at: datetime | None = Field( - sa_column=Column(DateTime(timezone=True)), - default=None, + allowed_mods: list[APIMod] + required_mods: list[APIMod] + freestyle: bool + expired: bool + owner_id: int + playlist_order: int + played_at: datetime | None + beatmap: NotRequired["BeatmapDict"] + scores: NotRequired[list[dict[str, Any]]] + + +class PlaylistModel(DatabaseModel[PlaylistDict]): + id: int = Field(index=True) + room_id: int = Field(foreign_key="rooms.id") + beatmap_id: int = Field( + foreign_key="beatmaps.id", ) + created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()}) + ruleset_id: int allowed_mods: list[APIMod] = Field( default_factory=list, sa_column=Column(JSON), @@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel): default_factory=list, sa_column=Column(JSON), ) - beatmap_id: int = Field( - foreign_key="beatmaps.id", - ) freestyle: bool = Field(default=False) + expired: bool = Field(default=False) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) + playlist_order: int = Field(default=0) + played_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True)), + default=None, + ) + + @ondemand + @staticmethod + async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict: + return await BeatmapModel.transform(playlist.beatmap, includes=includes) + + @ondemand + @staticmethod + async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]: + from .score import Score, ScoreModel + + scores = ( + await session.exec( + select(Score).where( + Score.playlist_item_id == playlist.id, + Score.room_id == playlist.room_id, + ) + ) + ).all() + result: list[ScoreDict] = [] + for score in scores: + result.append( + await ScoreModel.transform( + score, + ) + ) + return result -class Playlist(PlaylistBase, table=True): +class Playlist(PlaylistModel, table=True): __tablename__: str = "room_playlists" db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) - room_id: int = Field(foreign_key="rooms.id", exclude=True) beatmap: Beatmap = Relationship( sa_relationship_kwargs={ @@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True): } ) room: "Room" = Relationship() - created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()}) updated_at: datetime | None = Field( default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()} ) @@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True): raise ValueError("Playlist item not found") await session.delete(db_playlist) await session.commit() - - -class PlaylistResp(PlaylistBase): - beatmap: BeatmapResp | None = None - - @classmethod - async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp": - data = playlist.model_dump() - if "beatmap" in include: - data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap) - resp = cls.model_validate(data) - return resp diff --git a/app/database/relationship.py b/app/database/relationship.py index f792f31..55ecf20 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -1,26 +1,39 @@ from enum import Enum +from typing import TYPE_CHECKING, NotRequired, TypedDict -from .user import User, UserResp +from app.models.score import GameMode + +from ._base import DatabaseModel, included, ondemand -from pydantic import BaseModel from sqlmodel import ( BigInteger, Column, Field, ForeignKey, Relationship as SQLRelationship, - SQLModel, select, ) from sqlmodel.ext.asyncio.session import AsyncSession +if TYPE_CHECKING: + from .user import User, UserDict + class RelationshipType(str, Enum): - FOLLOW = "Friend" - BLOCK = "Block" + FOLLOW = "friend" + BLOCK = "block" -class Relationship(SQLModel, table=True): +class RelationshipDict(TypedDict): + target_id: int | None + type: RelationshipType + id: NotRequired[int | None] + user_id: NotRequired[int | None] + mutual: NotRequired[bool] + target: NotRequired["UserDict"] + + +class RelationshipModel(DatabaseModel[RelationshipDict]): __tablename__: str = "relationship" id: int | None = Field( default=None, @@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True): ForeignKey("lazer_users.id"), index=True, ), + exclude=True, ) target_id: int = Field( default=None, @@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True): ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) - target: User = SQLRelationship( - sa_relationship_kwargs={ - "foreign_keys": "[Relationship.target_id]", - "lazy": "selectin", - } - ) - -class RelationshipResp(BaseModel): - target_id: int - target: UserResp - mutual: bool = False - type: RelationshipType - - @classmethod - async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp": + @included + @staticmethod + async def mutual(session: AsyncSession, relationship: "Relationship") -> bool: target_relationship = ( await session.exec( select(Relationship).where( @@ -68,23 +70,29 @@ class RelationshipResp(BaseModel): ) ) ).first() - mutual = bool( + return bool( target_relationship is not None and relationship.type == RelationshipType.FOLLOW and target_relationship.type == RelationshipType.FOLLOW ) - return cls( - target_id=relationship.target_id, - target=await UserResp.from_db( - relationship.target, - session, - include=[ - "team", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - ], - ), - mutual=mutual, - type=relationship.type, - ) + + @ondemand + @staticmethod + async def target( + _session: AsyncSession, + relationship: "Relationship", + ruleset: GameMode | None = None, + includes: list[str] | None = None, + ) -> "UserDict": + from .user import UserModel + + return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes) + + +class Relationship(RelationshipModel, table=True): + target: "User" = SQLRelationship( + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } + ) diff --git a/app/database/room.py b/app/database/room.py index add0dbc..afa609b 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -1,8 +1,6 @@ from datetime import datetime +from typing import ClassVar, NotRequired, TypedDict -from app.database.item_attempts_count import PlaylistAggregateScore -from app.database.room_participated_user import RoomParticipatedUser -from app.models.model import UTCBaseModel from app.models.room import ( MatchType, QueueMode, @@ -13,28 +11,58 @@ from app.models.room import ( ) from app.utils import utcnow -from .playlists import Playlist, PlaylistResp -from .user import User, UserResp +from ._base import DatabaseModel, included, ondemand +from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel +from .playlists import Playlist, PlaylistDict, PlaylistModel +from .room_participated_user import RoomParticipatedUser +from .user import User, UserDict, UserModel +from pydantic import field_validator from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlmodel import ( - BigInteger, - Column, - DateTime, - Field, - ForeignKey, - Relationship, - SQLModel, - col, - func, - select, -) +from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select from sqlmodel.ext.asyncio.session import AsyncSession -class RoomBase(SQLModel, UTCBaseModel): +class RoomDict(TypedDict): + id: int + name: str + category: RoomCategory + status: RoomStatus + type: MatchType + duration: int | None + starts_at: datetime | None + ends_at: datetime | None + max_attempts: int | None + participant_count: int + channel_id: int + queue_mode: QueueMode + auto_skip: bool + auto_start_duration: int + has_password: NotRequired[bool] + current_playlist_item: NotRequired["PlaylistDict | None"] + playlist: NotRequired[list["PlaylistDict"]] + playlist_item_stats: NotRequired[RoomPlaylistItemStats] + difficulty_range: NotRequired[RoomDifficultyRange] + host: NotRequired[UserDict] + recent_participants: NotRequired[list[UserDict]] + current_user_score: NotRequired["ItemAttemptsCountDict | None"] + + +class RoomModel(DatabaseModel[RoomDict]): + SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [ + "current_user_score.playlist_item_attempts", + "host.country", + "playlist.beatmap.beatmapset", + "playlist.beatmap.checksum", + "playlist.beatmap.max_combo", + "recent_participants", + ] + + id: int = Field(default=None, primary_key=True, index=True) name: str = Field(index=True) category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) + status: RoomStatus + type: MatchType duration: int | None = Field(default=None) # minutes starts_at: datetime | None = Field( sa_column=Column( @@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel): ), default=None, ) - participant_count: int = Field(default=0) max_attempts: int | None = Field(default=None) # playlists - type: MatchType + participant_count: int = Field(default=0) + channel_id: int = 0 queue_mode: QueueMode auto_skip: bool + auto_start_duration: int - status: RoomStatus - channel_id: int | None = None - - -class Room(AsyncAttrs, RoomBase, table=True): - __tablename__: str = "rooms" - id: int = Field(default=None, primary_key=True, index=True) - host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) - password: str | None = Field(default=None) - - host: User = Relationship() - playlist: list[Playlist] = Relationship( - sa_relationship_kwargs={ - "lazy": "selectin", - "cascade": "all, delete-orphan", - "overlaps": "room", - } - ) - - -class RoomResp(RoomBase): - id: int - has_password: bool = False - host: UserResp | None = None - playlist: list[PlaylistResp] = [] - playlist_item_stats: RoomPlaylistItemStats | None = None - difficulty_range: RoomDifficultyRange | None = None - current_playlist_item: PlaylistResp | None = None - current_user_score: PlaylistAggregateScore | None = None - recent_participants: list[UserResp] = Field(default_factory=list) - channel_id: int = 0 + @field_validator("channel_id", mode="before") @classmethod - async def from_db( - cls, - room: Room, - session: AsyncSession, - include: list[str] = [], - user: User | None = None, - ) -> "RoomResp": - d = room.model_dump() - d["channel_id"] = d.get("channel_id", 0) or 0 - d["has_password"] = bool(room.password) - resp = cls.model_validate(d) + def validate_channel_id(cls, v): + """将 None 转换为 0""" + if v is None: + return 0 + return v - stats = RoomPlaylistItemStats(count_active=0, count_total=0) - difficulty_range = RoomDifficultyRange( - min=0, - max=0, - ) - rulesets = set() - for playlist in room.playlist: + @included + @staticmethod + async def has_password(_session: AsyncSession, room: "Room") -> bool: + return bool(room.password) + + @ondemand + @staticmethod + async def current_playlist_item( + _session: AsyncSession, room: "Room", includes: list[str] | None = None + ) -> "PlaylistDict | None": + playlists = await room.awaitable_attrs.playlist + if not playlists: + return None + return await PlaylistModel.transform(playlists[-1], includes=includes) + + @ondemand + @staticmethod + async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]: + playlists = await room.awaitable_attrs.playlist + result: list[PlaylistDict] = [] + for playlist_item in playlists: + result.append(await PlaylistModel.transform(playlist_item, includes=includes)) + return result + + @ondemand + @staticmethod + async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats: + playlists = await room.awaitable_attrs.playlist + stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[]) + rulesets: set[int] = set() + for playlist in playlists: stats.count_total += 1 if not playlist.expired: stats.count_active += 1 rulesets.add(playlist.ruleset_id) - difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating) - difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating) - resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) stats.ruleset_ids = list(rulesets) - resp.playlist_item_stats = stats - resp.difficulty_range = difficulty_range - resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None - resp.recent_participants = [] + return stats + + @ondemand + @staticmethod + async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange: + playlists = await room.awaitable_attrs.playlist + if not playlists: + return RoomDifficultyRange(min=0.0, max=0.0) + min_diff = float("inf") + max_diff = float("-inf") + for playlist in playlists: + rating = playlist.beatmap.difficulty_rating + min_diff = min(min_diff, rating) + max_diff = max(max_diff, rating) + if min_diff == float("inf"): + min_diff = 0.0 + if max_diff == float("-inf"): + max_diff = 0.0 + return RoomDifficultyRange(min=min_diff, max=max_diff) + + @ondemand + @staticmethod + async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict: + host_user = await room.awaitable_attrs.host + return await UserModel.transform(host_user, includes=includes) + + @ondemand + @staticmethod + async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]: + participants: list[UserDict] = [] if room.category == RoomCategory.REALTIME: query = ( select(RoomParticipatedUser) @@ -137,39 +177,67 @@ class RoomResp(RoomBase): .limit(8) .order_by(col(RoomParticipatedUser.joined_at).desc()) ) - resp.participant_count = ( - await session.exec( - select(func.count()) - .select_from(RoomParticipatedUser) - .where( - RoomParticipatedUser.room_id == room.id, - ) - ) - ).first() or 0 for recent_participant in await session.exec(query): - resp.recent_participants.append( - await UserResp.from_db( - await recent_participant.awaitable_attrs.user, - session, - include=["statistics"], + user_instance = await recent_participant.awaitable_attrs.user + participants.append(await UserModel.transform(user_instance)) + return participants + + @ondemand + @staticmethod + async def current_user_score( + session: AsyncSession, room: "Room", includes: list[str] | None = None + ) -> "ItemAttemptsCountDict | None": + item_attempt = ( + await session.exec( + select(ItemAttemptsCount).where( + ItemAttemptsCount.room_id == room.id, ) ) - resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"]) - if "current_user_score" in include and user: - resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session) - return resp + ).first() + if item_attempt is None: + return None + + return await ItemAttemptsCountModel.transform(item_attempt, includes=includes) -class APIUploadedRoom(RoomBase): - def to_room(self) -> Room: - """ - 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。 - """ - room_dict = self.model_dump() - room_dict.pop("playlist", None) - # host_id 已在字段中 - return Room(**room_dict) +class Room(AsyncAttrs, RoomModel, table=True): + __tablename__: str = "rooms" - id: int | None - host_id: int | None = None + host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + password: str | None = Field(default=None) + + host: User = Relationship() + playlist: list[Playlist] = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + "cascade": "all, delete-orphan", + "overlaps": "room", + } + ) + + +class APIUploadedRoom(SQLModel): + name: str = Field(index=True) + category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) + status: RoomStatus + type: MatchType + duration: int | None = Field(default=None) # minutes + starts_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default_factory=utcnow, + ) + ends_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=None, + ) + max_attempts: int | None = Field(default=None) # playlists + participant_count: int = Field(default=0) + channel_id: int = 0 + queue_mode: QueueMode + auto_skip: bool + auto_start_duration: int playlist: list[Playlist] = Field(default_factory=list) diff --git a/app/database/score.py b/app/database/score.py index 49bdecd..6bb4a28 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -3,7 +3,7 @@ from datetime import date, datetime import json import math import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict from app.calculator import ( calculate_pp_weight, @@ -15,8 +15,6 @@ from app.calculator import ( pre_fetch_and_calculate_pp, ) from app.config import settings -from app.database.beatmap_playcounts import BeatmapPlaycounts -from app.database.team import TeamMember from app.dependencies.database import get_redis from app.log import log from app.models.beatmap import BeatmapRankStatus @@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode from app.storage import StorageService from app.utils import utcnow -from .beatmap import Beatmap, BeatmapResp -from .beatmapset import BeatmapsetResp +from ._base import DatabaseModel, OnDemand, included, ondemand +from .beatmap import Beatmap, BeatmapDict, BeatmapModel +from .beatmap_playcounts import BeatmapPlaycounts +from .beatmapset import BeatmapsetDict, BeatmapsetModel from .best_scores import BestScore from .counts import MonthlyPlaycounts from .events import Event, EventType @@ -50,8 +50,9 @@ from .relationship import ( RelationshipType, ) from .score_token import ScoreToken +from .team import TeamMember from .total_score_best_scores import TotalScoreBestScore -from .user import User, UserResp +from .user import User, UserDict, UserModel from pydantic import BaseModel, field_serializer, field_validator from redis.asyncio import Redis @@ -80,30 +81,290 @@ if TYPE_CHECKING: logger = log("Score") -class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): - # 基本字段 +class ScoreDict(TypedDict): + beatmap_id: int + id: int + rank: Rank + type: str + user_id: int + accuracy: float + build_id: int | None + ended_at: datetime + has_replay: bool + max_combo: int + passed: bool + pp: float + started_at: datetime + total_score: int + maximum_statistics: ScoreStatistics + mods: list[APIMod] + classic_total_score: int | None + preserve: bool + processed: bool + ranked: bool + playlist_item_id: NotRequired[int | None] + room_id: NotRequired[int | None] + best_id: NotRequired[int | None] + legacy_perfect: NotRequired[bool] + is_perfect_combo: NotRequired[bool] + ruleset_id: NotRequired[int] + statistics: NotRequired[ScoreStatistics] + beatmapset: NotRequired[BeatmapsetDict] + beatmap: NotRequired[BeatmapDict] + current_user_attributes: NotRequired[CurrentUserAttributes] + position: NotRequired[int | None] + scores_around: NotRequired["ScoreAround | None"] + rank_country: NotRequired[int | None] + rank_global: NotRequired[int | None] + user: NotRequired[UserDict] + weight: NotRequired[float | None] + + # ScoreResp 字段 + legacy_total_score: NotRequired[int] + + +class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]): + # https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84 + MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"] + MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [ + "user.country", + "user.cover", + "user.team", + *MULTIPLAYER_SCORE_INCLUDE, + ] + USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"] + + # 基本字段 + beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") + id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)) + rank: Rank + type: str + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) accuracy: float - map_md5: str = Field(max_length=32, index=True) build_id: int | None = Field(default=None) - classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score ended_at: datetime = Field(sa_column=Column(DateTime)) has_replay: bool = Field(sa_column=Column(Boolean)) max_combo: int - mods: list[APIMod] = Field(sa_column=Column(JSON)) passed: bool = Field(sa_column=Column(Boolean)) - playlist_item_id: int | None = Field(default=None) # multiplayer pp: float = Field(default=0.0) - preserve: bool = Field(default=True, sa_column=Column(Boolean)) - rank: Rank - room_id: int | None = Field(default=None) # multiplayer started_at: datetime = Field(sa_column=Column(DateTime)) total_score: int = Field(default=0, sa_column=Column(BigInteger)) - total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True) - type: str - beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict) - processed: bool = False # solo_score - ranked: bool = False + mods: list[APIMod] = Field(sa_column=Column(JSON)) + total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True) + + # solo + classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) + preserve: bool = Field(default=True, sa_column=Column(Boolean)) + processed: bool = Field(default=False) + ranked: bool = Field(default=False) + + # multiplayer + playlist_item_id: OnDemand[int | None] = Field(default=None) + room_id: OnDemand[int | None] = Field(default=None) + + @included + @staticmethod + async def best_id( + session: AsyncSession, + score: "Score", + ) -> int | None: + return await get_best_id(session, score.id) + + @included + @staticmethod + async def legacy_perfect( + _session: AsyncSession, + score: "Score", + ) -> bool: + await score.awaitable_attrs.beatmap + return score.max_combo == score.beatmap.max_combo + + @included + @staticmethod + async def is_perfect_combo( + _session: AsyncSession, + score: "Score", + ) -> bool: + await score.awaitable_attrs.beatmap + return score.max_combo == score.beatmap.max_combo + + @included + @staticmethod + async def ruleset_id( + _session: AsyncSession, + score: "Score", + ) -> int: + return int(score.gamemode) + + @included + @staticmethod + async def statistics( + _session: AsyncSession, + score: "Score", + ) -> ScoreStatistics: + stats = { + HitResult.MISS: score.nmiss, + HitResult.MEH: score.n50, + HitResult.OK: score.n100, + HitResult.GREAT: score.n300, + HitResult.PERFECT: score.ngeki, + HitResult.GOOD: score.nkatu, + } + if score.nlarge_tick_miss is not None: + stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss + if score.nslider_tail_hit is not None: + stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit + if score.nsmall_tick_hit is not None: + stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit + if score.nlarge_tick_hit is not None: + stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit + return stats + + @ondemand + @staticmethod + async def beatmapset( + _session: AsyncSession, + score: "Score", + includes: list[str] | None = None, + ) -> BeatmapsetDict: + await score.awaitable_attrs.beatmap + return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes) + + # reorder beatmapset and beatmap + # https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112 + @ondemand + @staticmethod + async def beatmap( + _session: AsyncSession, + score: "Score", + includes: list[str] | None = None, + ) -> BeatmapDict: + await score.awaitable_attrs.beatmap + return await BeatmapModel.transform(score.beatmap, includes=includes) + + @ondemand + @staticmethod + async def current_user_attributes( + _session: AsyncSession, + score: "Score", + ) -> CurrentUserAttributes: + return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)) + + @ondemand + @staticmethod + async def position( + session: AsyncSession, + score: "Score", + ) -> int | None: + return await get_score_position_by_id( + session, + score.beatmap_id, + score.id, + mode=score.gamemode, + user=score.user, + ) + + @ondemand + @staticmethod + async def scores_around( + session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool + ) -> "ScoreAround | None": + scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ~User.is_restricted_query(col(PlaylistBestScore.user_id)), + col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True, + ) + ) + ).all() + + higher_scores = [] + lower_scores = [] + for score in scores: + total_score = score.score.total_score + resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES) + if score.total_score > total_score: + higher_scores.append(resp) + elif score.total_score < total_score: + lower_scores.append(resp) + + return ScoreAround( + higher=MultiplayerScores(scores=higher_scores), + lower=MultiplayerScores(scores=lower_scores), + ) + + @ondemand + @staticmethod + async def rank_country( + session: AsyncSession, + score: "Score", + ) -> int | None: + return ( + await get_score_position_by_id( + session, + score.beatmap_id, + score.id, + score.gamemode, + score.user, + type=LeaderboardType.COUNTRY, + ) + or None + ) + + @ondemand + @staticmethod + async def rank_global( + session: AsyncSession, + score: "Score", + ) -> int | None: + return ( + await get_score_position_by_id( + session, + score.beatmap_id, + score.id, + mode=score.gamemode, + user=score.user, + ) + or None + ) + + @ondemand + @staticmethod + async def user( + _session: AsyncSession, + score: "Score", + includes: list[str] | None = None, + ) -> UserDict: + return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or []) + + @ondemand + @staticmethod + async def weight( + session: AsyncSession, + score: "Score", + ) -> float | None: + best_id = await get_best_id(session, score.id) + if best_id: + return calculate_pp_weight(best_id - 1) + return None + + @ondemand + @staticmethod + async def legacy_total_score( + _session: AsyncSession, + _score: "Score", + ) -> int: + return 0 @field_validator("maximum_statistics", mode="before") @classmethod @@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # TODO: current_user_attributes -class Score(ScoreBase, table=True): +class Score(ScoreModel, table=True): __tablename__: str = "scores" - id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)) - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("lazer_users.id"), - index=True, - ), - ) + # ScoreStatistics n300: int = Field(exclude=True) n100: int = Field(exclude=True) @@ -175,6 +428,7 @@ class Score(ScoreBase, table=True): nsmall_tick_hit: int | None = Field(default=None, exclude=True) gamemode: GameMode = Field(index=True) pinned_order: int = Field(default=0, exclude=True) + map_md5: str = Field(max_length=32, index=True, exclude=True) @field_validator("gamemode", mode="before") @classmethod @@ -245,9 +499,11 @@ class Score(ScoreBase, table=True): maximum_statistics=self.maximum_statistics, ) - async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp": + async def to_resp( + self, session: AsyncSession, api_version: int, includes: list[str] = [] + ) -> "ScoreDict | LegacyScoreResp": if api_version >= 20220705: - return await ScoreResp.from_db(session, self) + return await ScoreModel.transform(self, includes=includes) return await LegacyScoreResp.from_db(session, self) async def delete( @@ -270,141 +526,7 @@ class Score(ScoreBase, table=True): await session.delete(self) -class ScoreResp(ScoreBase): - id: int - user_id: int - is_perfect_combo: bool = False - legacy_perfect: bool = False - legacy_total_score: int = 0 # FIXME - weight: float = 0.0 - best_id: int | None = None - ruleset_id: int | None = None - beatmap: BeatmapResp | None = None - beatmapset: BeatmapsetResp | None = None - user: UserResp | None = None - statistics: ScoreStatistics | None = None - maximum_statistics: ScoreStatistics | None = None - rank_global: int | None = None - rank_country: int | None = None - position: int | None = None - scores_around: "ScoreAround | None" = None - current_user_attributes: CurrentUserAttributes | None = None - - @field_validator( - "has_replay", - "passed", - "preserve", - "is_perfect_combo", - "legacy_perfect", - "processed", - "ranked", - mode="before", - ) - @classmethod - def validate_bool_fields(cls, v): - """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" - if isinstance(v, int): - return bool(v) - return v - - @field_validator("statistics", "maximum_statistics", mode="before") - @classmethod - def validate_statistics_fields(cls, v): - """处理统计字段中的字符串键,转换为 HitResult 枚举""" - if isinstance(v, dict): - converted = {} - for key, value in v.items(): - if isinstance(key, str): - try: - # 尝试将字符串转换为 HitResult 枚举 - enum_key = HitResult(key) - converted[enum_key] = value - except ValueError: - # 如果转换失败,跳过这个键值对 - continue - else: - converted[key] = value - return converted - return v - - @field_serializer("statistics", when_used="json") - def serialize_statistics_fields(self, v): - """序列化统计字段,确保枚举值正确转换为字符串""" - if isinstance(v, dict): - serialized = {} - for key, value in v.items(): - if hasattr(key, "value"): - # 如果是枚举,使用其值 - serialized[key.value] = value - else: - # 否则直接使用键 - serialized[str(key)] = value - return serialized - return v - - @classmethod - async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": - # 确保 score 对象完全加载,避免懒加载问题 - await session.refresh(score) - - s = cls.model_validate(score.model_dump()) - await score.awaitable_attrs.beatmap - s.beatmap = await BeatmapResp.from_db(score.beatmap) - s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user) - s.is_perfect_combo = s.max_combo == s.beatmap.max_combo - s.legacy_perfect = s.max_combo == s.beatmap.max_combo - s.ruleset_id = int(score.gamemode) - best_id = await get_best_id(session, score.id) - if best_id: - s.best_id = best_id - s.weight = calculate_pp_weight(best_id - 1) - s.statistics = { - HitResult.MISS: score.nmiss, - HitResult.MEH: score.n50, - HitResult.OK: score.n100, - HitResult.GREAT: score.n300, - HitResult.PERFECT: score.ngeki, - HitResult.GOOD: score.nkatu, - } - if score.nlarge_tick_miss is not None: - s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss - if score.nslider_tail_hit is not None: - s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit - if score.nsmall_tick_hit is not None: - s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit - if score.nlarge_tick_hit is not None: - s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit - s.user = await UserResp.from_db( - score.user, - session, - include=["statistics", "team", "daily_challenge_user_stats"], - ruleset=score.gamemode, - ) - s.rank_global = ( - await get_score_position_by_id( - session, - score.beatmap_id, - score.id, - mode=score.gamemode, - user=score.user, - ) - or None - ) - s.rank_country = ( - await get_score_position_by_id( - session, - score.beatmap_id, - score.id, - score.gamemode, - score.user, - type=LeaderboardType.COUNTRY, - ) - or None - ) - s.current_user_attributes = CurrentUserAttributes( - pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id) - ) - return s +MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues] class LegacyStatistics(BaseModel): @@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel): class LegacyScoreResp(UTCBaseModel): - accuracy: float - best_id: int - created_at: datetime id: int - max_combo: int - mode: GameMode - mode_int: int + best_id: int + user_id: int + accuracy: float mods: list[str] # acronym - passed: bool + score: int + max_combo: int perfect: bool = False + statistics: LegacyStatistics + passed: bool pp: float rank: Rank + created_at: datetime + mode: GameMode + mode_int: int replay: bool - score: int - statistics: LegacyStatistics - type: str - user_id: int - current_user_attributes: CurrentUserAttributes - user: UserResp - beatmap: BeatmapResp - rank_global: int | None = Field(default=None, exclude=True) @classmethod - async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp": - await session.refresh(score) + async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp": await score.awaitable_attrs.beatmap return cls( accuracy=score.accuracy, @@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel): count_geki=score.ngeki or 0, count_katu=score.nkatu or 0, ), - type=score.type, user_id=score.user_id, - current_user_attributes=CurrentUserAttributes( - pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id) - ), - user=await UserResp.from_db( - score.user, - session, - include=["statistics", "team", "daily_challenge_user_stats"], - ruleset=score.gamemode, - ), - beatmap=await BeatmapResp.from_db(score.beatmap), perfect=score.is_perfect_combo, - rank_global=( - await get_score_position_by_id( - session, - score.beatmap_id, - score.id, - mode=score.gamemode, - user=score.user, - ) - or None - ), ) class MultiplayerScores(RespWithCursor): - scores: list[ScoreResp] = Field(default_factory=list) + scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm] params: dict[str, Any] = Field(default_factory=dict) @@ -842,13 +937,13 @@ async def get_user_best_pp( # https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs -def get_play_length(score: Score, beatmap_length: int): +def get_play_length(score: "Score", beatmap_length: int): speed_rate = get_speed_rate(score.mods) length = beatmap_length / speed_rate return int(min(length, (score.ended_at - score.started_at).total_seconds())) -def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]: +def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]: total_length = get_play_length(score, beatmap_length) total_obj_hited = ( score.n300 @@ -937,7 +1032,7 @@ async def process_score( return score -async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"): +async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"): if score.pp != 0: logger.debug( "Skipping PP calculation for score {score_id} | already set {pp:.2f}", @@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f ) -async def _process_score_events(score: Score, session: AsyncSession): +async def _process_score_events(score: "Score", session: AsyncSession): total_users = (await session.exec(select(func.count()).select_from(User))).one() rank_global = await get_score_position_by_id( session, @@ -1088,7 +1183,7 @@ async def _process_statistics( session: AsyncSession, redis: Redis, user: User, - score: Score, + score: "Score", score_token: int, beatmap_length: int, beatmap_status: BeatmapRankStatus, @@ -1318,7 +1413,7 @@ async def process_user( redis: Redis, fetcher: "Fetcher", user: User, - score: Score, + score: "Score", score_token: int, beatmap_length: int, beatmap_status: BeatmapRankStatus, diff --git a/app/database/search_beatmapset.py b/app/database/search_beatmapset.py new file mode 100644 index 0000000..c680175 --- /dev/null +++ b/app/database/search_beatmapset.py @@ -0,0 +1,13 @@ +from . import beatmap # noqa: F401 +from .beatmapset import BeatmapsetModel + +from sqlmodel import SQLModel + +SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags")) + + +class SearchBeatmapsetsResp(SQLModel): + beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm] + total: int + cursor: dict[str, int | float | str] | None = None + cursor_string: str | None = None diff --git a/app/database/statistics.py b/app/database/statistics.py index 3d0b1dd..a1c90d9 100644 --- a/app/database/statistics.py +++ b/app/database/statistics.py @@ -1,10 +1,11 @@ from datetime import timedelta import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict from app.models.score import GameMode from app.utils import utcnow +from ._base import DatabaseModel, included, ondemand from .rank_history import RankHistory from pydantic import field_validator @@ -15,7 +16,6 @@ from sqlmodel import ( Field, ForeignKey, Relationship, - SQLModel, col, func, select, @@ -23,10 +23,40 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from .user import User, UserResp + from .user import User, UserDict -class UserStatisticsBase(SQLModel): +class UserStatisticsDict(TypedDict): + mode: GameMode + count_100: int + count_300: int + count_50: int + count_miss: int + pp: float + ranked_score: int + hit_accuracy: float + total_score: int + total_hits: int + maximum_combo: int + play_count: int + play_time: int + replays_watched_by_others: int + is_ranked: bool + level: NotRequired[dict[str, int]] + global_rank: NotRequired[int | None] + grade_counts: NotRequired[dict[str, int]] + rank_change_since_30_days: NotRequired[int] + country_rank: NotRequired[int | None] + user: NotRequired["UserDict"] + + +class UserStatisticsModel(DatabaseModel[UserStatisticsDict]): + RANKING_INCLUDES: ClassVar[list[str]] = [ + "user.country", + "user.cover", + "user.team", + ] + mode: GameMode = Field(index=True) count_100: int = Field(default=0, sa_column=Column(BigInteger)) count_300: int = Field(default=0, sa_column=Column(BigInteger)) @@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel): return GameMode.OSU return v + @included + @staticmethod + async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]: + return { + "current": int(statistics.level_current), + "progress": int(math.fmod(statistics.level_current, 1) * 100), + } -class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True): + @included + @staticmethod + async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None: + return await get_rank(session, statistics) + + @included + @staticmethod + async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]: + return { + "ss": statistics.grade_ss, + "ssh": statistics.grade_ssh, + "s": statistics.grade_s, + "sh": statistics.grade_sh, + "a": statistics.grade_a, + } + + @ondemand + @staticmethod + async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int: + global_rank = await get_rank(session, statistics) + rank_best = ( + await session.exec( + select(func.max(RankHistory.rank)).where( + RankHistory.date > utcnow() - timedelta(days=30), + RankHistory.user_id == statistics.user_id, + ) + ) + ).first() + if rank_best is None or global_rank is None: + return 0 + return rank_best - global_rank + + @ondemand + @staticmethod + async def country_rank( + session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None + ) -> int | None: + return await get_rank(session, statistics, user_country) + + @ondemand + @staticmethod + async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict": + from .user import UserModel + + user_instance = await statistics.awaitable_attrs.user + return await UserModel.transform(user_instance) + + +class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True): __tablename__: str = "lazer_user_statistics" id: int | None = Field(default=None, primary_key=True) user_id: int = Field( @@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True): user: "User" = Relationship(back_populates="statistics") -class UserStatisticsResp(UserStatisticsBase): - user: "UserResp | None" = None - rank_change_since_30_days: int | None = 0 - global_rank: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - grade_counts: dict[str, int] = Field( - default_factory=lambda: { - "ss": 0, - "ssh": 0, - "s": 0, - "sh": 0, - "a": 0, - } - ) - level: dict[str, int] = Field( - default_factory=lambda: { - "current": 1, - "progress": 0, - } - ) - - @classmethod - async def from_db( - cls, - obj: UserStatistics, - session: AsyncSession, - user_country: str | None = None, - include: list[str] = [], - ) -> "UserStatisticsResp": - s = cls.model_validate(obj.model_dump()) - s.grade_counts = { - "ss": obj.grade_ss, - "ssh": obj.grade_ssh, - "s": obj.grade_s, - "sh": obj.grade_sh, - "a": obj.grade_a, - } - s.level = { - "current": int(obj.level_current), - "progress": int(math.fmod(obj.level_current, 1) * 100), - } - if "user" in include: - from .user import RANKING_INCLUDES, UserResp - - user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES) - s.user = user - user_country = user.country_code - - s.global_rank = await get_rank(session, obj) - s.country_rank = await get_rank(session, obj, user_country) - - if "rank_change_since_30_days" in include: - rank_best = ( - await session.exec( - select(func.max(RankHistory.rank)).where( - RankHistory.date > utcnow() - timedelta(days=30), - RankHistory.user_id == obj.user_id, - ) - ) - ).first() - if rank_best is None or s.global_rank is None: - s.rank_change_since_30_days = 0 - else: - s.rank_change_since_30_days = rank_best - s.global_rank - - return s - - async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None: from .user import User @@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s query = query.join(User).where(User.country_code == country) subq = query.subquery() - result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id)) rank = result.first() diff --git a/app/database/total_score_best_scores.py b/app/database/total_score_best_scores.py index 1ab1d5a..f2a8f2d 100644 --- a/app/database/total_score_best_scores.py +++ b/app/database/total_score_best_scores.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING from app.calculator import calculate_score_to_level -from app.database.statistics import UserStatistics from app.models.score import GameMode, Rank +from .statistics import UserStatistics from .user import User from sqlmodel import ( diff --git a/app/database/user.py b/app/database/user.py index fd40c35..2f16390 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -1,22 +1,25 @@ from datetime import datetime, timedelta import json -from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, overload +from typing import TYPE_CHECKING, ClassVar, Literal, NotRequired, TypedDict, overload from app.config import settings -from app.models.model import UTCBaseModel +from app.models.notification import NotificationName from app.models.score import GameMode from app.models.user import Country, Page from app.path import STATIC_DIR from app.utils import utcnow +from ._base import DatabaseModel, OnDemand, included, ondemand from .achievement import UserAchievement, UserAchievementResp from .auth import TotpKeys from .beatmap_playcounts import BeatmapPlaycounts from .counts import CountResp, MonthlyPlaycounts, ReplayWatchedCount from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp from .events import Event +from .notification import Notification, UserNotification from .rank_history import RankHistory, RankHistoryResp, RankTop -from .statistics import UserStatistics, UserStatisticsResp +from .relationship import RelationshipModel +from .statistics import UserStatistics, UserStatisticsModel from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp, UserAccountHistoryType from .user_preference import DEFAULT_ORDER, UserPreference @@ -31,7 +34,6 @@ from sqlmodel import ( DateTime, Field, Relationship, - SQLModel, col, exists, func, @@ -43,7 +45,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .favourite_beatmapset import FavouriteBeatmapset from .matchmaking import MatchmakingUserStats - from .relationship import RelationshipResp + from .relationship import Relationship, RelationshipDict + from .statistics import UserStatisticsDict class Kudosu(TypedDict): @@ -76,10 +79,194 @@ Badge = TypedDict( COUNTRIES = json.loads((STATIC_DIR / "iso3166.json").read_text()) -class UserBase(UTCBaseModel, SQLModel): +class UserDict(TypedDict): + avatar_url: str + country_code: str + id: int + is_active: bool + is_bot: bool + is_supporter: bool + last_visit: datetime | None + pm_friends_only: bool + profile_colour: str | None + username: str + g0v0_playmode: GameMode + page: NotRequired[Page] + previous_usernames: NotRequired[list[str]] + support_level: NotRequired[int] + badges: NotRequired[list[Badge]] + cover: NotRequired[UserProfileCover] + beatmap_playcounts_count: NotRequired[int] + playmode: NotRequired[GameMode] + discord: NotRequired[str | None] + has_supported: NotRequired[bool] + interests: NotRequired[str | None] + join_date: NotRequired[datetime] + location: NotRequired[str | None] + max_blocks: NotRequired[int] + max_friends: NotRequired[int] + occupation: NotRequired[str | None] + playstyle: NotRequired[list[str]] + profile_hue: NotRequired[int | None] + title: NotRequired[str | None] + title_url: NotRequired[str | None] + twitter: NotRequired[str | None] + website: NotRequired[str | None] + comments_count: NotRequired[int] + post_count: NotRequired[int] + is_admin: NotRequired[bool] + is_gmt: NotRequired[bool] + is_qat: NotRequired[bool] + is_bng: NotRequired[bool] + groups: NotRequired[list[str]] + active_tournament_banners: NotRequired[list[dict]] + graveyard_beatmapset_count: NotRequired[int] + loved_beatmapset_count: NotRequired[int] + mapping_follower_count: NotRequired[int] + nominated_beatmapset_count: NotRequired[int] + guest_beatmapset_count: NotRequired[int] + pending_beatmapset_count: NotRequired[int] + ranked_beatmapset_count: NotRequired[int] + follow_user_mapping: NotRequired[list[int]] + is_deleted: NotRequired[bool] + country: NotRequired[Country] + favourite_beatmapset_count: NotRequired[int] + follower_count: NotRequired[int] + scores_best_count: NotRequired[int] + scores_pinned_count: NotRequired[int] + scores_recent_count: NotRequired[int] + scores_first_count: NotRequired[int] + cover_url: NotRequired[str] + profile_order: NotRequired[list[str]] + user_preference: NotRequired[UserPreference | None] + friends: NotRequired[list["RelationshipDict"]] + team: NotRequired[Team | None] + account_history: NotRequired[list[UserAccountHistoryResp]] + daily_challenge_user_stats: NotRequired[DailyChallengeStatsResp | None] + statistics: NotRequired["UserStatisticsDict | None"] + statistics_rulesets: NotRequired[dict[str, "UserStatisticsDict"]] + monthly_playcounts: NotRequired[list[CountResp]] + replay_watched_counts: NotRequired[list[CountResp]] + user_achievements: NotRequired[list[UserAchievementResp]] + rank_history: NotRequired[RankHistoryResp | None] + rank_highest: NotRequired[RankHighest | None] + is_restricted: NotRequired[bool] + kudosu: NotRequired[Kudosu] + unread_pm_count: NotRequired[int] + default_group: NotRequired[str] + is_online: NotRequired[bool] + session_verified: NotRequired[bool] + session_verification_method: NotRequired[Literal["totp", "mail"] | None] + + +class UserModel(DatabaseModel[UserDict]): + # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L22-L39 + CARD_INCLUDES: ClassVar[list[str]] = [ + "country", + "cover", + "groups", + "team", + ] + LIST_INCLUDES: ClassVar[list[str]] = [ + *CARD_INCLUDES, + "statistics", + "support_level", + ] + + # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserTransformer.php#L36-L53 + USER_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [ + "cover_url", + "discord", + "has_supported", + "interests", + "join_date", + "location", + "max_blocks", + "max_friends", + "occupation", + "playmode", + "playstyle", + "post_count", + "profile_hue", + "profile_order", + "title", + "title_url", + "twitter", + "website", + # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserTransformer.php#L13C22-L25 + "cover", + "country", + "is_admin", + "is_bng", + "is_full_bn", + "is_gmt", + "is_limited_bn", + "is_moderator", + "is_nat", + "is_restricted", + "is_silenced", + "kudosu", + ] + + # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L41-L51 + PROFILE_HEADER_INCLUDES: ClassVar[list[str]] = [ + "active_tournament_banner", + "active_tournament_banners", + "badges", + "comments_count", + "follower_count", + "groups", + "mapping_follower_count", + "previous_usernames", + "support_level", + ] + + # https://github.com/ppy/osu-web/blob/3f08fe12d70bcac1e32455c31e984eb6ef589b42/app/Http/Controllers/UsersController.php#L900-L937 + USER_INCLUDES: ClassVar[list[str]] = [ + # == apiIncludes == + # historical + "beatmap_playcounts_count", + "monthly_playcounts", + "replays_watched_counts", + "scores_recent_count", + # beatmapsets + "favourite_beatmapset_count", + "graveyard_beatmapset_count", + "guest_beatmapset_count", + "loved_beatmapset_count", + "nominated_beatmapset_count", + "pending_beatmapset_count", + "ranked_beatmapset_count", + # top scores + "scores_best_count", + "scores_first_count", + "scores_pinned_count", + # others + "account_history", + "current_season_stats", + "daily_challenge_user_stats", + "page", + "pending_beatmapset_count", + "rank_highest", + "rank_history", + "statistics", + "statistics.country_rank", + "statistics.rank", + "statistics.variants", + "team", + "user_achievements", + *PROFILE_HEADER_INCLUDES, + *USER_TRANSFORMER_INCLUDES, + ] + + # https://github.com/ppy/osu-web/blob/d0407b1f2846dfd8b85ec0cf20e3fe3028a7b486/app/Transformers/UserCompactTransformer.php#L133-L150 avatar_url: str = "https://lazer-data.g0v0.top/default.jpg" country_code: str = Field(default="CN", max_length=2, index=True) # ? default_group: str|None + id: int = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) is_active: bool = True is_bot: bool = False is_supporter: bool = False @@ -87,45 +274,45 @@ class UserBase(UTCBaseModel, SQLModel): pm_friends_only: bool = False profile_colour: str | None = None username: str = Field(max_length=32, unique=True, index=True) - page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw="")) - previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON)) - support_level: int = 0 - badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON)) + + page: OnDemand[Page] = Field(sa_column=Column(JSON), default=Page(html="", raw="")) + previous_usernames: OnDemand[list[str]] = Field(default_factory=list, sa_column=Column(JSON)) + support_level: OnDemand[int] = Field(default=0) + badges: OnDemand[list[Badge]] = Field(default_factory=list, sa_column=Column(JSON)) # optional # blocks - cover: UserProfileCover = Field( + cover: OnDemand[UserProfileCover] = Field( default=UserProfileCover(url=""), sa_column=Column(JSON), ) - beatmap_playcounts_count: int = 0 # kudosu # UserExtended - playmode: GameMode = GameMode.OSU - discord: str | None = None - has_supported: bool = False - interests: str | None = None - join_date: datetime = Field(default_factory=utcnow) - location: str | None = None - max_blocks: int = 50 - max_friends: int = 500 - occupation: str | None = None - playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + playmode: OnDemand[GameMode] = Field(default=GameMode.OSU) + discord: OnDemand[str | None] = Field(default=None) + has_supported: OnDemand[bool] = Field(default=False) + interests: OnDemand[str | None] = Field(default=None) + join_date: OnDemand[datetime] = Field(default_factory=utcnow) + location: OnDemand[str | None] = Field(default=None) + max_blocks: OnDemand[int] = Field(default=50) + max_friends: OnDemand[int] = Field(default=500) + occupation: OnDemand[str | None] = Field(default=None) + playstyle: OnDemand[list[str]] = Field(default_factory=list, sa_column=Column(JSON)) # TODO: post_count - profile_hue: int | None = None - title: str | None = None - title_url: str | None = None - twitter: str | None = None - website: str | None = None + profile_hue: OnDemand[int | None] = Field(default=None) + title: OnDemand[str | None] = Field(default=None) + title_url: OnDemand[str | None] = Field(default=None) + twitter: OnDemand[str | None] = Field(default=None) + website: OnDemand[str | None] = Field(default=None) # undocumented - comments_count: int = 0 - post_count: int = 0 - is_admin: bool = False - is_gmt: bool = False - is_qat: bool = False - is_bng: bool = False + comments_count: OnDemand[int] = Field(default=0) + post_count: OnDemand[int] = Field(default=0) + is_admin: OnDemand[bool] = Field(default=False) + is_gmt: OnDemand[bool] = Field(default=False) + is_qat: OnDemand[bool] = Field(default=False) + is_bng: OnDemand[bool] = Field(default=False) # g0v0-extra g0v0_playmode: GameMode = GameMode.OSU @@ -142,14 +329,388 @@ class UserBase(UTCBaseModel, SQLModel): return GameMode.OSU return v + @ondemand + @staticmethod + async def groups(_session: AsyncSession, _obj: "User") -> list[str]: + return [] -class User(AsyncAttrs, UserBase, table=True): + @ondemand + @staticmethod + async def active_tournament_banners(_session: AsyncSession, _obj: "User") -> list[dict]: + return [] + + @ondemand + @staticmethod + async def graveyard_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def loved_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def mapping_follower_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def nominated_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def guest_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def pending_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def ranked_beatmapset_count(_session: AsyncSession, _obj: "User") -> int: + return 0 + + @ondemand + @staticmethod + async def follow_user_mapping(_session: AsyncSession, _obj: "User") -> list[int]: + return [] + + @ondemand + @staticmethod + async def is_deleted(_session: AsyncSession, _obj: "User") -> bool: + return False + + @ondemand + @staticmethod + async def country(_session: AsyncSession, obj: "User") -> Country: + return Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")) + + @ondemand + @staticmethod + async def favourite_beatmapset_count(session: AsyncSession, obj: "User") -> int: + from .favourite_beatmapset import FavouriteBeatmapset + + return ( + await session.exec( + select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id) + ) + ).one() + + @ondemand + @staticmethod + async def follower_count(session: AsyncSession, obj: "User") -> int: + from .relationship import Relationship, RelationshipType + + stmt = ( + select(func.count()) + .select_from(Relationship) + .where( + Relationship.target_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + return (await session.exec(stmt)).one() + + @ondemand + @staticmethod + async def scores_best_count( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> int: + from .best_scores import BestScore + + mode = ruleset or obj.playmode + stmt = ( + select(func.count()) + .select_from(BestScore) + .where( + BestScore.user_id == obj.id, + BestScore.gamemode == mode, + ) + .limit(200) + ) + return (await session.exec(stmt)).one() + + @ondemand + @staticmethod + async def scores_pinned_count( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> int: + from .score import Score + + mode = ruleset or obj.playmode + stmt = ( + select(func.count()) + .select_from(Score) + .where( + Score.user_id == obj.id, + Score.gamemode == mode, + Score.pinned_order > 0, + col(Score.passed).is_(True), + ) + ) + return (await session.exec(stmt)).one() + + @ondemand + @staticmethod + async def scores_recent_count( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> int: + from .score import Score + + mode = ruleset or obj.playmode + stmt = ( + select(func.count()) + .select_from(Score) + .where( + Score.user_id == obj.id, + Score.gamemode == mode, + col(Score.passed).is_(True), + Score.ended_at > utcnow() - timedelta(hours=24), + ) + ) + return (await session.exec(stmt)).one() + + @ondemand + @staticmethod + async def scores_first_count( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> int: + from .score import get_user_first_score_count + + mode = ruleset or obj.playmode + return await get_user_first_score_count(session, obj.id, mode) + + @ondemand + @staticmethod + async def beatmap_playcounts_count(session: AsyncSession, obj: "User") -> int: + stmt = select(func.count()).select_from(BeatmapPlaycounts).where(BeatmapPlaycounts.user_id == obj.id) + return (await session.exec(stmt)).one() + + @ondemand + @staticmethod + async def cover_url(_session: AsyncSession, obj: "User") -> str: + return obj.cover.get("url", "") if obj.cover else "" + + @ondemand + @staticmethod + async def profile_order(_session: AsyncSession, obj: "User") -> list[str]: + await obj.awaitable_attrs.user_preference + if obj.user_preference: + return list(obj.user_preference.extras_order) + return list(DEFAULT_ORDER) + + @ondemand + @staticmethod + async def user_preference(_session: AsyncSession, obj: "User") -> UserPreference | None: + await obj.awaitable_attrs.user_preference + return obj.user_preference + + @ondemand + @staticmethod + async def friends(session: AsyncSession, obj: "User") -> list["RelationshipDict"]: + from .relationship import Relationship, RelationshipType + + relationships = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + ).all() + return [await RelationshipModel.transform(rel, ruleset=obj.playmode) for rel in relationships] + + @ondemand + @staticmethod + async def team(_session: AsyncSession, obj: "User") -> Team | None: + membership = await obj.awaitable_attrs.team_membership + return membership.team if membership else None + + @ondemand + @staticmethod + async def account_history(_session: AsyncSession, obj: "User") -> list[UserAccountHistoryResp]: + await obj.awaitable_attrs.account_history + return [UserAccountHistoryResp.from_db(ah) for ah in obj.account_history] + + @ondemand + @staticmethod + async def daily_challenge_user_stats(_session: AsyncSession, obj: "User") -> DailyChallengeStatsResp | None: + stats = await obj.awaitable_attrs.daily_challenge_stats + return DailyChallengeStatsResp.from_db(stats) if stats else None + + @ondemand + @staticmethod + async def statistics( + _session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + includes: list[str] | None = None, + ) -> "UserStatisticsDict | None": + mode = ruleset or obj.playmode + for stat in await obj.awaitable_attrs.statistics: + if stat.mode == mode: + return await UserStatisticsModel.transform(stat, user_country=obj.country_code, includes=includes) + return None + + @ondemand + @staticmethod + async def statistics_rulesets( + _session: AsyncSession, + obj: "User", + includes: list[str] | None = None, + ) -> dict[str, "UserStatisticsDict"]: + stats = await obj.awaitable_attrs.statistics + result: dict[str, UserStatisticsDict] = {} + for stat in stats: + result[stat.mode.value] = await UserStatisticsModel.transform( + stat, user_country=obj.country_code, includes=includes + ) + return result + + @ondemand + @staticmethod + async def monthly_playcounts(_session: AsyncSession, obj: "User") -> list[CountResp]: + playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts] + if len(playcounts) == 1: + d = playcounts[0].start_date + playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) + return playcounts + + @ondemand + @staticmethod + async def replay_watched_counts(_session: AsyncSession, obj: "User") -> list[CountResp]: + counts = [CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts] + if len(counts) == 1: + d = counts[0].start_date + counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) + return counts + + @ondemand + @staticmethod + async def user_achievements(_session: AsyncSession, obj: "User") -> list[UserAchievementResp]: + return [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement] + + @ondemand + @staticmethod + async def rank_history( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> RankHistoryResp | None: + mode = ruleset or obj.playmode + rank_history = await RankHistoryResp.from_db(session, obj.id, mode) + return rank_history if len(rank_history.data) != 0 else None + + @ondemand + @staticmethod + async def rank_highest( + session: AsyncSession, + obj: "User", + ruleset: GameMode | None = None, + ) -> RankHighest | None: + mode = ruleset or obj.playmode + rank_top = (await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == mode))).first() + if not rank_top: + return None + return RankHighest( + rank=rank_top.rank, + updated_at=datetime.combine(rank_top.date, datetime.min.time()), + ) + + @ondemand + @staticmethod + async def is_restricted(session: AsyncSession, obj: "User") -> bool: + return await obj.is_restricted(session) + + @ondemand + @staticmethod + async def kudosu(_session: AsyncSession, _obj: "User") -> Kudosu: + return Kudosu(available=0, total=0) # TODO + + @ondemand + @staticmethod + async def unread_pm_count(session: AsyncSession, obj: "User") -> int: + return ( + await session.exec( + select(func.count()) + .join(Notification, col(Notification.id) == UserNotification.notification_id) + .select_from(UserNotification) + .where( + col(UserNotification.is_read).is_(False), + UserNotification.user_id == obj.id, + Notification.name == NotificationName.CHANNEL_MESSAGE, + text("details->>'$.type' = 'pm'"), + ) + ) + ).one() + + @included + @staticmethod + async def default_group(_session: AsyncSession, obj: "User") -> str: + return "default" if not obj.is_bot else "bot" + + @included + @staticmethod + async def is_online(_session: AsyncSession, obj: "User") -> bool: + from app.dependencies.database import get_redis + + redis = get_redis() + return bool(await redis.exists(f"metadata:online:{obj.id}")) + + @ondemand + @staticmethod + async def session_verified( + session: AsyncSession, + obj: "User", + token_id: int | None = None, + ) -> bool: + from app.service.verification_service import LoginSessionService + + return ( + not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) + if token_id + else True + ) + + @ondemand + @staticmethod + async def session_verification_method( + session: AsyncSession, + obj: "User", + token_id: int | None = None, + ) -> Literal["totp", "mail"] | None: + from app.dependencies.database import get_redis + from app.service.verification_service import LoginSessionService + + if (settings.enable_totp_verification or settings.enable_email_verification) and token_id: + redis = get_redis() + if not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id): + return None + return await LoginSessionService.get_login_method(obj.id, token_id, redis) + return None + + +class User(AsyncAttrs, UserModel, table=True): __tablename__: str = "lazer_users" - id: int = Field( - default=None, - sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), - ) + email: str = Field(max_length=254, unique=True, index=True) + priv: int = Field(default=1) + pw_bcrypt: str = Field(max_length=60) + silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True))) + donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True))) + account_history: list[UserAccountHistory] = Relationship(back_populates="user") statistics: list[UserStatistics] = Relationship(back_populates="user") achievement: list[UserAchievement] = Relationship(back_populates="user") @@ -166,12 +727,6 @@ class User(AsyncAttrs, UserBase, table=True): totp_key: TotpKeys | None = Relationship(back_populates="user") user_preference: UserPreference | None = Relationship(back_populates="user") - email: str = Field(max_length=254, unique=True, index=True, exclude=True) - priv: int = Field(default=1, exclude=True) - pw_bcrypt: str = Field(max_length=60, exclude=True) - silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True) - donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True) - async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]: from .relationship import Relationship, RelationshipType @@ -241,303 +796,8 @@ class User(AsyncAttrs, UserBase, table=True): return active_restrictions or False -class UserResp(UserBase): - id: int | None = None - is_online: bool = False - groups: list = [] # TODO - country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) - favourite_beatmapset_count: int = 0 - graveyard_beatmapset_count: int = 0 # TODO - guest_beatmapset_count: int = 0 # TODO - loved_beatmapset_count: int = 0 # TODO - mapping_follower_count: int = 0 # TODO - nominated_beatmapset_count: int = 0 # TODO - pending_beatmapset_count: int = 0 # TODO - ranked_beatmapset_count: int = 0 # TODO - follow_user_mapping: list[int] = Field(default_factory=list) - follower_count: int = 0 - friends: list["RelationshipResp"] | None = None - scores_best_count: int = 0 - scores_first_count: int = 0 - scores_recent_count: int = 0 - scores_pinned_count: int = 0 - beatmap_playcounts_count: int = 0 - account_history: list[UserAccountHistoryResp] = [] - active_tournament_banners: list[dict] = [] # TODO - kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO - monthly_playcounts: list[CountResp] = Field(default_factory=list) - replay_watched_counts: list[CountResp] = Field(default_factory=list) - unread_pm_count: int = 0 # TODO - rank_history: RankHistoryResp | None = None - rank_highest: RankHighest | None = None - statistics: UserStatisticsResp | None = None - statistics_rulesets: dict[str, UserStatisticsResp] | None = None - user_achievements: list[UserAchievementResp] = Field(default_factory=list) - cover_url: str = "" # deprecated - team: Team | None = None - daily_challenge_user_stats: DailyChallengeStatsResp | None = None - default_group: str = "" - is_deleted: bool = False # TODO - is_restricted: bool = False - user_preference: UserPreference | None = None - profile_order: list[str] = Field( - default_factory=lambda: DEFAULT_ORDER, - ) - - # TODO: unread_pm_count - - @classmethod - async def from_db( - cls, - obj: User, - session: AsyncSession, - include: list[str] = [], - ruleset: GameMode | None = None, - ) -> "UserResp": - from app.dependencies.database import get_redis - - from .best_scores import BestScore - from .favourite_beatmapset import FavouriteBeatmapset - from .relationship import Relationship, RelationshipResp, RelationshipType - from .score import Score, get_user_first_score_count - from .total_score_best_scores import TotalScoreBestScore - - ruleset = ruleset or obj.playmode - - u = cls.model_validate(obj.model_dump()) - u.id = obj.id - u.default_group = "bot" if u.is_bot else "default" - u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")) - u.follower_count = ( - await session.exec( - select(func.count()) - .select_from(Relationship) - .where( - Relationship.target_id == obj.id, - Relationship.type == RelationshipType.FOLLOW, - ) - ) - ).one() - u.scores_best_count = ( - await session.exec( - select(func.count()) - .select_from(TotalScoreBestScore) - .where( - TotalScoreBestScore.user_id == obj.id, - ) - .limit(200) - ) - ).one() - redis = get_redis() - u.is_online = bool(await redis.exists(f"metadata:online:{obj.id}")) - u.cover_url = obj.cover.get("url", "") if obj.cover else "" - - await obj.awaitable_attrs.user_preference - if obj.user_preference: - u.profile_order = obj.user_preference.extras_order - - if "user_preference" in include: - u.user_preference = obj.user_preference - - if "friends" in include: - u.friends = [ - await RelationshipResp.from_db(session, r) - for r in ( - await session.exec( - select(Relationship).where( - Relationship.user_id == obj.id, - Relationship.type == RelationshipType.FOLLOW, - ) - ) - ).all() - ] - - if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership): - u.team = team_membership.team - - if "account_history" in include: - u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history] - - if "daily_challenge_user_stats" in include and ( - daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats - ): - u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats) - - if "statistics" in include: - current_stattistics = None - for i in await obj.awaitable_attrs.statistics: - if i.mode == ruleset: - current_stattistics = i - break - u.statistics = ( - await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code) - if current_stattistics - else None - ) - - if "statistics_rulesets" in include: - u.statistics_rulesets = { - i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code) - for i in await obj.awaitable_attrs.statistics - } - - if "monthly_playcounts" in include: - u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts] - if len(u.monthly_playcounts) == 1: - d = u.monthly_playcounts[0].start_date - u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) - - if "replays_watched_counts" in include: - u.replay_watched_counts = [ - CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts - ] - if len(u.replay_watched_counts) == 1: - d = u.replay_watched_counts[0].start_date - u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) - - if "achievements" in include: - u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement] - if "rank_history" in include: - rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset) - if len(rank_history.data) != 0: - u.rank_history = rank_history - - rank_top = ( - await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset)) - ).first() - if rank_top: - u.rank_highest = ( - RankHighest( - rank=rank_top.rank, - updated_at=datetime.combine(rank_top.date, datetime.min.time()), - ) - if rank_top - else None - ) - if "is_restricted" in include: - u.is_restricted = await obj.is_restricted(session) - - u.favourite_beatmapset_count = ( - await session.exec( - select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id) - ) - ).one() - u.scores_pinned_count = ( - await session.exec( - select(func.count()) - .select_from(Score) - .where( - Score.user_id == obj.id, - Score.pinned_order > 0, - Score.gamemode == ruleset, - col(Score.passed).is_(True), - ) - ) - ).one() - u.scores_best_count = ( - await session.exec( - select(func.count()) - .select_from(BestScore) - .where( - BestScore.user_id == obj.id, - BestScore.gamemode == ruleset, - ) - .limit(200) - ) - ).one() - u.scores_recent_count = ( - await session.exec( - select(func.count()) - .select_from(Score) - .where( - Score.user_id == obj.id, - Score.gamemode == ruleset, - col(Score.passed).is_(True), - Score.ended_at > utcnow() - timedelta(hours=24), - ) - ) - ).one() - u.scores_first_count = await get_user_first_score_count(session, obj.id, ruleset) - u.beatmap_playcounts_count = ( - await session.exec( - select(func.count()) - .select_from(BeatmapPlaycounts) - .where( - BeatmapPlaycounts.user_id == obj.id, - ) - ) - ).one() - - return u - - -class MeResp(UserResp): - session_verification_method: Literal["totp", "mail"] | None = None - session_verified: bool = True - - @classmethod - async def from_db( - cls, - obj: User, - session: AsyncSession, - ruleset: GameMode | None = None, - *, - token_id: int | None = None, - ) -> "MeResp": - from app.dependencies.database import get_redis - from app.service.verification_service import LoginSessionService - - u = await super().from_db(obj, session, ALL_INCLUDED, ruleset) - u.session_verified = ( - not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) - if token_id - else True - ) - u = cls.model_validate(u.model_dump()) - if (settings.enable_totp_verification or settings.enable_email_verification) and token_id: - redis = get_redis() - if not u.session_verified: - u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis) - else: - u.session_verification_method = None - return u - - -ALL_INCLUDED = [ - "friends", - "team", - "account_history", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - "achievements", - "monthly_playcounts", - "replays_watched_counts", - "rank_history", - "is_restricted", - "session_verified", - "user_preference", -] - - -SEARCH_INCLUDED = [ - "team", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - "achievements", - "monthly_playcounts", - "replays_watched_counts", - "rank_history", -] - -BASE_INCLUDES = [ - "team", - "daily_challenge_user_stats", - "statistics", -] - -RANKING_INCLUDES = [ - "team", - "statistics", -] +# 为了向后兼容,在 SQL 查询中使用 User +# 例如: select(User).where(User.id == 1) +# 但类型注解和返回值使用 User +# 例如: async def get_user() -> User | None: +# return (await session.exec(select(User)...)).first() diff --git a/app/fetcher/beatmap.py b/app/fetcher/beatmap.py index 272f41f..206cf34 100644 --- a/app/fetcher/beatmap.py +++ b/app/fetcher/beatmap.py @@ -1,13 +1,40 @@ -from app.database.beatmap import BeatmapResp +from app.database.beatmap import BeatmapDict, BeatmapModel from app.log import fetcher_logger from ._base import BaseFetcher +from pydantic import TypeAdapter + logger = fetcher_logger("BeatmapFetcher") +adapter = TypeAdapter( + BeatmapModel.generate_typeddict( + ( + "checksum", + "accuracy", + "ar", + "bpm", + "convert", + "count_circles", + "count_sliders", + "count_spinners", + "cs", + "deleted_at", + "drain", + "hit_length", + "is_scoreable", + "last_updated", + "mode_int", + "ranked", + "url", + "max_combo", + "beatmapset", + ) + ) +) class BeatmapFetcher(BaseFetcher): - async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp: + async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapDict: if beatmap_id: params = {"id": beatmap_id} elif beatmap_checksum: @@ -16,7 +43,7 @@ class BeatmapFetcher(BaseFetcher): raise ValueError("Either beatmap_id or beatmap_checksum must be provided.") logger.opt(colors=True).debug(f"get_beatmap: {params}") - return BeatmapResp.model_validate( + return adapter.validate_python( # pyright: ignore[reportReturnType] await self.request_api( "https://osu.ppy.sh/api/v2/beatmaps/lookup", params=params, diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 893552f..910163f 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -3,7 +3,7 @@ import base64 import hashlib import json -from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp +from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp from app.helpers.rate_limiter import osu_api_rate_limiter from app.log import fetcher_logger from app.models.beatmap import SearchQueryModel @@ -13,6 +13,7 @@ from app.utils import bg_tasks from ._base import BaseFetcher from httpx import AsyncClient +from pydantic import TypeAdapter import redis.asyncio as redis @@ -26,6 +27,46 @@ logger = fetcher_logger("BeatmapsetFetcher") MAX_RETRY_ATTEMPTS = 3 +adapter = TypeAdapter( + BeatmapsetModel.generate_typeddict( + ( + "availability", + "bpm", + "last_updated", + "ranked", + "ranked_date", + "submitted_date", + "tags", + "storyboard", + "description", + "genre", + "language", + *[ + f"beatmaps.{inc}" + for inc in ( + "checksum", + "accuracy", + "ar", + "bpm", + "convert", + "count_circles", + "count_sliders", + "count_spinners", + "cs", + "deleted_at", + "drain", + "hit_length", + "is_scoreable", + "last_updated", + "mode_int", + "ranked", + "url", + "max_combo", + ) + ], + ) + ) +) class BeatmapsetFetcher(BaseFetcher): @@ -139,10 +180,9 @@ class BeatmapsetFetcher(BaseFetcher): except Exception: return {} - async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: + async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetDict: logger.opt(colors=True).debug(f"get_beatmapset: {beatmap_set_id}") - - return BeatmapsetResp.model_validate( + return adapter.validate_python( # pyright: ignore[reportReturnType] await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}") ) diff --git a/app/models/chat.py b/app/models/chat.py index 0e2b9da..c77bbdb 100644 --- a/app/models/chat.py +++ b/app/models/chat.py @@ -1,8 +1,6 @@ -from typing import Any - -from pydantic import BaseModel +from typing import Any, TypedDict -class ChatEvent(BaseModel): +class ChatEvent(TypedDict): event: str - data: dict[str, Any] | None = None + data: dict[str, Any] | None diff --git a/app/models/model.py b/app/models/model.py index 4c28048..a44a9df 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -3,16 +3,16 @@ from datetime import UTC, datetime from app.models.score import GameMode -from pydantic import BaseModel, field_serializer +from pydantic import BaseModel, FieldSerializationInfo, field_serializer class UTCBaseModel(BaseModel): - @field_serializer("*", when_used="json") - def serialize_datetime(self, v, _info): + @field_serializer("*", when_used="always") + def serialize_datetime(self, v, _info: FieldSerializationInfo): if isinstance(v, datetime): if v.tzinfo is None: v = v.replace(tzinfo=UTC) - return v.astimezone(UTC).isoformat().replace("+00:00", "Z") + return v.astimezone(UTC) return v diff --git a/app/router/lio.py b/app/router/lio.py index 33fbaa1..e9bcc18 100644 --- a/app/router/lio.py +++ b/app/router/lio.py @@ -447,7 +447,7 @@ async def create_multiplayer_room( # 让房主加入频道 host_user = await db.get(User, host_user_id) if host_user: - await server.batch_join_channel([host_user], channel, db) + await server.batch_join_channel([host_user], channel) # Add playlist items await _add_playlist_items(db, room_id, room_data, host_user_id) diff --git a/app/router/notification/banchobot.py b/app/router/notification/banchobot.py index 39a01c5..84f41be 100644 --- a/app/router/notification/banchobot.py +++ b/app/router/notification/banchobot.py @@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable from math import ceil import random import shlex +from typing import TYPE_CHECKING from app.calculator import calculate_weighted_pp from app.const import BANCHOBOT_ID -from app.database import ChatMessageResp -from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType + +if TYPE_CHECKING: + pass +from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType from app.database.score import Score, get_best_id from app.database.statistics import UserStatistics, get_rank from app.database.user import User @@ -95,7 +98,7 @@ class Bot: await session.commit() await session.refresh(msg) await session.refresh(bot) - resp = await ChatMessageResp.from_db(msg, session, bot) + resp = await ChatMessageModel.transform(msg, includes=["sender"]) await server.send_message_to_channel(resp) async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None: @@ -119,7 +122,7 @@ class Bot: await session.refresh(channel) await session.refresh(user) await session.refresh(bot) - await server.batch_join_channel([user, bot], channel, session) + await server.batch_join_channel([user, bot], channel) return channel async def _send_reply( diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py index 3d98dda..4428572 100644 --- a/app/router/notification/channel.py +++ b/app/router/notification/channel.py @@ -1,37 +1,40 @@ -from typing import Annotated, Any, Literal, Self +from typing import Annotated, Literal, Self from app.database.chat import ( ChannelType, ChatChannel, - ChatChannelResp, + ChatChannelModel, ChatMessage, SilenceUser, UserSilenceResp, ) -from app.database.user import User, UserResp +from app.database.user import User, UserModel from app.dependencies.database import Database, Redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router +from app.utils import api_doc from .server import server from fastapi import Depends, HTTPException, Path, Query, Security -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, model_validator from sqlmodel import col, select -class UpdateResponse(BaseModel): - presence: list[ChatChannelResp] = Field(default_factory=list) - silences: list[Any] = Field(default_factory=list) - - @router.get( "/chat/updates", - response_model=UpdateResponse, name="获取更新", description="获取当前用户所在频道的最新的禁言情况。", tags=["聊天"], + responses={ + 200: api_doc( + "获取更新响应。", + {"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]}, + ChatChannel.LISTING_INCLUDES, + name="UpdateResponse", + ) + }, ) async def get_update( session: Database, @@ -44,45 +47,44 @@ async def get_update( Query(alias="includes[]", description="要包含的更新类型"), ] = ["presence", "silences"], ): - resp = UpdateResponse() + resp = { + "presence": [], + "silences": [], + } if "presence" in includes: channel_ids = server.get_user_joined_channel(current_user.id) for channel_id in channel_ids: # 使用明确的查询避免延迟加载 db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first() if db_channel: - # 提取必要的属性避免惰性加载 - channel_type = db_channel.type - - resp.presence.append( - await ChatChannelResp.from_db( + resp["presence"].append( + await ChatChannelModel.transform( db_channel, - session, - current_user, - redis, - server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, + user=current_user, + server=server, + includes=ChatChannel.LISTING_INCLUDES, ) ) if "silences" in includes: if history_since: silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all() - resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) + resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences]) elif since: msg = await session.get(ChatMessage, since) if msg: silences = ( await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp)) ).all() - resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) + resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences]) return resp @router.put( "/chat/channels/{channel}/users/{user}", - response_model=ChatChannelResp, name="加入频道", description="加入指定的公开/房间频道。", tags=["聊天"], + responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)}, ) async def join_channel( session: Database, @@ -101,7 +103,7 @@ async def join_channel( if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - return await server.join_channel(current_user, db_channel, session) + return await server.join_channel(current_user, db_channel) @router.delete( @@ -128,13 +130,13 @@ async def leave_channel( if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - await server.leave_channel(current_user, db_channel, session) + await server.leave_channel(current_user, db_channel) return @router.get( "/chat/channels", - response_model=list[ChatChannelResp], + responses={200: api_doc("加入的频道", list[ChatChannelModel])}, name="获取频道列表", description="获取所有公开频道。", tags=["聊天"], @@ -142,35 +144,30 @@ async def leave_channel( async def get_channel_list( session: Database, current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], - redis: Redis, ): channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all() - results = [] - for channel in channels: - # 提取必要的属性避免惰性加载 - channel_id = channel.channel_id - channel_type = channel.type + results = await ChatChannelModel.transform_many( + channels, + user=current_user, + server=server, + ) - results.append( - await ChatChannelResp.from_db( - channel, - session, - current_user, - redis, - server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, - ) - ) return results -class GetChannelResp(BaseModel): - channel: ChatChannelResp - users: list[UserResp] = Field(default_factory=list) - - @router.get( "/chat/channels/{channel}", - response_model=GetChannelResp, + responses={ + 200: api_doc( + "频道详细信息", + { + "channel": ChatChannelModel, + "users": list[UserModel], + }, + ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES, + name="GetChannelResponse", + ) + }, name="获取频道信息", description="获取指定频道的信息。", tags=["聊天"], @@ -191,7 +188,6 @@ async def get_channel( raise HTTPException(status_code=404, detail="Channel not found") # 立即提取需要的属性 - channel_id = db_channel.channel_id channel_type = db_channel.type channel_name = db_channel.name @@ -209,15 +205,15 @@ async def get_channel( users.extend([target_user, current_user]) break - return GetChannelResp( - channel=await ChatChannelResp.from_db( + return { + "channel": await ChatChannelModel.transform( db_channel, - session, - current_user, - redis, - server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, - ) - ) + user=current_user, + server=server, + includes=ChatChannel.LISTING_INCLUDES, + ), + "users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES), + } class CreateChannelReq(BaseModel): @@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel): @router.post( "/chat/channels", - response_model=ChatChannelResp, + responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])}, name="创建频道", description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。", tags=["聊天"], @@ -289,21 +285,13 @@ async def create_channel( await session.refresh(current_user) if req.type == "PM": await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable] - await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable] + await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable] else: target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or []))) - await server.batch_join_channel([*target_users, current_user], channel, session) + await server.batch_join_channel([*target_users, current_user], channel) - await server.join_channel(current_user, channel, session) + await server.join_channel(current_user, channel) - # 提取必要的属性避免惰性加载 - channel_id = channel.channel_id - - return await ChatChannelResp.from_db( - channel, - session, - current_user, - redis, - server.channels.get(channel_id, []), - include_recent_messages=True, + return await ChatChannelModel.transform( + channel, user=current_user, server=server, includes=["recent_messages.sender"] ) diff --git a/app/router/notification/message.py b/app/router/notification/message.py index 4d5312c..f5b2d5b 100644 --- a/app/router/notification/message.py +++ b/app/router/notification/message.py @@ -1,11 +1,11 @@ from typing import Annotated -from app.database import ChatMessageResp +from app.database import ChatChannelModel from app.database.chat import ( ChannelType, ChatChannel, - ChatChannelResp, ChatMessage, + ChatMessageModel, MessageType, SilenceUser, UserSilenceResp, @@ -18,6 +18,7 @@ from app.log import log from app.models.notification import ChannelMessage, ChannelMessageTeam from app.router.v2 import api_v2_router as router from app.service.redis_message_system import redis_message_system +from app.utils import api_doc from .banchobot import bot from .server import server @@ -68,7 +69,7 @@ class MessageReq(BaseModel): @router.post( "/chat/channels/{channel}/messages", - response_model=ChatMessageResp, + responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])}, name="发送消息", description="发送消息到指定频道。", tags=["聊天"], @@ -130,7 +131,7 @@ async def send_message( # 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道) if channel_type in [ChannelType.PM, ChannelType.TEAM]: temp_msg = ChatMessage( - message_id=resp.message_id, # 使用 Redis 系统生成的ID + message_id=resp["message_id"], # 使用 Redis 系统生成的ID channel_id=channel_id, content=req.message, sender_id=user_id, @@ -151,7 +152,7 @@ async def send_message( @router.get( "/chat/channels/{channel}/messages", - response_model=list[ChatMessageResp], + responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])}, name="获取消息", description="获取指定频道的消息列表(统一按时间正序返回)。", tags=["聊天"], @@ -177,7 +178,7 @@ async def get_message( try: messages = await redis_message_system.get_messages(channel_id, limit, since) - if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id: + if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]: messages.reverse() return messages except Exception as e: @@ -189,7 +190,7 @@ async def get_message( # 向前加载新消息 → 直接 ASC query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit) rows = (await session.exec(query)).all() - resp = [await ChatMessageResp.from_db(m, session) for m in rows] + resp = await ChatMessageModel.transform_many(rows, includes=["sender"]) # 已经 ASC,无需反转 return resp @@ -202,15 +203,14 @@ async def get_message( rows = (await session.exec(query)).all() rows = list(rows) rows.reverse() # 反转为 ASC - resp = [await ChatMessageResp.from_db(m, session) for m in rows] + resp = await ChatMessageModel.transform_many(rows, includes=["sender"]) return resp query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit) rows = (await session.exec(query)).all() rows = list(rows) rows.reverse() # 反转为 ASC - resp = [await ChatMessageResp.from_db(m, session) for m in rows] - return resp + resp = await ChatMessageModel.transform_many(rows, includes=["sender"]) return resp @@ -248,17 +248,23 @@ class PMReq(BaseModel): uuid: str | None = None -class NewPMResp(BaseModel): - channel: ChatChannelResp - message: ChatMessageResp - new_channel_id: int - - @router.post( "/chat/new", name="创建私聊频道", description="创建一个新的私聊频道。", tags=["聊天"], + responses={ + 200: api_doc( + "创建私聊频道响应", + { + "channel": ChatChannelModel, + "message": ChatMessageModel, + "new_channel_id": int, + }, + ["recent_messages.sender", "sender"], + name="NewPMResponse", + ) + }, ) async def create_new_pm( session: Database, @@ -290,9 +296,9 @@ async def create_new_pm( await session.refresh(target) await session.refresh(current_user) - await server.batch_join_channel([target, current_user], channel, session) - channel_resp = await ChatChannelResp.from_db( - channel, session, current_user, redis, server.channels[channel.channel_id] + await server.batch_join_channel([target, current_user], channel) + channel_resp = await ChatChannelModel.transform( + channel, user=current_user, server=server, includes=["recent_messages.sender"] ) msg = ChatMessage( channel_id=channel.channel_id, @@ -306,10 +312,10 @@ async def create_new_pm( await session.refresh(msg) await session.refresh(current_user) await session.refresh(channel) - message_resp = await ChatMessageResp.from_db(msg, session, current_user) + message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"]) await server.send_message_to_channel(message_resp) - return NewPMResp( - channel=channel_resp, - message=message_resp, - new_channel_id=channel_resp.channel_id, - ) + return { + "channel": channel_resp, + "message": message_resp, + "new_channel_id": channel_resp["channel_id"], + } diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 7818887..ad4a17a 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -1,7 +1,8 @@ import asyncio from typing import Annotated, overload -from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp +from app.database import ChatMessageDict +from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel from app.database.notification import UserNotification, insert_notification from app.database.user import User from app.dependencies.database import ( @@ -16,7 +17,7 @@ from app.log import log from app.models.chat import ChatEvent from app.models.notification import NotificationDetail from app.service.subscribers.chat import ChatSubscriber -from app.utils import bg_tasks +from app.utils import bg_tasks, safe_json_dumps from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes @@ -65,7 +66,7 @@ class ChatServer: await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id)) ).first() if db_channel: - await self.leave_channel(user, db_channel, session) + await self.leave_channel(user, db_channel) @overload async def send_event(self, client: int, event: ChatEvent): ... @@ -80,7 +81,7 @@ class ChatServer: return client = client_ if client.client_state == WebSocketState.CONNECTED: - await client.send_text(event.model_dump_json()) + await client.send_text(safe_json_dumps(event)) async def broadcast(self, channel_id: int, event: ChatEvent): users_in_channel = self.channels.get(channel_id, []) @@ -107,38 +108,38 @@ class ChatServer: async def mark_as_read(self, channel_id: int, user_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id) - async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False): + async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False): logger.info( - f"Sending message to channel {message.channel_id}, message_id: " - f"{message.message_id}, is_bot_command: {is_bot_command}" + f"Sending message to channel {message['channel_id']}, message_id: " + f"{message['message_id']}, is_bot_command: {is_bot_command}" ) event = ChatEvent( event="chat.message.new", - data={"messages": [message], "users": [message.sender]}, + data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess] ) if is_bot_command: - logger.info(f"Sending bot command to user {message.sender_id}") - bg_tasks.add_task(self.send_event, message.sender_id, event) + logger.info(f"Sending bot command to user {message['sender_id']}") + bg_tasks.add_task(self.send_event, message["sender_id"], event) else: # 总是广播消息,无论是临时ID还是真实ID - logger.info(f"Broadcasting message to all users in channel {message.channel_id}") + logger.info(f"Broadcasting message to all users in channel {message['channel_id']}") bg_tasks.add_task( self.broadcast, - message.channel_id, + message["channel_id"], event, ) # 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息 # Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理 - if message.message_id and message.message_id > 0: - await self.mark_as_read(message.channel_id, message.sender_id, message.message_id) - await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) - logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}") + if message["message_id"] and message["message_id"] > 0: + await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"]) + await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"]) + logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}") else: - logger.debug(f"Skipping last message update for message ID: {message.message_id}") + logger.debug(f"Skipping last message update for message ID: {message['message_id']}") - async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession): + async def batch_join_channel(self, users: list[User], channel: ChatChannel): channel_id = channel.channel_id not_joined = [] @@ -151,22 +152,18 @@ class ChatServer: not_joined.append(user) for user in not_joined: - channel_resp = await ChatChannelResp.from_db( - channel, - session, - user, - self.redis, - self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, + channel_resp = await ChatChannelModel.transform( + channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES ) await self.send_event( user.id, ChatEvent( event="chat.channel.join", - data=channel_resp.model_dump(), + data=channel_resp, # pyright: ignore[reportArgumentType] ), ) - async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp: + async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict: user_id = user.id channel_id = channel.channel_id @@ -175,25 +172,21 @@ class ChatServer: if user_id not in self.channels[channel_id]: self.channels[channel_id].append(user_id) - channel_resp = await ChatChannelResp.from_db( - channel, - session, - user, - self.redis, - self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, + channel_resp: ChatChannelDict = await ChatChannelModel.transform( + channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES ) await self.send_event( user_id, ChatEvent( event="chat.channel.join", - data=channel_resp.model_dump(), + data=channel_resp, # pyright: ignore[reportArgumentType] ), ) return channel_resp - async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None: + async def leave_channel(self, user: User, channel: ChatChannel) -> None: user_id = user.id channel_id = channel.channel_id @@ -203,18 +196,14 @@ class ChatServer: if (c := self.channels.get(channel_id)) is not None and not c: del self.channels[channel_id] - channel_resp = await ChatChannelResp.from_db( - channel, - session, - user, - self.redis, - self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None, + channel_resp = await ChatChannelModel.transform( + channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES ) await self.send_event( user_id, ChatEvent( event="chat.channel.part", - data=channel_resp.model_dump(), + data=channel_resp, # pyright: ignore[reportArgumentType] ), ) @@ -232,7 +221,7 @@ class ChatServer: return logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})") - await self.join_channel(user, db_channel, session) + await self.join_channel(user, db_channel) async def leave_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: @@ -248,7 +237,7 @@ class ChatServer: return logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})") - await self.leave_channel(user, db_channel, session) + await self.leave_channel(user, db_channel) async def new_private_notification(self, detail: NotificationDetail): async with with_db() as session: @@ -336,6 +325,6 @@ async def chat_websocket( # 使用明确的查询避免延迟加载 db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first() if db_channel is not None: - await server.join_channel(user, db_channel, session) + await server.join_channel(user, db_channel) await _listen_stop(websocket, user_id, factory) diff --git a/app/router/private/team.py b/app/router/private/team.py index 826176a..1afa61b 100644 --- a/app/router/private/team.py +++ b/app/router/private/team.py @@ -2,7 +2,7 @@ import hashlib from typing import Annotated from app.database.team import Team, TeamMember, TeamRequest, TeamResp -from app.database.user import BASE_INCLUDES, User, UserResp +from app.database.user import User, UserModel from app.dependencies.database import Database, Redis from app.dependencies.storage import StorageService from app.dependencies.user import ClientUser @@ -14,12 +14,11 @@ from app.models.notification import ( from app.models.score import GameMode from app.router.notification import server from app.service.ranking_cache_service import get_ranking_cache_service -from app.utils import check_image, utcnow +from app.utils import api_doc, check_image, utcnow from .router import router from fastapi import File, Form, HTTPException, Path, Query, Request -from pydantic import BaseModel from sqlmodel import col, exists, select @@ -214,12 +213,22 @@ async def delete_team( await cache_service.invalidate_team_cache() -class TeamQueryResp(BaseModel): - team: TeamResp - members: list[UserResp] - - -@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"]) +@router.get( + "/team/{team_id}", + name="查询战队", + tags=["战队", "g0v0 API"], + responses={ + 200: api_doc( + "战队信息", + { + "team": TeamResp, + "members": list[UserModel], + }, + ["statistics", "country"], + name="TeamQueryResp", + ) + }, +) async def get_team( session: Database, team_id: Annotated[int, Path(..., description="战队 ID")], @@ -233,10 +242,10 @@ async def get_team( ) ) ).all() - return TeamQueryResp( - team=await TeamResp.from_db(members[0].team, session, gamemode), - members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members], - ) + return { + "team": await TeamResp.from_db(members[0].team, session, gamemode), + "members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]), + } @router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"]) diff --git a/app/router/v1/user.py b/app/router/v1/user.py index aee0569..29b642f 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Annotated, Literal -from app.database.statistics import UserStatistics, UserStatisticsResp +from app.database.statistics import UserStatistics, UserStatisticsModel from app.database.user import User from app.dependencies.database import Database, get_redis from app.log import logger @@ -46,7 +46,7 @@ class V1User(AllStrModel): return f"v1_user:{user_id}" @classmethod - async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User": + async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User": ruleset = ruleset or db_user.playmode current_statistics: UserStatistics | None = None for i in await db_user.awaitable_attrs.statistics: @@ -54,31 +54,33 @@ class V1User(AllStrModel): current_statistics = i break if current_statistics: - statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code) + statistics = await UserStatisticsModel.transform( + current_statistics, country_code=db_user.country_code, includes=["country_rank"] + ) else: statistics = None return cls( user_id=db_user.id, username=db_user.username, join_date=db_user.join_date, - count300=statistics.count_300 if statistics else 0, - count100=statistics.count_100 if statistics else 0, - count50=statistics.count_50 if statistics else 0, - playcount=statistics.play_count if statistics else 0, - ranked_score=statistics.ranked_score if statistics else 0, - total_score=statistics.total_score if statistics else 0, - pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0, + count300=current_statistics.count_300 if current_statistics else 0, + count100=current_statistics.count_100 if current_statistics else 0, + count50=current_statistics.count_50 if current_statistics else 0, + playcount=current_statistics.play_count if current_statistics else 0, + ranked_score=current_statistics.ranked_score if current_statistics else 0, + total_score=current_statistics.total_score if current_statistics else 0, + pp_rank=statistics.get("global_rank") or 0 if statistics else 0, level=current_statistics.level_current if current_statistics else 0, - pp_raw=statistics.pp if statistics else 0.0, - accuracy=statistics.hit_accuracy if statistics else 0, + pp_raw=current_statistics.pp if current_statistics else 0.0, + accuracy=current_statistics.hit_accuracy if current_statistics else 0, count_rank_ss=current_statistics.grade_ss if current_statistics else 0, count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0, count_rank_s=current_statistics.grade_s if current_statistics else 0, count_rank_sh=current_statistics.grade_sh if current_statistics else 0, count_rank_a=current_statistics.grade_a if current_statistics else 0, country=db_user.country_code, - total_seconds_played=statistics.play_time if statistics else 0, - pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0, + total_seconds_played=current_statistics.play_time if current_statistics else 0, + pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0, events=[], # TODO ) @@ -134,7 +136,7 @@ async def get_user( try: # 生成用户数据 - v1_user = await V1User.from_db(session, db_user, ruleset) + v1_user = await V1User.from_db(db_user, ruleset) # 异步缓存结果(如果有用户ID) if db_user.id is not None: diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index 53cf70f..943317d 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -5,7 +5,11 @@ from typing import Annotated from app.calculator import get_calculator from app.calculators.performance import ConvertError -from app.database import Beatmap, BeatmapResp, User +from app.database import ( + Beatmap, + BeatmapModel, + User, +) from app.database.beatmap import calculate_beatmap_attributes from app.dependencies.database import Database, Redis from app.dependencies.fetcher import Fetcher @@ -19,29 +23,20 @@ from app.models.performance import ( from app.models.score import ( GameMode, ) +from app.utils import api_doc from .router import router from fastapi import HTTPException, Path, Query, Security from httpx import HTTPError, HTTPStatusError -from pydantic import BaseModel from sqlmodel import col, select -class BatchGetResp(BaseModel): - """批量获取谱面返回模型。 - - 返回字段说明: - - beatmaps: 谱面详细信息列表。""" - - beatmaps: list[BeatmapResp] - - @router.get( "/beatmaps/lookup", tags=["谱面"], name="查询单个谱面", - response_model=BeatmapResp, + responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)}, description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"), ) @asset_proxy_response @@ -67,14 +62,14 @@ async def lookup_beatmap( raise HTTPException(status_code=404, detail="Beatmap not found") await db.refresh(current_user) - return await BeatmapResp.from_db(beatmap, session=db, user=current_user) + return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES) @router.get( "/beatmaps/{beatmap_id}", tags=["谱面"], name="获取谱面详情", - response_model=BeatmapResp, + responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)}, description="获取单个谱面详情。", ) @asset_proxy_response @@ -86,7 +81,12 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) - return await BeatmapResp.from_db(beatmap, session=db, user=current_user) + await db.refresh(current_user) + return await BeatmapModel.transform( + beatmap, + user=current_user, + includes=BeatmapModel.TRANSFORMER_INCLUDES, + ) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -95,7 +95,11 @@ async def get_beatmap( "/beatmaps/", tags=["谱面"], name="批量获取谱面", - response_model=BatchGetResp, + responses={ + 200: api_doc( + "谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse" + ) + }, description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"), ) @asset_proxy_response @@ -124,7 +128,12 @@ async def batch_get_beatmaps( for beatmap in beatmaps: await db.refresh(beatmap) await db.refresh(current_user) - return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps]) + return { + "beatmaps": [ + await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES) + for bm in beatmaps + ] + } @router.post( diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index 5129c32..7fb685a 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -2,17 +2,24 @@ import re from typing import Annotated, Literal from urllib.parse import parse_qs -from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User -from app.database.beatmapset import SearchBeatmapsetsResp +from app.database import ( + Beatmap, + Beatmapset, + BeatmapsetModel, + FavouriteBeatmapset, + SearchBeatmapsetsResp, + User, +) from app.dependencies.beatmap_download import DownloadService from app.dependencies.cache import BeatmapsetCacheService, UserCacheService -from app.dependencies.database import Database, Redis, with_db +from app.dependencies.database import Database, Redis from app.dependencies.fetcher import Fetcher from app.dependencies.geoip import IPAddress, get_geoip_helper from app.dependencies.user import ClientUser, get_current_user from app.helpers.asset_proxy_helper import asset_proxy_response from app.models.beatmap import SearchQueryModel from app.service.beatmapset_cache_service import generate_hash +from app.utils import api_doc from .router import router @@ -27,14 +34,7 @@ from fastapi import ( ) from fastapi.responses import RedirectResponse from httpx import HTTPError -from sqlmodel import exists, select - - -async def _save_to_db(sets: SearchBeatmapsetsResp): - async with with_db() as session: - for s in sets.beatmapsets: - if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first(): - await Beatmapset.from_resp(session, s) +from sqlmodel import select @router.get( @@ -105,7 +105,6 @@ async def search_beatmapset( try: sets = await fetcher.search_beatmapset(query, cursor, redis) - background_tasks.add_task(_save_to_db, sets) # 缓存搜索结果 await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump()) @@ -117,8 +116,8 @@ async def search_beatmapset( @router.get( "/beatmapsets/lookup", tags=["谱面集"], + responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)}, name="查询谱面集 (通过谱面 ID)", - response_model=BeatmapsetResp, description=("通过谱面 ID 查询所属谱面集。"), ) @asset_proxy_response @@ -137,7 +136,10 @@ async def lookup_beatmapset( try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) - resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user) + + resp = await BeatmapsetModel.transform( + beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES + ) # 缓存结果 await cache_service.cache_beatmap_lookup(beatmap_id, resp) @@ -149,8 +151,8 @@ async def lookup_beatmapset( @router.get( "/beatmapsets/{beatmapset_id}", tags=["谱面集"], + responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)}, name="获取谱面集详情", - response_model=BeatmapsetResp, description="获取单个谱面集详情。", ) @asset_proxy_response @@ -169,7 +171,8 @@ async def get_beatmapset( try: beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id) - resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user) + await db.refresh(current_user) + resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user) # 缓存结果 await cache_service.cache_beatmapset(resp) diff --git a/app/router/v2/me.py b/app/router/v2/me.py index 7122638..d0e94b6 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -1,20 +1,29 @@ from typing import Annotated -from app.database import FavouriteBeatmapset, MeResp, User +from app.database import FavouriteBeatmapset, User +from app.database.user import UserModel from app.dependencies.database import Database from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token from app.models.score import GameMode +from app.utils import api_doc from .router import router from fastapi import Path, Security from fastapi.responses import RedirectResponse +from pydantic import BaseModel from sqlmodel import select +ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"] + + +class BeatmapsetIds(BaseModel): + beatmapset_ids: list[int] + @router.get( "/me/beatmapset-favourites", - response_model=list[int], + response_model=BeatmapsetIds, name="获取当前用户收藏的谱面集 ID 列表", description="获取当前登录用户收藏的谱面集 ID 列表。", tags=["用户", "谱面集"], @@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites( beatmapset_ids = await session.exec( select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id) ) - return beatmapset_ids.all() + return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all())) @router.get( "/me/{ruleset}", - response_model=MeResp, + responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)}, name="获取当前用户信息 (指定 ruleset)", description="获取当前登录用户信息 (含指定 ruleset 统计)。", tags=["用户"], ) async def get_user_info_with_ruleset( - session: Database, ruleset: Annotated[GameMode, Path(description="指定 ruleset")], user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])], ): - user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id) + user_resp = await UserModel.transform( + user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES + ) return user_resp @router.get( "/me/", - response_model=MeResp, + responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)}, name="获取当前用户信息", description="获取当前登录用户信息。", tags=["用户"], ) async def get_user_info_default( - session: Database, user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])], ): - user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id) + user_resp = await UserModel.transform( + user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES + ) return user_resp diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index 17aabb6..98682a6 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -1,11 +1,13 @@ from typing import Annotated, Literal from app.config import settings -from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp +from app.database import Team, TeamMember, User, UserStatistics +from app.database.statistics import UserStatisticsModel from app.dependencies.database import Database, get_redis from app.dependencies.user import get_current_user from app.models.score import GameMode from app.service.ranking_cache_service import get_ranking_cache_service +from app.utils import api_doc from .router import router @@ -308,14 +310,16 @@ async def get_country_ranking( return response -class TopUsersResponse(BaseModel): - ranking: list[UserStatisticsResp] - total: int - - @router.get( "/rankings/{ruleset}/{sort}", - response_model=TopUsersResponse, + responses={ + 200: api_doc( + "用户排行榜", + {"ranking": list[UserStatisticsModel], "total": int}, + ["user.country", "user.cover"], + name="TopUsersResponse", + ) + }, name="获取用户排行榜", description="获取在指定模式下的用户排行榜", tags=["排行榜"], @@ -339,10 +343,10 @@ async def get_user_ranking( if cached_data and cached_stats: # 从缓存返回数据 - return TopUsersResponse( - ranking=[UserStatisticsResp.model_validate(item) for item in cached_data], - total=cached_stats.get("total", 0), - ) + return { + "ranking": cached_data, + "total": cached_stats.get("total", 0), + } # 缓存未命中,从数据库查询 wheres = [ @@ -350,7 +354,7 @@ async def get_user_ranking( col(UserStatistics.pp) > 0, col(UserStatistics.is_ranked), ] - include = ["user"] + include = UserStatistics.RANKING_INCLUDES.copy() if sort == "performance": order_by = col(UserStatistics.pp).desc() include.append("rank_change_since_30_days") @@ -358,6 +362,7 @@ async def get_user_ranking( order_by = col(UserStatistics.ranked_score).desc() if country: wheres.append(col(UserStatistics.user).has(country_code=country.upper())) + include.append("country_rank") # 查询总数 count_query = select(func.count()).select_from(UserStatistics).where(*wheres) @@ -378,12 +383,14 @@ async def get_user_ranking( # 转换为响应格式 ranking_data = [] for statistics in statistics_list: - user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include) + user_stats_resp = await UserStatisticsModel.transform( + statistics, includes=include, user_country=current_user.country_code + ) ranking_data.append(user_stats_resp) # 异步缓存数据(不等待完成) # 使用配置文件中的TTL设置 - cache_data = [item.model_dump() for item in ranking_data] + cache_data = ranking_data stats_data = {"total": total_count} # 创建后台任务来缓存数据 @@ -407,5 +414,7 @@ async def get_user_ranking( ttl=settings.ranking_cache_expire_minutes * 60, ) - resp = TopUsersResponse(ranking=ranking_data, total=total_count) - return resp + return { + "ranking": ranking_data, + "total": total_count, + } diff --git a/app/router/v2/relationship.py b/app/router/v2/relationship.py index 554596a..7d0c891 100644 --- a/app/router/v2/relationship.py +++ b/app/router/v2/relationship.py @@ -1,15 +1,16 @@ -from typing import Annotated +from typing import Annotated, Any -from app.database import Relationship, RelationshipResp, RelationshipType, User -from app.database.user import UserResp +from app.database import Relationship, RelationshipType, User +from app.database.relationship import RelationshipModel +from app.database.user import UserModel from app.dependencies.api_version import APIVersion from app.dependencies.database import Database from app.dependencies.user import ClientUser, get_current_user +from app.utils import api_doc from .router import router from fastapi import HTTPException, Path, Query, Request, Security -from pydantic import BaseModel from sqlmodel import col, exists, select @@ -17,38 +18,19 @@ from sqlmodel import col, exists, select "/friends", tags=["用户关系"], responses={ - 200: { - "description": "好友列表", - "content": { - "application/json": { - "schema": { - "oneOf": [ - { - "type": "array", - "items": {"$ref": "#/components/schemas/RelationshipResp"}, - "description": "好友列表", - }, - { - "type": "array", - "items": {"$ref": "#/components/schemas/UserResp"}, - "description": "好友列表 (`x-api-version < 20241022`)", - }, - ] - } - } - }, - } + 200: api_doc( + "好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表。", + list[RelationshipModel] | list[UserModel], + [f"target.{inc}" for inc in User.LIST_INCLUDES], + ) }, name="获取好友列表", - description=( - "获取当前用户的好友列表。\n\n" - "如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。" - ), + description="获取当前用户的好友列表。", ) @router.get( "/blocks", tags=["用户关系"], - response_model=list[RelationshipResp], + response_model=list[dict[str, Any]], name="获取屏蔽列表", description="获取当前用户的屏蔽用户列表。", ) @@ -67,35 +49,29 @@ async def get_relationship( ) ) if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK: - return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] + return [ + await RelationshipModel.transform( + rel, + includes=[f"target.{inc}" for inc in User.LIST_INCLUDES], + ruleset=current_user.playmode, + ) + for rel in relationships.unique() + ] else: return [ - await UserResp.from_db( + await UserModel.transform( rel.target, - db, - include=[ - "team", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - ], + ruleset=current_user.playmode, + includes=User.LIST_INCLUDES, ) for rel in relationships.unique() ] -class AddFriendResp(BaseModel): - """添加好友/屏蔽 返回模型。 - - - user_relation: 新的或更新后的关系对象。""" - - user_relation: RelationshipResp - - @router.post( "/friends", tags=["用户关系"], - response_model=AddFriendResp, + responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")}, name="添加或更新好友关系", description="\n添加或更新与目标用户的好友关系。", ) @@ -163,7 +139,13 @@ async def add_relationship( ) ) ).one() - return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship)) + return { + "user_relation": await RelationshipModel.transform( + relationship, + includes=[], + ruleset=current_user.playmode, + ) + } @router.delete( diff --git a/app/router/v2/room.py b/app/router/v2/room.py index b8d06ed..94ed5a0 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -1,25 +1,27 @@ from datetime import UTC from typing import Annotated, Literal -from app.database.beatmap import Beatmap, BeatmapResp -from app.database.beatmapset import BeatmapsetResp -from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp +from app.database.beatmap import ( + Beatmap, + BeatmapModel, +) +from app.database.beatmapset import BeatmapsetModel +from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp -from app.database.playlists import Playlist, PlaylistResp -from app.database.room import APIUploadedRoom, Room, RoomResp +from app.database.playlists import Playlist, PlaylistModel +from app.database.room import APIUploadedRoom, Room, RoomModel from app.database.room_participated_user import RoomParticipatedUser from app.database.score import Score -from app.database.user import User, UserResp +from app.database.user import User, UserModel from app.dependencies.database import Database, Redis from app.dependencies.user import ClientUser, get_current_user from app.models.room import MatchType, RoomCategory, RoomStatus from app.service.room import create_playlist_room_from_api -from app.utils import utcnow +from app.utils import api_doc, utcnow from .router import router from fastapi import HTTPException, Path, Query, Security -from pydantic import BaseModel, Field from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get( "/rooms", tags=["房间"], - response_model=list[RoomResp], + responses={ + 200: api_doc( + "房间列表", + list[RoomModel], + [ + "current_playlist_item.beatmap.beatmapset", + "difficulty_range", + "host.country", + "playlist_item_stats", + "recent_participants", + ], + ) + }, name="获取房间列表", description="获取房间列表。支持按状态/模式筛选", ) @@ -49,7 +63,7 @@ async def get_all_rooms( ] = RoomCategory.NORMAL, status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None, ): - resp_list: list[RoomResp] = [] + resp_list = [] where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING] now = utcnow() @@ -90,22 +104,24 @@ async def get_all_rooms( .all() ) for room in db_rooms: - resp = await RoomResp.from_db(room, db) + resp = await RoomModel.transform( + room, + includes=[ + "current_playlist_item.beatmap.beatmapset", + "difficulty_range", + "host.country", + "playlist_item_stats", + "recent_participants", + ], + ) if category == RoomCategory.REALTIME: - resp.category = RoomCategory.NORMAL + resp["category"] = RoomCategory.NORMAL resp_list.append(resp) return resp_list -class APICreatedRoom(RoomResp): - """创建房间返回模型,继承 RoomResp。额外字段: - - error: 错误信息(为空表示成功)。""" - - error: str = "" - - async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis): participated_user = ( await session.exec( @@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session: @router.post( "/rooms", tags=["房间"], - response_model=APICreatedRoom, name="创建房间", description="\n创建一个新的房间。", + responses={ + 200: api_doc( + "创建的房间信息", + RoomModel, + Room.SHOW_RESPONSE_INCLUDES, + ) + }, ) async def create_room( db: Database, @@ -145,23 +167,27 @@ async def create_room( ): if await current_user.is_restricted(db): raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.") - user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) await _participate_room(db_room.id, user_id, db_room, db, redis) await db.commit() await db.refresh(db_room) - created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db)) - created_room.error = "" + created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES) return created_room @router.get( "/rooms/{room_id}", tags=["房间"], - response_model=RoomResp, + responses={ + 200: api_doc( + "房间详细信息", + RoomModel, + Room.SHOW_RESPONSE_INCLUDES, + ) + }, name="获取房间详情", - description="获取单个房间详情。", + description="获取指定房间详情。", ) async def get_room( db: Database, @@ -177,7 +203,7 @@ async def get_room( db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is None: raise HTTPException(404, "Room not found") - resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user) + resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user) return resp @@ -225,10 +251,10 @@ async def add_user_to_room( await _participate_room(room_id, user_id, db_room, db, redis) await db.commit() await db.refresh(db_room) - resp = await RoomResp.from_db(db_room, db) + resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES) return resp else: - raise HTTPException(404, "room not found0") + raise HTTPException(404, "room not found") @router.delete( @@ -268,21 +294,22 @@ async def remove_user_from_room( raise HTTPException(404, "Room not found") -class APILeaderboard(BaseModel): - """房间全局排行榜返回模型。 - - leaderboard: 用户游玩统计(尝试次数/分数等)。 - - user_score: 当前用户对应统计。""" - - leaderboard: list[ItemAttemptsResp] = Field(default_factory=list) - user_score: ItemAttemptsResp | None = None - - @router.get( "/rooms/{room_id}/leaderboard", tags=["房间"], - response_model=APILeaderboard, name="获取房间排行榜", description="获取房间内累计得分排行榜。", + responses={ + 200: api_doc( + "房间排行榜", + { + "leaderboard": list[ItemAttemptsCountModel], + "user_score": ItemAttemptsCountModel | None, + }, + ["user.country", "position"], + name="RoomLeaderboardResponse", + ) + }, ) async def get_room_leaderboard( db: Database, @@ -300,45 +327,43 @@ async def get_room_leaderboard( aggs_resp = [] user_agg = None for i, agg in enumerate(aggs): - resp = await ItemAttemptsResp.from_db(agg, db) - resp.position = i + 1 + includes = ["user.country"] + if agg.user_id == current_user.id: + includes.append("position") + resp = await ItemAttemptsCountModel.transform(agg, includes=includes) aggs_resp.append(resp) if agg.user_id == current_user.id: user_agg = resp - return APILeaderboard( - leaderboard=aggs_resp, - user_score=user_agg, - ) - -class RoomEvents(BaseModel): - """房间事件流返回模型。 - - beatmaps: 本次结果涉及的谱面列表。 - - beatmapsets: 谱面集映射。 - - current_playlist_item_id: 当前游玩列表(项目)项 ID。 - - events: 事件列表。 - - first_event_id / last_event_id: 事件范围。 - - playlist_items: 房间游玩列表(项目)详情。 - - room: 房间详情。 - - user: 关联用户列表。""" - - beatmaps: list[BeatmapResp] = Field(default_factory=list) - beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict) - current_playlist_item_id: int = 0 - events: list[MultiplayerEventResp] = Field(default_factory=list) - first_event_id: int = 0 - last_event_id: int = 0 - playlist_items: list[PlaylistResp] = Field(default_factory=list) - room: RoomResp - user: list[UserResp] = Field(default_factory=list) + return { + "leaderboard": aggs_resp, + "user_score": user_agg, + } @router.get( "/rooms/{room_id}/events", - response_model=RoomEvents, tags=["房间"], name="获取房间事件", description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。", + responses={ + 200: api_doc( + "房间事件", + { + "beatmaps": list[BeatmapModel], + "beatmapsets": list[BeatmapsetModel], + "current_playlist_item_id": int, + "events": list[MultiplayerEventResp], + "first_event_id": int, + "last_event_id": int, + "playlist_items": list[PlaylistModel], + "room": RoomModel, + "user": list[UserModel], + }, + ["country", "details", "scores"], + name="RoomEventsResponse", + ) + }, ) async def get_room_events( db: Database, @@ -402,28 +427,44 @@ async def get_room_events( room = (await db.exec(select(Room).where(Room.id == room_id))).first() if room is None: raise HTTPException(404, "Room not found") - room_resp = await RoomResp.from_db(room, db) - if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item: - current_playlist_item_id = room_resp.current_playlist_item.id + room_resp = await RoomModel.transform(room, includes=["current_playlist_item"]) + if room.category == RoomCategory.REALTIME: + current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"] users = await db.exec(select(User).where(col(User.id).in_(user_ids))) - user_resps = [await UserResp.from_db(user, db) for user in users] + user_resps = [await UserModel.transform(user, includes=["country"]) for user in users] + beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) - beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps] - beatmapset_resps = {} - for beatmap_resp in beatmap_resps: - beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset + beatmap_resps = [ + await BeatmapModel.transform( + beatmap, + ) + for beatmap in beatmaps + ] - playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()] + beatmapsets = [] + for beatmap in beatmaps: + if beatmap.beatmapset_id not in beatmapsets: + beatmapsets.append(beatmap.beatmapset) + beatmapset_resps = [ + await BeatmapsetModel.transform( + beatmapset, + ) + for beatmapset in beatmapsets + ] - return RoomEvents( - beatmaps=beatmap_resps, - beatmapsets=beatmapset_resps, - current_playlist_item_id=current_playlist_item_id, - events=event_resps, - first_event_id=first_event_id, - last_event_id=last_event_id, - playlist_items=playlist_items_resps, - room=room_resp, - user=user_resps, - ) + playlist_items_resps = [ + await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values() + ] + + return { + "beatmaps": beatmap_resps, + "beatmapsets": beatmapset_resps, + "current_playlist_item_id": current_playlist_item_id, + "events": event_resps, + "first_event_id": first_event_id, + "last_event_id": last_event_id, + "playlist_items": playlist_items_resps, + "room": room_resp, + "user": user_resps, + } diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 55eb6ec..fffcea2 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -9,7 +9,6 @@ from app.database import ( Playlist, Room, Score, - ScoreResp, ScoreToken, ScoreTokenResp, User, @@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType from app.database.score import ( LegacyScoreResp, MultiplayerScores, - ScoreAround, + MultiplayScoreDict, + ScoreModel, get_leaderboard, + get_score_position_by_id, process_score, process_user, ) @@ -49,7 +50,7 @@ from app.models.score import ( ) from app.service.beatmap_cache_service import get_beatmap_cache_service from app.service.user_cache_service import refresh_user_cache_background -from app.utils import utcnow +from app.utils import api_doc, utcnow from .router import router @@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession READ_SCORE_TIMEOUT = 10 +DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"] logger = log("Score") @@ -180,13 +182,15 @@ async def submit_score( await db.refresh(score) background_task.add_task(_process_user, score_id, user_id, redis, fetcher) - resp: ScoreResp = await ScoreResp.from_db(db, score) + resp = await ScoreModel.transform( + score, + ) score_gamemode = score.gamemode await db.commit() if user_id is not None: background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode) - background_task.add_task(_process_user_achievement, resp.id) + background_task.add_task(_process_user_achievement, resp["id"]) return resp @@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None: logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}") -class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel): +LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp + + +class BeatmapUserScore(BaseModel): position: int - score: T + score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm] -class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel): - scores: list[T] - user_score: BeatmapUserScore[T] | None = None +class BeatmapScores(BaseModel): + scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm] + user_score: BeatmapUserScore | None = None score_count: int = 0 @router.get( "/beatmaps/{beatmap_id}/scores", tags=["成绩"], - response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp], + responses={ + 200: { + "model": BeatmapScores, + "description": ( + "排行榜及当前用户成绩。\n\n" + f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`" + f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])})," + "否则为 `BeatmapScores[LegacyScoreResp]`。" + ), + } + }, name="获取谱面排行榜", - description=( - "获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n" - "如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`," - "否则为 `BeatmapScores[LegacyScoreResp]`。" - ), + description="获取指定谱面在特定条件下的排行榜及当前用户成绩。", ) async def get_beatmap_scores( db: Database, @@ -266,27 +279,46 @@ async def get_beatmap_scores( mods=sorted(mods), ) - user_score_resp = await user_score.to_resp(db, api_version) if user_score else None - resp = BeatmapScores( - scores=[await score.to_resp(db, api_version) for score in all_scores], - user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0) - if user_score_resp - else None, - score_count=count, - ) - return resp + user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None + return { + "scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores], + "user_score": ( + { + "score": user_score_resp, + "position": ( + await get_score_position_by_id( + db, + user_score.beatmap_id, + user_score.id, + mode=user_score.gamemode, + user=user_score.user, + ) + or 0 + ), + } + if user_score and user_score_resp + else None + ), + "score_count": count, + } @router.get( "/beatmaps/{beatmap_id}/scores/users/{user_id}", tags=["成绩"], - response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp], + responses={ + 200: { + "model": BeatmapUserScore, + "description": ( + "指定用户在指定谱面上的最高成绩\n\n" + "如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`," + f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])})," + "否则为 `BeatmapUserScore[LegacyScoreResp]`。" + ), + } + }, name="获取用户谱面最高成绩", - description=( - "获取指定用户在指定谱面上的最高成绩。\n\n" - "如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`," - "否则为 `BeatmapUserScore[LegacyScoreResp]`。" - ), + description="获取指定用户在指定谱面上的最高成绩。", ) async def get_user_beatmap_score( db: Database, @@ -318,23 +350,38 @@ async def get_user_beatmap_score( detail=f"Cannot find user {user_id}'s score on this beatmap", ) else: - resp = await user_score.to_resp(db, api_version=api_version) - return BeatmapUserScore( - position=resp.rank_global or 0, - score=resp, - ) + resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES) + return { + "position": ( + await get_score_position_by_id( + db, + user_score.beatmap_id, + user_score.id, + mode=user_score.gamemode, + user=user_score.user, + ) + or 0 + ), + "score": resp, + } @router.get( "/beatmaps/{beatmap_id}/scores/users/{user_id}/all", tags=["成绩"], - response_model=list[ScoreResp] | list[LegacyScoreResp], + responses={ + 200: api_doc( + ( + "用户谱面全部成绩\n\n" + "如果 `x-api-version >= 20220705`,返回值为 `Score`列表," + "否则为 `LegacyScoreResp`列表。" + ), + list[ScoreModel] | list[LegacyScoreResp], + DEFAULT_SCORE_INCLUDES, + ) + }, name="获取用户谱面全部成绩", - description=( - "获取指定用户在指定谱面上的全部成绩列表。\n\n" - "如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表," - "否则为 `LegacyScoreResp`列表。" - ), + description="获取指定用户在指定谱面上的全部成绩列表。", ) async def get_user_all_beatmap_scores( db: Database, @@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores( ) ).all() - return [await score.to_resp(db, api_version) for score in all_user_scores] + return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores] @router.post( @@ -413,9 +460,9 @@ async def create_solo_score( @router.put( "/beatmaps/{beatmap_id}/solo/scores/{token}", tags=["游玩"], - response_model=ScoreResp, name="提交单曲成绩", description="\n使用令牌提交单曲成绩。", + responses={200: api_doc("单曲成绩提交结果。", ScoreModel)}, ) async def submit_solo_score( background_task: BackgroundTasks, @@ -520,6 +567,7 @@ async def create_playlist_score( tags=["游玩"], name="提交房间项目成绩", description="\n提交房间游玩项目成绩。", + responses={200: api_doc("单曲成绩提交结果。", ScoreModel)}, ) async def submit_playlist_score( background_task: BackgroundTasks, @@ -560,13 +608,13 @@ async def submit_playlist_score( room_id, playlist_id, user_id, - score_resp.id, - score_resp.total_score, + score_resp["id"], + score_resp["total_score"], session, redis, ) await session.commit() - if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed: + if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]: await process_daily_challenge_score(session, user_id, room_id) await ItemAttemptsCount.get_or_create(room_id, user_id, session) await session.commit() @@ -575,15 +623,23 @@ async def submit_playlist_score( class IndexedScoreResp(MultiplayerScores): total: int - user_score: ScoreResp | None = None + user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm] @router.get( "/rooms/{room_id}/playlist/{playlist_id}/scores", - response_model=IndexedScoreResp, + # response_model=IndexedScoreResp, name="获取房间项目排行榜", description="获取房间游玩项目排行榜。", tags=["成绩"], + responses={ + 200: { + "description": ( + f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}" + ), + "model": IndexedScoreResp, + } + }, ) async def index_playlist_scores( session: Database, @@ -620,16 +676,14 @@ async def index_playlist_scores( scores = scores[:-1] user_score = None - score_resp = [await ScoreResp.from_db(session, score.score) for score in scores] + score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores] for score in score_resp: - score.position = await get_position(room_id, playlist_id, score.id, session) - if score.user_id == user_id: + if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[ + "user_id" + ] == user_id: user_score = score - - if room.category == RoomCategory.DAILY_CHALLENGE: - score_resp = [s for s in score_resp if s.passed] - if user_score and not user_score.passed: - user_score = None + user_score["position"] = await get_position(room_id, playlist_id, score["id"], session) + break resp = IndexedScoreResp( scores=score_resp, @@ -648,10 +702,16 @@ async def index_playlist_scores( @router.get( "/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}", - response_model=ScoreResp, name="获取房间项目单个成绩", description="获取指定房间游玩项目中单个成绩详情。", tags=["成绩"], + responses={ + 200: api_doc( + "房间项目单个成绩详情。", + ScoreModel, + [*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"], + ) + }, ) async def show_playlist_score( session: Database, @@ -687,39 +747,25 @@ async def show_playlist_score( break if not score_record: raise HTTPException(status_code=404, detail="Score not found") - resp = await ScoreResp.from_db(session, score_record.score) - resp.position = await get_position(room_id, playlist_id, score_id, session) + includes = [ + *Score.MULTIPLAYER_BASE_INCLUDES, + "position", + ] if completed: - scores = ( - await session.exec( - select(PlaylistBestScore).where( - PlaylistBestScore.playlist_id == playlist_id, - PlaylistBestScore.room_id == room_id, - ~User.is_restricted_query(col(PlaylistBestScore.user_id)), - ) - ) - ).all() - higher_scores = [] - lower_scores = [] - for score in scores: - resp = await ScoreResp.from_db(session, score.score) - if is_playlist and not resp.passed: - continue - if score.total_score > resp.total_score: - higher_scores.append(resp) - elif score.total_score < resp.total_score: - lower_scores.append(resp) - resp.scores_around = ScoreAround( - higher=MultiplayerScores(scores=higher_scores), - lower=MultiplayerScores(scores=lower_scores), - ) - + includes.append("scores_around") + resp = await ScoreModel.transform(score_record.score, includes=includes) return resp @router.get( "rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}", - response_model=ScoreResp, + responses={ + 200: api_doc( + "房间项目单个成绩详情。", + ScoreModel, + [*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"], + ) + }, name="获取房间项目用户成绩", description="获取指定用户在房间游玩项目中的成绩。", tags=["成绩"], @@ -749,8 +795,14 @@ async def get_user_playlist_score( if not score_record: raise HTTPException(status_code=404, detail="Score not found") - resp = await ScoreResp.from_db(session, score_record.score) - resp.position = await get_position(room_id, playlist_id, score_record.score_id, session) + resp = await ScoreModel.transform( + score_record.score, + includes=[ + *Score.MULTIPLAYER_BASE_INCLUDES, + "position", + "scores_around", + ], + ) return resp diff --git a/app/router/v2/user.py b/app/router/v2/user.py index ce3e080..31403f9 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -5,17 +5,16 @@ from app.config import settings from app.const import BANCHOBOT_ID from app.database import ( Beatmap, + BeatmapModel, BeatmapPlaycounts, - BeatmapPlaycountsResp, - BeatmapResp, - BeatmapsetResp, + BeatmapsetModel, User, - UserResp, ) +from app.database.beatmap_playcounts import BeatmapPlaycountsModel from app.database.best_scores import BestScore from app.database.events import Event -from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores -from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED +from app.database.score import Score, get_user_first_scores +from app.database.user import UserModel from app.dependencies.api_version import APIVersion from app.dependencies.cache import UserCacheService from app.dependencies.database import Database, get_redis @@ -26,24 +25,15 @@ from app.models.mods import API_MODS from app.models.score import GameMode from app.models.user import BeatmapsetType from app.service.user_cache_service import get_user_cache_service -from app.utils import utcnow +from app.utils import api_doc, utcnow from .router import router from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security -from pydantic import BaseModel from sqlmodel import exists, false, select from sqlmodel.sql.expression import col -class BatchUserResponse(BaseModel): - users: list[UserResp] - - -class BeatmapsPassedResponse(BaseModel): - beatmaps_passed: list[BeatmapResp] - - def _get_difficulty_reduction_mods() -> set[str]: mods: set[str] = set() for ruleset_mods in API_MODS.values(): @@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session @router.get( "/users/", - response_model=BatchUserResponse, + responses={ + 200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse") + }, name="批量获取用户信息", description="通过用户 ID 列表批量获取用户信息。", tags=["用户"], ) -@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False) -@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False) +@router.get("/users/lookup", include_in_schema=False) +@router.get("/users/lookup/", include_in_schema=False) @asset_proxy_response async def get_users( session: Database, @@ -108,16 +100,15 @@ async def get_users( # 将查询到的用户添加到缓存并返回 for searched_user in searched_users: if searched_user.id != BANCHOBOT_ID: - user_resp = await UserResp.from_db( + user_resp = await UserModel.transform( searched_user, - session, - include=SEARCH_INCLUDED, + includes=User.CARD_INCLUDES, ) cached_users.append(user_resp) # 异步缓存,不阻塞响应 background_task.add_task(cache_service.cache_user, user_resp) - response = BatchUserResponse(users=cached_users) + response = {"users": cached_users} return response else: searched_users = ( @@ -127,16 +118,15 @@ async def get_users( for searched_user in searched_users: if searched_user.id == BANCHOBOT_ID: continue - user_resp = await UserResp.from_db( + user_resp = await UserModel.transform( searched_user, - session, - include=SEARCH_INCLUDED, + includes=User.CARD_INCLUDES, ) users.append(user_resp) # 异步缓存 background_task.add_task(cache_service.cache_user, user_resp) - response = BatchUserResponse(users=users) + response = {"users": users} return response @@ -200,10 +190,12 @@ async def get_user_kudosu( @router.get( "/users/{user_id}/beatmaps-passed", - response_model=BeatmapsPassedResponse, name="获取用户已通过谱面", description="获取指定用户在给定谱面集中的已通过谱面列表。", tags=["用户"], + responses={ + 200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse") + }, ) @asset_proxy_response async def get_user_beatmaps_passed( @@ -226,7 +218,7 @@ async def get_user_beatmaps_passed( no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True, ): if not beatmapset_ids: - return BeatmapsPassedResponse(beatmaps_passed=[]) + return {"beatmaps_passed": []} if len(beatmapset_ids) > 50: raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items") @@ -255,7 +247,7 @@ async def get_user_beatmaps_passed( scores = (await session.exec(score_query)).all() if not scores: - return BeatmapsPassedResponse(beatmaps_passed=[]) + return {"beatmaps_passed": []} difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set() passed_beatmap_ids: set[int] = set() @@ -269,7 +261,7 @@ async def get_user_beatmaps_passed( continue passed_beatmap_ids.add(beatmap_id) if not passed_beatmap_ids: - return BeatmapsPassedResponse(beatmaps_passed=[]) + return {"beatmaps_passed": []} beatmaps = ( await session.exec( @@ -279,19 +271,24 @@ async def get_user_beatmaps_passed( ) ).all() - return BeatmapsPassedResponse( - beatmaps_passed=[ - await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps + return { + "beatmaps_passed": [ + await BeatmapModel.transform( + beatmap, + ) + for beatmap in beatmaps ] - ) + } @router.get( "/users/{user_id}/{ruleset}", - response_model=UserResp, name="获取用户信息(指定ruleset)", description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。", tags=["用户"], + responses={ + 200: api_doc("用户信息", UserModel, User.USER_INCLUDES), + }, ) @asset_proxy_response async def get_user_info_ruleset( @@ -325,29 +322,26 @@ async def get_user_info_ruleset( if should_not_show: raise HTTPException(404, detail="User not found") - include = SEARCH_INCLUDED - if searched_is_self: - include = ALL_INCLUDED - user_resp = await UserResp.from_db( + user_resp = await UserModel.transform( searched_user, - session, - include=include, + includes=User.USER_INCLUDES, ruleset=ruleset, ) # 异步缓存结果 background_task.add_task(cache_service.cache_user, user_resp, ruleset) - return user_resp -@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False) +@router.get("/users/{user_id}/", include_in_schema=False) @router.get( "/users/{user_id}", - response_model=UserResp, name="获取用户信息", description="通过用户 ID 或用户名获取单个用户的详细信息。", tags=["用户"], + responses={ + 200: api_doc("用户信息", UserModel, User.USER_INCLUDES), + }, ) @asset_proxy_response async def get_user_info( @@ -381,27 +375,31 @@ async def get_user_info( if should_not_show: raise HTTPException(404, detail="User not found") - include = SEARCH_INCLUDED - if searched_is_self: - include = ALL_INCLUDED - user_resp = await UserResp.from_db( + user_resp = await UserModel.transform( searched_user, - session, - include=include, + includes=User.USER_INCLUDES, ) # 异步缓存结果 background_task.add_task(cache_service.cache_user, user_resp) - return user_resp +beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"] + + @router.get( "/users/{user_id}/beatmapsets/{type}", - response_model=list[BeatmapsetResp | BeatmapPlaycountsResp], name="获取用户谱面集列表", description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。", tags=["用户"], + responses={ + 200: api_doc( + "当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`", + list[BeatmapsetModel] | list[BeatmapPlaycountsModel], + beatmapset_includes, + ) + }, ) @asset_proxy_response async def get_user_beatmapsets( @@ -417,11 +415,7 @@ async def get_user_beatmapsets( # 先尝试从缓存获取 cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset) if cached_result is not None: - # 根据类型恢复对象 - if type == BeatmapsetType.MOST_PLAYED: - return [BeatmapPlaycountsResp(**item) for item in cached_result] - else: - return [BeatmapsetResp(**item) for item in cached_result] + return cached_result user = await session.get(User, user_id) if not user or user.id == BANCHOBOT_ID: @@ -444,7 +438,10 @@ async def get_user_beatmapsets( raise HTTPException(404, detail="User not found") favourites = await user.awaitable_attrs.favourite_beatmapsets resp = [ - await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites + await BeatmapsetModel.transform( + favourite.beatmapset, session=session, user=user, includes=beatmapset_includes + ) + for favourite in favourites ] elif type == BeatmapsetType.MOST_PLAYED: @@ -459,7 +456,10 @@ async def get_user_beatmapsets( .limit(limit) .offset(offset) ) - resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played] + resp = [ + await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes) + for most_played_beatmap in most_played + ] else: raise HTTPException(400, detail="Invalid beatmapset type") @@ -477,7 +477,6 @@ async def get_user_beatmapsets( @router.get( "/users/{user_id}/scores/{type}", - response_model=list[ScoreResp] | list[LegacyScoreResp], name="获取用户成绩列表", description=( "获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n" @@ -523,6 +522,7 @@ async def get_user_scores( gamemode = mode or db_user.playmode order_by = None where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode) + includes = Score.USER_PROFILE_INCLUDES.copy() if not include_fails: where_clause &= col(Score.passed).is_(True) if type == "pinned": @@ -531,6 +531,7 @@ async def get_user_scores( elif type == "best": where_clause &= exists().where(col(BestScore.score_id) == Score.id) order_by = col(Score.pp).desc() + includes.append("weight") elif type == "recent": where_clause &= Score.ended_at > utcnow() - timedelta(hours=24) order_by = col(Score.ended_at).desc() @@ -551,6 +552,7 @@ async def get_user_scores( await score.to_resp( session, api_version, + includes=includes, ) for score in scores ] diff --git a/app/service/beatmapset_cache_service.py b/app/service/beatmapset_cache_service.py index 16ec2e2..dbb7682 100644 --- a/app/service/beatmapset_cache_service.py +++ b/app/service/beatmapset_cache_service.py @@ -3,14 +3,14 @@ Beatmapset缓存服务 用于缓存beatmapset数据,减少数据库查询频率 """ -from datetime import datetime import hashlib import json from typing import TYPE_CHECKING from app.config import settings -from app.database.beatmapset import BeatmapsetResp +from app.database import BeatmapsetDict from app.log import logger +from app.utils import safe_json_dumps from redis.asyncio import Redis @@ -18,20 +18,6 @@ if TYPE_CHECKING: pass -class DateTimeEncoder(json.JSONEncoder): - """处理datetime序列化的JSON编码器""" - - def default(self, obj): - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - -def safe_json_dumps(data) -> str: - """安全的JSON序列化,处理datetime对象""" - return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False) - - def generate_hash(data) -> str: """生成数据的MD5哈希值""" content = data if isinstance(data, str) else safe_json_dumps(data) @@ -57,15 +43,14 @@ class BeatmapsetCacheService: """生成搜索结果缓存键""" return f"beatmapset_search:{query_hash}:{cursor_hash}" - async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None: + async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None: """从缓存获取beatmapset信息""" try: cache_key = self._get_beatmapset_cache_key(beatmapset_id) cached_data = await self.redis.get(cache_key) if cached_data: logger.debug(f"Beatmapset cache hit for {beatmapset_id}") - data = json.loads(cached_data) - return BeatmapsetResp(**data) + return json.loads(cached_data) return None except (ValueError, TypeError, AttributeError) as e: logger.error(f"Error getting beatmapset from cache: {e}") @@ -73,24 +58,21 @@ class BeatmapsetCacheService: async def cache_beatmapset( self, - beatmapset_resp: BeatmapsetResp, + beatmapset_resp: BeatmapsetDict, expire_seconds: int | None = None, ): """缓存beatmapset信息""" try: if expire_seconds is None: expire_seconds = self._default_ttl - if beatmapset_resp.id is None: - logger.warning("Cannot cache beatmapset with None id") - return - cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id) - cached_data = beatmapset_resp.model_dump_json() + cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"]) + cached_data = safe_json_dumps(beatmapset_resp) await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore - logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s") + logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s") except (ValueError, TypeError, AttributeError) as e: logger.error(f"Error caching beatmapset: {e}") - async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None: + async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None: """从缓存获取通过beatmap ID查找的beatmapset信息""" try: cache_key = self._get_beatmap_lookup_cache_key(beatmap_id) @@ -98,7 +80,7 @@ class BeatmapsetCacheService: if cached_data: logger.debug(f"Beatmap lookup cache hit for {beatmap_id}") data = json.loads(cached_data) - return BeatmapsetResp(**data) + return data return None except (ValueError, TypeError, AttributeError) as e: logger.error(f"Error getting beatmap lookup from cache: {e}") @@ -107,7 +89,7 @@ class BeatmapsetCacheService: async def cache_beatmap_lookup( self, beatmap_id: int, - beatmapset_resp: BeatmapsetResp, + beatmapset_resp: BeatmapsetDict, expire_seconds: int | None = None, ): """缓存通过beatmap ID查找的beatmapset信息""" @@ -115,7 +97,7 @@ class BeatmapsetCacheService: if expire_seconds is None: expire_seconds = self._default_ttl cache_key = self._get_beatmap_lookup_cache_key(beatmap_id) - cached_data = beatmapset_resp.model_dump_json() + cached_data = safe_json_dumps(beatmapset_resp) await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s") except (ValueError, TypeError, AttributeError) as e: diff --git a/app/service/beatmapset_update_service.py b/app/service/beatmapset_update_service.py index b642ca7..5c09123 100644 --- a/app/service/beatmapset_update_service.py +++ b/app/service/beatmapset_update_service.py @@ -3,12 +3,12 @@ from datetime import timedelta from enum import Enum import math import random -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, cast from app.config import OldScoreProcessingMode, settings -from app.database.beatmap import Beatmap, BeatmapResp +from app.database.beatmap import Beatmap, BeatmapDict from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta -from app.database.beatmapset import Beatmapset, BeatmapsetResp +from app.database.beatmapset import Beatmapset, BeatmapsetDict from app.database.score import Score from app.dependencies.database import get_redis, with_db from app.dependencies.storage import get_storage_service @@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = { SCHEDULER_INTERVAL_MINUTES = 2 +class EnsuredBeatmap(BeatmapDict): + checksum: str + ranked: int + + +class EnsuredBeatmapset(BeatmapsetDict): + ranked: int + ranked_date: datetime.datetime + last_updated: datetime.datetime + play_count: int + beatmaps: list[EnsuredBeatmap] + + class ProcessingBeatmapset: - def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None: + def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None: self.beatmapset = beatmapset - self.status = BeatmapRankStatus(self.beatmapset.ranked) + self.status = BeatmapRankStatus(self.beatmapset["ranked"]) self.record = record def calculate_next_sync_time( @@ -76,19 +89,19 @@ class ProcessingBeatmapset: now = utcnow() if self.status == BeatmapRankStatus.QUALIFIED: - assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps" - time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds() + assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps" + time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds() baseline = max(MIN_DELTA, time_to_ranked / 2) next_delta = max(MIN_DELTA, baseline) elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}: - seconds_since_update = (now - self.beatmapset.last_updated).total_seconds() + seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds() factor_update = max(1.0, seconds_since_update / TAU) - factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count) + factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"]) status_factor = STATUS_FACTOR[self.status] baseline = BASE * factor_play / factor_update * status_factor next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1))) elif self.status == BeatmapRankStatus.GRAVEYARD: - days_since_update = (now - self.beatmapset.last_updated).days + days_since_update = (now - self.beatmapset["last_updated"]).days doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS delta = MIN_DELTA * (2**doubling_periods) max_seconds = GRAVEYARD_MAX_DAYS * 86400 @@ -105,21 +118,24 @@ class ProcessingBeatmapset: @property def beatmapset_changed(self) -> bool: - return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked) + return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"]) @property def changed_beatmaps(self) -> list[ChangedBeatmap]: changed_beatmaps = [] - for bm in self.beatmapset.beatmaps: - saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None) + for bm in self.beatmapset["beatmaps"]: + saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None) if not saved or saved["is_deleted"]: - changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED)) - elif saved["md5"] != bm.checksum: - changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED)) - elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked): - changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED)) + changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED)) + elif saved["md5"] != bm["checksum"]: + changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED)) + elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]): + changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED)) for saved in self.record.beatmaps: - if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]: + if ( + not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"]) + and not saved["is_deleted"] + ): changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED)) return changed_beatmaps @@ -132,7 +148,7 @@ class BeatmapsetUpdateService: async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool: beatmapset = await self.fetcher.get_beatmapset(beatmapset_id) if immediate: - await self._sync_immediately(beatmapset) + await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset)) logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ") return True await self.add(beatmapset) @@ -172,7 +188,7 @@ class BeatmapsetUpdateService: BeatmapSync( beatmapset_id=missing, beatmap_status=BeatmapRankStatus.GRAVEYARD, - next_sync_time=datetime.datetime.max, + next_sync_time=datetime.datetime(year=6000, month=1, day=1), beatmaps=[], ) ) @@ -185,11 +201,13 @@ class BeatmapsetUpdateService: await session.commit() self._adding_missing = False - async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True): + async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True): + beatmapset = cast(EnsuredBeatmapset, set) async with with_db() as session: - sync_record = await session.get(BeatmapSync, beatmapset.id) + beatmapset_id = beatmapset["id"] + sync_record = await session.get(BeatmapSync, beatmapset_id) if not sync_record: - database_beatmapset = await session.get(Beatmapset, beatmapset.id) + database_beatmapset = await session.get(Beatmapset, beatmapset_id) if database_beatmapset: status = BeatmapRankStatus(database_beatmapset.beatmap_status) await database_beatmapset.awaitable_attrs.beatmaps @@ -203,19 +221,29 @@ class BeatmapsetUpdateService: for bm in database_beatmapset.beatmaps ] else: - status = BeatmapRankStatus(beatmapset.ranked) - beatmaps = [ - SavedBeatmapMeta( - beatmap_id=bm.id, - md5=bm.checksum, - is_deleted=False, - beatmap_status=BeatmapRankStatus(bm.ranked), + ranked = beatmapset.get("ranked") + if ranked is None: + raise ValueError("ranked field is required") + status = BeatmapRankStatus(ranked) + beatmap_list = beatmapset.get("beatmaps", []) + beatmaps = [] + for bm in beatmap_list: + bm_id = bm.get("id") + checksum = bm.get("checksum") + ranked = bm.get("ranked") + if bm_id is None or checksum is None or ranked is None: + continue + beatmaps.append( + SavedBeatmapMeta( + beatmap_id=bm_id, + md5=checksum, + is_deleted=False, + beatmap_status=BeatmapRankStatus(ranked), + ) ) - for bm in beatmapset.beatmaps - ] sync_record = BeatmapSync( - beatmapset_id=beatmapset.id, + beatmapset_id=beatmapset_id, beatmaps=beatmaps, beatmap_status=status, ) @@ -223,13 +251,27 @@ class BeatmapsetUpdateService: await session.commit() await session.refresh(sync_record) else: - sync_record.beatmaps = [ - SavedBeatmapMeta( - beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked) + ranked = beatmapset.get("ranked") + if ranked is None: + raise ValueError("ranked field is required") + beatmap_list = beatmapset.get("beatmaps", []) + beatmaps = [] + for bm in beatmap_list: + bm_id = bm.get("id") + checksum = bm.get("checksum") + bm_ranked = bm.get("ranked") + if bm_id is None or checksum is None or bm_ranked is None: + continue + beatmaps.append( + SavedBeatmapMeta( + beatmap_id=bm_id, + md5=checksum, + is_deleted=False, + beatmap_status=BeatmapRankStatus(bm_ranked), + ) ) - for bm in beatmapset.beatmaps - ] - sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked) + sync_record.beatmaps = beatmaps + sync_record.beatmap_status = BeatmapRankStatus(ranked) if calculate_next_sync: processing = ProcessingBeatmapset(beatmapset, sync_record) next_time_delta = processing.calculate_next_sync_time() @@ -238,17 +280,19 @@ class BeatmapsetUpdateService: await BeatmapsetUpdateService._sync_immediately(self, beatmapset) return sync_record.next_sync_time = utcnow() + next_time_delta - logger.opt(colors=True).info(f"[{beatmapset.id}] next sync at {sync_record.next_sync_time}") + beatmapset_id = beatmapset.get("id") + if beatmapset_id: + logger.opt(colors=True).debug(f"[{beatmapset_id}] next sync at {sync_record.next_sync_time}") await session.commit() - async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None: + async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None: async with with_db() as session: - record = await session.get(BeatmapSync, beatmapset.id) + record = await session.get(BeatmapSync, beatmapset["id"]) if not record: record = BeatmapSync( - beatmapset_id=beatmapset.id, + beatmapset_id=beatmapset["id"], beatmaps=[], - beatmap_status=BeatmapRankStatus(beatmapset.ranked), + beatmap_status=BeatmapRankStatus(beatmapset["ranked"]), ) session.add(record) await session.commit() @@ -261,19 +305,18 @@ class BeatmapsetUpdateService: record: BeatmapSync, session: AsyncSession, *, - beatmapset: BeatmapsetResp | None = None, + beatmapset: EnsuredBeatmapset | None = None, ): - logger.opt(colors=True).info(f"[{record.beatmapset_id}] syncing...") + logger.opt(colors=True).debug(f"[{record.beatmapset_id}] syncing...") if beatmapset is None: try: - beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id) + beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id)) except Exception as e: if isinstance(e, HTTPStatusError) and e.response.status_code == 404: logger.opt(colors=True).warning( f"[{record.beatmapset_id}] beatmapset not found (404), removing from sync list" ) await session.delete(record) - await session.commit() return if isinstance(e, HTTPError): logger.opt(colors=True).warning( @@ -292,20 +335,20 @@ class BeatmapsetUpdateService: if changed: record.beatmaps = [ SavedBeatmapMeta( - beatmap_id=bm.id, - md5=bm.checksum, + beatmap_id=bm["id"], + md5=bm["checksum"], is_deleted=False, - beatmap_status=BeatmapRankStatus(bm.ranked), + beatmap_status=BeatmapRankStatus(bm["ranked"]), ) - for bm in beatmapset.beatmaps + for bm in beatmapset["beatmaps"] ] - record.beatmap_status = BeatmapRankStatus(beatmapset.ranked) + record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"]) record.consecutive_no_change = 0 bg_tasks.add_task( self._process_changed_beatmaps, changed_beatmaps, - beatmapset.beatmaps, + beatmapset["beatmaps"], ) bg_tasks.add_task( self._process_changed_beatmapset, @@ -317,13 +360,13 @@ class BeatmapsetUpdateService: next_time_delta = processing.calculate_next_sync_time() if not next_time_delta: logger.opt(colors=True).info( - f"[{beatmapset.id}] beatmapset has transformed to ranked or loved," + f"[{beatmapset['id']}] beatmapset has transformed to ranked or loved," f" removing from sync list" ) await session.delete(record) else: record.next_sync_time = utcnow() + next_time_delta - logger.opt(colors=True).info(f"[{record.beatmapset_id}] next sync at {record.next_sync_time}") + logger.opt(colors=True).debug(f"[{record.beatmapset_id}] next sync at {record.next_sync_time}") async def _update_beatmaps(self): async with with_db() as session: @@ -338,18 +381,18 @@ class BeatmapsetUpdateService: await self.sync(record, session) await session.commit() - async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp): + async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset): async with with_db() as session: - db_beatmapset = await session.get(Beatmapset, beatmapset.id) - new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) + db_beatmapset = await session.get(Beatmapset, beatmapset["id"]) + new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType] if db_beatmapset: await session.merge(new_beatmapset) - await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id) + await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"]) await session.commit() - async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]): + async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]): storage_service = get_storage_service() - beatmaps = {bm.id: bm for bm in beatmaps_list} + beatmaps = {bm["id"]: bm for bm in beatmaps_list} async with with_db() as session: @@ -380,9 +423,9 @@ class BeatmapsetUpdateService: ) continue logger.opt(colors=True).info( - f"[{beatmap.beatmapset_id}] adding beatmap {beatmap.id}" + f"[{beatmap['beatmapset_id']}] adding beatmap {beatmap['id']}" ) - await Beatmap.from_resp_no_save(session, beatmap) + await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType] else: beatmap = beatmaps.get(change.beatmap_id) if not beatmap: @@ -391,10 +434,10 @@ class BeatmapsetUpdateService: ) continue logger.opt(colors=True).info( - f"[{beatmap.beatmapset_id}] processing beatmap {beatmap.id} " + f"[{beatmap['beatmapset_id']}] processing beatmap {beatmap['id']} " f"change {change.type}" ) - new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) + new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType] existing_beatmap = await session.get(Beatmap, change.beatmap_id) if existing_beatmap: await session.merge(new_db_beatmap) diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index d2702ef..e1c475c 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -4,16 +4,15 @@ """ import asyncio -from datetime import datetime import json from typing import TYPE_CHECKING, Literal from app.config import settings -from app.database.statistics import UserStatistics, UserStatisticsResp +from app.database.statistics import UserStatistics, UserStatisticsModel from app.helpers.asset_proxy_helper import replace_asset_urls from app.log import logger from app.models.score import GameMode -from app.utils import utcnow +from app.utils import safe_json_dumps, utcnow from redis.asyncio import Redis from sqlmodel import col, select @@ -23,20 +22,6 @@ if TYPE_CHECKING: pass -class DateTimeEncoder(json.JSONEncoder): - """自定义 JSON 编码器,支持 datetime 序列化""" - - def default(self, obj): - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - -def safe_json_dumps(data) -> str: - """安全的 JSON 序列化,支持 datetime 对象""" - return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) - - class RankingCacheService: """用户排行榜缓存服务""" @@ -311,7 +296,7 @@ class RankingCacheService: col(UserStatistics.pp) > 0, col(UserStatistics.is_ranked).is_(True), ] - include = ["user"] + include = UserStatistics.RANKING_INCLUDES.copy() if type == "performance": order_by = col(UserStatistics.pp).desc() @@ -321,6 +306,7 @@ class RankingCacheService: if country: wheres.append(col(UserStatistics.user).has(country_code=country.upper())) + include.append("country_rank") # 获取总用户数用于统计 total_users_query = select(UserStatistics).where(*wheres) @@ -353,9 +339,9 @@ class RankingCacheService: # 转换为响应格式并确保正确序列化 ranking_data = [] for statistics in statistics_data: - user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include) + user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include) - user_dict = user_stats_resp.model_dump() + user_dict = user_stats_resp # 应用资源代理处理 if settings.enable_asset_proxy: diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py index 34fefba..2485261 100644 --- a/app/service/redis_message_system.py +++ b/app/service/redis_message_system.py @@ -8,14 +8,14 @@ import asyncio from datetime import datetime import json -import time from typing import Any -from app.database.chat import ChatMessage, ChatMessageResp, MessageType -from app.database.user import RANKING_INCLUDES, User, UserResp +from app.database import ChatMessageDict +from app.database.chat import ChatMessage, ChatMessageModel, MessageType +from app.database.user import User, UserModel from app.dependencies.database import get_redis_message, with_db from app.log import logger -from app.utils import bg_tasks +from app.utils import bg_tasks, safe_json_dumps class RedisMessageSystem: @@ -35,7 +35,7 @@ class RedisMessageSystem: content: str, is_action: bool = False, user_uuid: str | None = None, - ) -> ChatMessageResp: + ) -> "ChatMessageDict": """ 发送消息 - 立即存储到 Redis 并返回 @@ -47,7 +47,7 @@ class RedisMessageSystem: user_uuid: 用户UUID Returns: - ChatMessageResp: 消息响应对象 + ChatMessage: 消息响应对象 """ # 生成消息ID和时间戳 message_id = await self._generate_message_id(channel_id) @@ -57,28 +57,16 @@ class RedisMessageSystem: if not user.id: raise ValueError("User ID is required") - # 获取频道类型以判断是否需要存储到数据库 - async with with_db() as session: - from app.database.chat import ChannelType, ChatChannel - - from sqlmodel import select - - channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)) - channel_type = channel_result.first() - is_multiplayer = channel_type == ChannelType.MULTIPLAYER - # 准备消息数据 - message_data = { + message_data: "ChatMessageDict" = { "message_id": message_id, "channel_id": channel_id, "sender_id": user.id, "content": content, - "timestamp": timestamp.isoformat(), - "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, + "timestamp": timestamp, + "type": MessageType.ACTION if is_action else MessageType.PLAIN, "uuid": user_uuid or "", - "status": "cached", # Redis 缓存状态 - "created_at": time.time(), - "is_multiplayer": is_multiplayer, # 标记是否为多人房间消息 + "is_action": is_action, } # 立即存储到 Redis @@ -86,51 +74,13 @@ class RedisMessageSystem: # 创建响应对象 async with with_db() as session: - user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES) + user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES) + message_data["sender"] = user_resp - # 确保 statistics 不为空 - if user_resp.statistics is None: - from app.database.statistics import UserStatisticsResp + logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") + return message_data - user_resp.statistics = UserStatisticsResp( - mode=user.playmode, - global_rank=0, - country_rank=0, - pp=0.0, - ranked_score=0, - hit_accuracy=0.0, - play_count=0, - play_time=0, - total_score=0, - total_hits=0, - maximum_combo=0, - replays_watched_by_others=0, - is_ranked=False, - grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, - level={"current": 1, "progress": 0}, - ) - - response = ChatMessageResp( - message_id=message_id, - channel_id=channel_id, - content=content, - timestamp=timestamp, - sender_id=user.id, - sender=user_resp, - is_action=is_action, - uuid=user_uuid, - ) - - if is_multiplayer: - logger.info( - f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}," - " will not be persisted to database" - ) - else: - logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") - return response - - async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]: + async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]: """ 获取频道消息 - 优先从 Redis 获取最新消息 @@ -140,9 +90,9 @@ class RedisMessageSystem: since: 起始消息ID Returns: - List[ChatMessageResp]: 消息列表 + List[ChatMessageDict]: 消息列表 """ - messages = [] + messages: list["ChatMessageDict"] = [] try: # 从 Redis 获取最新消息 @@ -154,45 +104,21 @@ class RedisMessageSystem: # 获取发送者信息 sender = await session.get(User, msg_data["sender_id"]) if sender: - user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) + user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES) - if user_resp.statistics is None: - from app.database.statistics import UserStatisticsResp + from app.database.chat import ChatMessageDict - user_resp.statistics = UserStatisticsResp( - mode=sender.playmode, - global_rank=0, - country_rank=0, - pp=0.0, - ranked_score=0, - hit_accuracy=0.0, - play_count=0, - play_time=0, - total_score=0, - total_hits=0, - maximum_combo=0, - replays_watched_by_others=0, - is_ranked=False, - grade_counts={ - "ssh": 0, - "ss": 0, - "sh": 0, - "s": 0, - "a": 0, - }, - level={"current": 1, "progress": 0}, - ) - - message_resp = ChatMessageResp( - message_id=msg_data["message_id"], - channel_id=msg_data["channel_id"], - content=msg_data["content"], - timestamp=datetime.fromisoformat(msg_data["timestamp"]), - sender_id=msg_data["sender_id"], - sender=user_resp, - is_action=msg_data["type"] == MessageType.ACTION.value, - uuid=msg_data.get("uuid") or None, - ) + message_resp: ChatMessageDict = { + "message_id": msg_data["message_id"], + "channel_id": msg_data["channel_id"], + "content": msg_data["content"], + "timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType] + "sender_id": msg_data["sender_id"], + "sender": user_resp, + "is_action": msg_data["type"] == MessageType.ACTION.value, + "uuid": msg_data.get("uuid") or None, + "type": MessageType(msg_data["type"]), + } messages.append(message_resp) # 如果 Redis 消息不够,从数据库补充 @@ -216,86 +142,46 @@ class RedisMessageSystem: return message_id - async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]): + async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict): """存储消息到 Redis""" try: - # 检查是否是多人房间消息 - is_multiplayer = message_data.get("is_multiplayer", False) - - # 存储消息数据 - await self.redis.hset( + # 存储消息数据为 JSON 字符串 + await self.redis.set( f"msg:{channel_id}:{message_id}", - mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()}, + safe_json_dumps(message_data), + ex=604800, # 7天过期 ) - # 设置消息过期时间(7天) - await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800) - - # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) + # 添加到频道消息列表(按时间排序) channel_messages_key = f"channel:{channel_id}:messages" - # 更健壮的键类型检查和清理 + # 检查并清理错误类型的键 try: key_type = await self.redis.type(channel_messages_key) - if key_type == "none": - # 键不存在,这是正常的 - pass - elif key_type != "zset": - # 键类型错误,需要清理 + if key_type not in ("none", "zset"): logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") await self.redis.delete(channel_messages_key) - - # 验证删除是否成功 - verify_type = await self.redis.type(channel_messages_key) - if verify_type != "none": - logger.error( - f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}" - ) - # 强制删除 - await self.redis.unlink(channel_messages_key) - except Exception as type_check_error: logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") - # 如果检查失败,尝试强制删除键以确保清理 - try: - await self.redis.delete(channel_messages_key) - except Exception: - # 最后的努力:使用unlink - try: - await self.redis.unlink(channel_messages_key) - except Exception as final_error: - logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}") + await self.redis.delete(channel_messages_key) # 添加到频道消息列表(sorted set) - try: - await self.redis.zadd( - channel_messages_key, - mapping={f"msg:{channel_id}:{message_id}": message_id}, - ) - except Exception as zadd_error: - logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}") - # 如果添加失败,再次尝试清理并重试 - await self.redis.delete(channel_messages_key) - await self.redis.zadd( - channel_messages_key, - mapping={f"msg:{channel_id}:{message_id}": message_id}, - ) + await self.redis.zadd( + channel_messages_key, + mapping={f"msg:{channel_id}:{message_id}": message_id}, + ) # 保持频道消息列表大小(最多1000条) await self.redis.zremrangebyrank(channel_messages_key, 0, -1001) - # 只有非多人房间消息才添加到待持久化队列 - if not is_multiplayer: - await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}") - logger.debug(f"Message {message_id} added to persistence queue") - else: - logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue") + await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}") + logger.debug(f"Message {message_id} added to persistence queue") except Exception as e: logger.error(f"Failed to store message to Redis: {e}") raise - async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]: + async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]: """从 Redis 获取消息""" try: # 获取消息键列表,按消息ID排序 @@ -314,28 +200,16 @@ class RedisMessageSystem: messages = [] for key in message_keys: - # 获取消息数据 - raw_data = await self.redis.hgetall(key) + # 获取消息数据(JSON 字符串) + raw_data = await self.redis.get(key) if raw_data: - # 解码数据 - message_data: dict[str, Any] = {} - for k, v in raw_data.items(): - # 尝试解析 JSON - try: - if k in ["grade_counts", "level"] or v.startswith(("{", "[")): - message_data[k] = json.loads(v) - elif k in ["message_id", "channel_id", "sender_id"]: - message_data[k] = int(v) - elif k == "is_multiplayer": - message_data[k] = v == "True" - elif k == "created_at": - message_data[k] = float(v) - else: - message_data[k] = v - except (json.JSONDecodeError, ValueError): - message_data[k] = v - - messages.append(message_data) + try: + # 解析 JSON 字符串为字典 + message_data = json.loads(raw_data) + messages.append(message_data) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode message JSON from {key}: {e}") + continue # 确保消息按ID正序排序(时间顺序) messages.sort(key=lambda x: x.get("message_id", 0)) @@ -350,15 +224,15 @@ class RedisMessageSystem: logger.error(f"Failed to get messages from Redis: {e}") return [] - async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int): + async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int): """从数据库补充历史消息""" try: # 找到最小的消息ID min_id = float("inf") if existing_messages: for msg in existing_messages: - if msg.message_id is not None and msg.message_id < min_id: - min_id = msg.message_id + if msg["message_id"] is not None and msg["message_id"] < min_id: + min_id = msg["message_id"] needed = limit - len(existing_messages) @@ -378,13 +252,13 @@ class RedisMessageSystem: db_messages = (await session.exec(query)).all() for msg in reversed(db_messages): # 按时间正序插入 - msg_resp = await ChatMessageResp.from_db(msg, session) + msg_resp = await ChatMessageModel.transform(msg, includes=["sender"]) existing_messages.insert(0, msg_resp) except Exception as e: logger.error(f"Failed to backfill from database: {e}") - async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]: + async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]: """仅从数据库获取消息(回退方案)""" try: async with with_db() as session: @@ -402,7 +276,7 @@ class RedisMessageSystem: messages = (await session.exec(query)).all() - results = [await ChatMessageResp.from_db(msg, session) for msg in messages] + results = await ChatMessageModel.transform_many(messages, includes=["sender"]) # 如果是 since > 0,保持正序;否则反转为时间正序 if since == 0: @@ -450,27 +324,17 @@ class RedisMessageSystem: # 解析频道ID和消息ID channel_id, message_id = map(int, key.split(":")) - # 从 Redis 获取消息数据 - raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}") + # 从 Redis 获取消息数据(JSON 字符串) + raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}") if not raw_data: continue - # 解码数据 - message_data = {} - for k, v in raw_data.items(): - message_data[k] = v - - # 检查是否是多人房间消息,如果是则跳过数据库存储 - is_multiplayer = message_data.get("is_multiplayer", "False") == "True" - if is_multiplayer: - # 多人房间消息不存储到数据库,直接标记为已跳过 - await self.redis.hset( - f"msg:{channel_id}:{message_id}", - "status", - "skipped_multiplayer", - ) - logger.debug(f"Message {message_id} in multiplayer room skipped from database storage") + # 解析 JSON 字符串为字典 + try: + message_data = json.loads(raw_data) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}") continue # 检查消息是否已存在于数据库 @@ -491,13 +355,6 @@ class RedisMessageSystem: session.add(db_message) - # 更新 Redis 中的状态 - await self.redis.hset( - f"msg:{channel_id}:{message_id}", - "status", - "persisted", - ) - logger.debug(f"Message {message_id} persisted to database") except Exception as e: diff --git a/app/service/room.py b/app/service/room.py index cec7029..3204383 100644 --- a/app/service/room.py +++ b/app/service/room.py @@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room: - db_room = room.to_room() - db_room.host_id = host_id + db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})}) db_room.starts_at = utcnow() db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0) session.add(db_room) diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py index 3da0ab3..5e634fd 100644 --- a/app/service/user_cache_service.py +++ b/app/service/user_cache_service.py @@ -3,19 +3,19 @@ 用于缓存用户信息,提供热缓存和实时刷新功能 """ -from datetime import datetime import json from typing import TYPE_CHECKING, Any from app.config import settings from app.const import BANCHOBOT_ID -from app.database import User, UserResp -from app.database.score import LegacyScoreResp, ScoreResp -from app.database.user import SEARCH_INCLUDED +from app.database import User +from app.database.score import LegacyScoreResp +from app.database.user import UserDict, UserModel from app.dependencies.database import with_db from app.helpers.asset_proxy_helper import replace_asset_urls from app.log import logger from app.models.score import GameMode +from app.utils import safe_json_dumps from redis.asyncio import Redis from sqlmodel import col, select @@ -25,20 +25,6 @@ if TYPE_CHECKING: pass -class DateTimeEncoder(json.JSONEncoder): - """自定义 JSON 编码器,支持 datetime 序列化""" - - def default(self, obj): - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - -def safe_json_dumps(data: Any) -> str: - """安全的 JSON 序列化,支持 datetime 对象""" - return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False) - - class UserCacheService: """用户缓存服务""" @@ -125,7 +111,7 @@ class UserCacheService: """生成用户谱面集缓存键""" return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" - async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None: + async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None: """从缓存获取用户信息""" try: cache_key = self._get_user_cache_key(user_id, ruleset) @@ -133,7 +119,7 @@ class UserCacheService: if cached_data: logger.debug(f"User cache hit for user {user_id}") data = json.loads(cached_data) - return UserResp(**data) + return data return None except Exception as e: logger.error(f"Error getting user from cache: {e}") @@ -141,7 +127,7 @@ class UserCacheService: async def cache_user( self, - user_resp: UserResp, + user_resp: UserDict, ruleset: GameMode | None = None, expire_seconds: int | None = None, ): @@ -149,13 +135,10 @@ class UserCacheService: try: if expire_seconds is None: expire_seconds = settings.user_cache_expire_seconds - if user_resp.id is None: - logger.warning("Cannot cache user with None id") - return - cache_key = self._get_user_cache_key(user_resp.id, ruleset) - cached_data = user_resp.model_dump_json() + cache_key = self._get_user_cache_key(user_resp["id"], ruleset) + cached_data = safe_json_dumps(user_resp) await self.redis.setex(cache_key, expire_seconds, cached_data) - logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s") + logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s") except Exception as e: logger.error(f"Error caching user: {e}") @@ -168,10 +151,9 @@ class UserCacheService: limit: int = 100, offset: int = 0, is_legacy: bool = False, - ) -> list[ScoreResp] | list[LegacyScoreResp] | None: + ) -> list[UserDict] | list[LegacyScoreResp] | None: """从缓存获取用户成绩""" try: - model = LegacyScoreResp if is_legacy else ScoreResp cache_key = self._get_user_scores_cache_key( user_id, score_type, include_fail, mode, limit, offset, is_legacy ) @@ -179,7 +161,7 @@ class UserCacheService: if cached_data: logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") data = json.loads(cached_data) - return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType] + return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data return None except Exception as e: logger.error(f"Error getting user scores from cache: {e}") @@ -189,7 +171,7 @@ class UserCacheService: self, user_id: int, score_type: str, - scores: list[ScoreResp] | list[LegacyScoreResp], + scores: list[UserDict] | list[LegacyScoreResp], include_fail: bool, mode: GameMode | None = None, limit: int = 100, @@ -204,8 +186,12 @@ class UserCacheService: cache_key = self._get_user_scores_cache_key( user_id, score_type, include_fail, mode, limit, offset, is_legacy ) - # 使用 model_dump_json() 而不是 model_dump() + json.dumps() - scores_json_list = [score.model_dump_json() for score in scores] + if len(scores) == 0: + return + if isinstance(scores[0], dict): + scores_json_list = [safe_json_dumps(score) for score in scores] + else: + scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue] cached_data = f"[{','.join(scores_json_list)}]" await self.redis.setex(cache_key, expire_seconds, cached_data) logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") @@ -308,7 +294,7 @@ class UserCacheService: for user in users: if user.id != BANCHOBOT_ID: try: - await self._cache_single_user(user, session) + await self._cache_single_user(user) cached_count += 1 except Exception as e: logger.error(f"Failed to cache user {user.id}: {e}") @@ -320,10 +306,10 @@ class UserCacheService: finally: self._refreshing = False - async def _cache_single_user(self, user: User, session: AsyncSession): + async def _cache_single_user(self, user: User): """缓存单个用户""" try: - user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED) + user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES) # 应用资源代理处理 if settings.enable_asset_proxy: @@ -347,7 +333,7 @@ class UserCacheService: # 立即重新加载用户信息 user = await session.get(User, user_id) if user and user.id != BANCHOBOT_ID: - await self._cache_single_user(user, session) + await self._cache_single_user(user) logger.info(f"Refreshed cache for user {user_id} after score submit") except Exception as e: logger.error(f"Error refreshing user cache on score submit: {e}") diff --git a/app/utils.py b/app/utils.py index c786511..1123353 100644 --- a/app/utils.py +++ b/app/utils.py @@ -4,10 +4,13 @@ from datetime import UTC, datetime import functools import inspect from io import BytesIO +import json import re -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar +from types import NoneType, UnionType +from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict, TypeVar, Union, get_args, get_origin from fastapi import HTTPException +from fastapi.encoders import jsonable_encoder from PIL import Image if TYPE_CHECKING: @@ -299,3 +302,51 @@ def hex_to_hue(hex_color: str) -> int: hue = (60 * ((r - g) / delta) + 240) % 360 return int(hue) + + +def safe_json_dumps(data) -> str: + return json.dumps(jsonable_encoder(data), ensure_ascii=False) + + +def type_is_optional(typ: type): + origin_type = get_origin(typ) + args = get_args(typ) + return (origin_type is UnionType or origin_type is Union) and len(args) == 2 and NoneType in args + + +def _get_type(typ: type, includes: tuple[str, ...]) -> Any: + from app.database._base import DatabaseModel + + origin = get_origin(typ) + if issubclass(typ, DatabaseModel): + return typ.generate_typeddict(includes) + elif origin is list: + item_type = typ.__args__[0] + return list[_get_type(item_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] + elif origin is dict: + key_type, value_type = typ.__args__ + return dict[key_type, _get_type(value_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] + elif type_is_optional(typ): + inner_type = next(arg for arg in get_args(typ) if arg is not NoneType) + return Union[_get_type(inner_type, includes), None] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007 + elif origin is UnionType or origin is Union: + new_types = [] + for arg in get_args(typ): + new_types.append(_get_type(arg, includes)) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] + return Union[tuple(new_types)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007 + else: + return typ + + +def api_doc(desc: str, model: Any, includes: list[str] = [], *, name: str = "APIDict"): + if includes: + includes_str = ", ".join(f"`{inc}`" for inc in includes) + desc += f"\n\n包含:{includes_str}" + if isinstance(model, dict): + fields = {} + for k, v in model.items(): + fields[k] = _get_type(v, tuple(includes)) + typed_dict = TypedDict(name, fields) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] + else: + typed_dict = _get_type(model, tuple(includes)) + return {"description": desc, "model": typed_dict} diff --git a/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py b/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py new file mode 100644 index 0000000..4af10eb --- /dev/null +++ b/migrations/versions/2025-11-23_23707640303c_project_remove_unused_fields_in_database.py @@ -0,0 +1,498 @@ +"""project: remove unused fields in database + +Revision ID: 23707640303c +Revises: 3f0f22f38c3d +Create Date: 2025-11-23 08:14:05.284238 + +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "23707640303c" +down_revision: str | Sequence[str] | None = "3f0f22f38c3d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "beatmaps", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.drop_column("beatmaps", "current_user_playcount") + op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=True) + op.alter_column( + "best_scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "lazer_user_statistics", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=True) + op.alter_column( + "lazer_users", + "playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "lazer_users", + "g0v0_playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.drop_column("lazer_users", "beatmap_playcounts_count") + op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=False) + op.alter_column( + "rank_history", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "rank_top", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "rooms", + "type", + existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"), + nullable=False, + ) + op.execute("UPDATE rooms SET channel_id = 0 WHERE channel_id IS NULL") + op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=False) + op.alter_column( + "score_tokens", + "ruleset_id", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "teams", + "playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + op.alter_column( + "total_score_best_scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "total_score_best_scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "teams", + "playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "score_tokens", + "ruleset_id", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=True) + op.alter_column( + "rooms", + "type", + existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"), + nullable=True, + ) + op.alter_column( + "rank_top", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "rank_history", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=True) + op.add_column( + "lazer_users", sa.Column("beatmap_playcounts_count", mysql.INTEGER(), autoincrement=False, nullable=False) + ) + op.alter_column( + "lazer_users", + "g0v0_playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "lazer_users", + "playmode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=False) + op.alter_column( + "lazer_user_statistics", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column( + "best_scores", + "gamemode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=False) + op.add_column("beatmaps", sa.Column("current_user_playcount", mysql.INTEGER(), autoincrement=False, nullable=False)) + op.alter_column( + "beatmaps", + "mode", + existing_type=mysql.ENUM( + "OSU", + "TAIKO", + "FRUITS", + "MANIA", + "OSURX", + "OSUAP", + "TAIKORX", + "FRUITSRX", + "SENTAKKI", + "TAU", + "RUSH", + "HISHIGATA", + "SOYOKAZE", + ), + nullable=True, + ) + # ### end Alembic commands ###