refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
156
CONTRIBUTING.md
156
CONTRIBUTING.md
@@ -9,23 +9,26 @@ git clone https://github.com/GooGuTeam/g0v0-server.git
|
|||||||
此外,您还需要:
|
此外,您还需要:
|
||||||
|
|
||||||
- clone 旁观服务器到 g0v0-server 的文件夹。
|
- clone 旁观服务器到 g0v0-server 的文件夹。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/GooGuTeam/osu-server-spectator.git spectator-server
|
git clone https://github.com/GooGuTeam/osu-server-spectator.git spectator-server
|
||||||
```
|
```
|
||||||
|
|
||||||
- clone 表现分计算器到 g0v0-server 的文件夹。
|
- clone 表现分计算器到 g0v0-server 的文件夹。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/GooGuTeam/osu-performance-server.git performance-server
|
git clone https://github.com/GooGuTeam/osu-performance-server.git performance-server
|
||||||
```
|
```
|
||||||
|
|
||||||
- 下载并放置自定义规则集 DLL 到 `rulesets/` 目录(如果需要)。
|
- 下载并放置自定义规则集 DLL 到 `rulesets/` 目录(如果需要)。
|
||||||
|
|
||||||
## 开发环境
|
## 开发环境
|
||||||
|
|
||||||
为了确保一致的开发环境,我们强烈建议使用提供的 Dev Container。这将设置一个容器化的环境,预先安装所有必要的工具和依赖项。
|
为了确保一致的开发环境,我们强烈建议使用提供的 Dev Container。这将设置一个容器化的环境,预先安装所有必要的工具和依赖项。
|
||||||
|
|
||||||
1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。
|
1. 安装 [Docker](https://www.docker.com/products/docker-desktop/)。
|
||||||
2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。
|
2. 在 Visual Studio Code 中安装 [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)。
|
||||||
3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。
|
3. 在 VS Code 中打开项目。当被提示时,点击“在容器中重新打开”以启动开发容器。
|
||||||
|
|
||||||
## 配置项目
|
## 配置项目
|
||||||
|
|
||||||
@@ -67,54 +70,109 @@ uv sync
|
|||||||
|
|
||||||
以下是项目主要目录和文件的结构说明:
|
以下是项目主要目录和文件的结构说明:
|
||||||
|
|
||||||
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
|
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
|
||||||
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
|
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
|
||||||
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
|
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
|
||||||
- `app/`: 存放所有核心应用代码。
|
- `app/`: 存放所有核心应用代码。
|
||||||
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
|
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
|
||||||
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
|
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
|
||||||
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
|
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
|
||||||
- `models/`: 定义非数据库模型和其他模型。
|
- `models/`: 定义非数据库模型和其他模型。
|
||||||
- `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
|
- `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
|
||||||
- `dependencies/`: 管理 FastAPI 的依赖项注入。
|
- `dependencies/`: 管理 FastAPI 的依赖项注入。
|
||||||
- `achievements/`: 存放与成就相关的逻辑。
|
- `achievements/`: 存放与成就相关的逻辑。
|
||||||
- `storage/`: 存储服务代码。
|
- `storage/`: 存储服务代码。
|
||||||
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
|
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
|
||||||
- `middleware/`: 定义中间件,例如会话验证。
|
- `middleware/`: 定义中间件,例如会话验证。
|
||||||
- `helpers/`: 存放辅助函数和工具类。
|
- `helpers/`: 存放辅助函数和工具类。
|
||||||
- `config.py`: 应用配置,使用 pydantic-settings 管理。
|
- `config.py`: 应用配置,使用 pydantic-settings 管理。
|
||||||
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
|
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
|
||||||
- `log.py`: 日志记录模块,提供统一的日志接口。
|
- `log.py`: 日志记录模块,提供统一的日志接口。
|
||||||
- `const.py`: 定义常量。
|
- `const.py`: 定义常量。
|
||||||
- `path.py`: 定义跨文件使用的常量。
|
- `path.py`: 定义跨文件使用的常量。
|
||||||
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
|
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
|
||||||
- `static/`: 存放静态文件,如 `mods.json`。
|
- `static/`: 存放静态文件,如 `mods.json`。
|
||||||
|
|
||||||
### 数据库模型定义
|
### 数据库模型定义
|
||||||
|
|
||||||
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
|
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
|
||||||
|
|
||||||
如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构:
|
本项目使用一种“按需返回”的设计模式,遵循 `Dict` - `Model` - `Table` 结构。详细设计思路请参考[这篇文章](https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/)。
|
||||||
|
|
||||||
|
#### 基本结构
|
||||||
|
|
||||||
|
1. **Dict**: 定义模型转换后的字典结构,用于类型检查。必须在 Model 之前定义。
|
||||||
|
2. **Model**: 继承自 `DatabaseModel[Dict]`,定义字段和计算属性。
|
||||||
|
3. **Table**: 继承自 `Model`,定义数据库表结构。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ModelBase(SQLModel):
|
from typing import TypedDict, NotRequired
|
||||||
# 定义共有内容
|
from app.database._base import DatabaseModel, OnDemand, included, ondemand
|
||||||
...
|
from sqlmodel import Field
|
||||||
|
|
||||||
|
# 1. 定义 Dict
|
||||||
|
class UserDict(TypedDict):
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
email: NotRequired[str] # 可选字段
|
||||||
|
followers_count: int # 计算属性
|
||||||
|
|
||||||
class Model(ModelBase, table=True):
|
# 2. 定义 Model
|
||||||
# 定义数据库表内容
|
class UserModel(DatabaseModel[UserDict]):
|
||||||
...
|
id: int = Field(primary_key=True)
|
||||||
|
username: str
|
||||||
|
email: OnDemand[str] # 使用 OnDemand 标记可选字段
|
||||||
|
|
||||||
|
# 普通计算属性 (总是返回)
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def followers_count(session: AsyncSession, instance: "User") -> int:
|
||||||
|
return await session.scalar(select(func.count()).where(Follower.followed_id == instance.id))
|
||||||
|
|
||||||
class ModelResp(ModelBase):
|
# 可选计算属性 (仅在 includes 中指定时返回)
|
||||||
# 定义响应内容
|
@ondemand
|
||||||
...
|
@staticmethod
|
||||||
|
async def some_optional_property(session: AsyncSession, instance: "User") -> str:
|
||||||
@classmethod
|
|
||||||
def from_db(cls, db: Model) -> "ModelResp":
|
|
||||||
# 从数据库模型转换
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# 3. 定义 Table
|
||||||
|
class User(UserModel, table=True):
|
||||||
|
password: str # 仅在数据库中存在的字段
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 字段类型
|
||||||
|
|
||||||
|
- **普通属性**: 直接定义在 Model 中,总是返回。
|
||||||
|
- **可选属性**: 使用 `OnDemand[T]` 标记,仅在 `includes` 中指定时返回。
|
||||||
|
- **普通计算属性**: 使用 `@included` 装饰的静态方法,总是返回。
|
||||||
|
- **可选计算属性**: 使用 `@ondemand` 装饰的静态方法,仅在 `includes` 中指定时返回。
|
||||||
|
|
||||||
|
#### 使用方法
|
||||||
|
|
||||||
|
**转换模型**:
|
||||||
|
|
||||||
|
使用 `Model.transform` 方法将数据库实例转换为字典:
|
||||||
|
|
||||||
|
```python
|
||||||
|
user = await session.get(User, 1)
|
||||||
|
user_dict = await UserModel.transform(
|
||||||
|
user,
|
||||||
|
includes=["email"], # 指定需要返回的可选字段
|
||||||
|
some_context="foo-bar", # 如果计算属性需要上下文,可以传入额外参数
|
||||||
|
session=session # 可选传入自己的 session
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**API 文档**:
|
||||||
|
|
||||||
|
在 FastAPI 路由中,使用 `Model.generate_typeddict` 生成准确的响应文档:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get("/users/{id}", response_model=UserModel.generate_typeddict(includes=("email",)))
|
||||||
|
async def get_user(id: int) -> dict:
|
||||||
|
...
|
||||||
|
return await UserModel.transform(user, includes=["email"])
|
||||||
```
|
```
|
||||||
|
|
||||||
数据库模块名应与表名相同,定义了多个模型的除外。
|
数据库模块名应与表名相同,定义了多个模型的除外。
|
||||||
@@ -227,16 +285,16 @@ pre-commit 不提供 pyright 的 hook,您需要手动运行 `pyright` 检查
|
|||||||
|
|
||||||
**类型** 必须是以下之一:
|
**类型** 必须是以下之一:
|
||||||
|
|
||||||
* **feat**:新功能
|
- **feat**:新功能
|
||||||
* **fix**:错误修复
|
- **fix**:错误修复
|
||||||
* **docs**:仅文档更改
|
- **docs**:仅文档更改
|
||||||
* **style**:不影响代码含义的更改(空格、格式、缺少分号等)
|
- **style**:不影响代码含义的更改(空格、格式、缺少分号等)
|
||||||
* **refactor**:代码重构
|
- **refactor**:代码重构
|
||||||
* **perf**:改善性能的代码更改
|
- **perf**:改善性能的代码更改
|
||||||
* **test**:添加缺失的测试或修正现有测试
|
- **test**:添加缺失的测试或修正现有测试
|
||||||
* **chore**:对构建过程或辅助工具和库(如文档生成)的更改
|
- **chore**:对构建过程或辅助工具和库(如文档生成)的更改
|
||||||
* **ci**:持续集成相关的更改
|
- **ci**:持续集成相关的更改
|
||||||
* **deploy**: 部署相关的更改
|
- **deploy**: 部署相关的更改
|
||||||
|
|
||||||
**范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。对整个项目的更改使用 `project`。
|
**范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。对整个项目的更改使用 `project`。
|
||||||
|
|
||||||
|
|||||||
@@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp
|
|||||||
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
|
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
|
||||||
from .beatmap import (
|
from .beatmap import (
|
||||||
Beatmap,
|
Beatmap,
|
||||||
BeatmapResp,
|
BeatmapDict,
|
||||||
|
BeatmapModel,
|
||||||
|
)
|
||||||
|
from .beatmap_playcounts import (
|
||||||
|
BeatmapPlaycounts,
|
||||||
|
BeatmapPlaycountsDict,
|
||||||
|
BeatmapPlaycountsModel,
|
||||||
)
|
)
|
||||||
from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
|
|
||||||
from .beatmap_sync import BeatmapSync
|
from .beatmap_sync import BeatmapSync
|
||||||
from .beatmap_tags import BeatmapTagVote
|
from .beatmap_tags import BeatmapTagVote
|
||||||
from .beatmapset import (
|
from .beatmapset import (
|
||||||
Beatmapset,
|
Beatmapset,
|
||||||
BeatmapsetResp,
|
BeatmapsetDict,
|
||||||
|
BeatmapsetModel,
|
||||||
)
|
)
|
||||||
from .beatmapset_ratings import BeatmapRating
|
from .beatmapset_ratings import BeatmapRating
|
||||||
from .best_scores import BestScore
|
from .best_scores import BestScore
|
||||||
from .chat import (
|
from .chat import (
|
||||||
ChannelType,
|
ChannelType,
|
||||||
ChatChannel,
|
ChatChannel,
|
||||||
ChatChannelResp,
|
ChatChannelDict,
|
||||||
|
ChatChannelModel,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessageResp,
|
ChatMessageDict,
|
||||||
|
ChatMessageModel,
|
||||||
)
|
)
|
||||||
from .counts import (
|
from .counts import (
|
||||||
CountResp,
|
CountResp,
|
||||||
@@ -30,8 +38,8 @@ from .events import Event
|
|||||||
from .favourite_beatmapset import FavouriteBeatmapset
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
from .item_attempts_count import (
|
from .item_attempts_count import (
|
||||||
ItemAttemptsCount,
|
ItemAttemptsCount,
|
||||||
ItemAttemptsResp,
|
ItemAttemptsCountDict,
|
||||||
PlaylistAggregateScore,
|
ItemAttemptsCountModel,
|
||||||
)
|
)
|
||||||
from .matchmaking import (
|
from .matchmaking import (
|
||||||
MatchmakingPool,
|
MatchmakingPool,
|
||||||
@@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
|||||||
from .notification import Notification, UserNotification
|
from .notification import Notification, UserNotification
|
||||||
from .password_reset import PasswordReset
|
from .password_reset import PasswordReset
|
||||||
from .playlist_best_score import PlaylistBestScore
|
from .playlist_best_score import PlaylistBestScore
|
||||||
from .playlists import Playlist, PlaylistResp
|
from .playlists import Playlist, PlaylistDict, PlaylistModel
|
||||||
from .rank_history import RankHistory, RankHistoryResp, RankTop
|
from .rank_history import RankHistory, RankHistoryResp, RankTop
|
||||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType
|
||||||
from .room import APIUploadedRoom, Room, RoomResp
|
from .room import APIUploadedRoom, Room, RoomDict, RoomModel
|
||||||
from .room_participated_user import RoomParticipatedUser
|
from .room_participated_user import RoomParticipatedUser
|
||||||
from .score import (
|
from .score import (
|
||||||
MultiplayerScores,
|
MultiplayerScores,
|
||||||
Score,
|
Score,
|
||||||
ScoreAround,
|
ScoreAround,
|
||||||
ScoreBase,
|
ScoreDict,
|
||||||
ScoreResp,
|
ScoreModel,
|
||||||
ScoreStatistics,
|
ScoreStatistics,
|
||||||
)
|
)
|
||||||
from .score_token import ScoreToken, ScoreTokenResp
|
from .score_token import ScoreToken, ScoreTokenResp
|
||||||
|
from .search_beatmapset import SearchBeatmapsetsResp
|
||||||
from .statistics import (
|
from .statistics import (
|
||||||
UserStatistics,
|
UserStatistics,
|
||||||
UserStatisticsResp,
|
UserStatisticsDict,
|
||||||
|
UserStatisticsModel,
|
||||||
)
|
)
|
||||||
from .team import Team, TeamMember, TeamRequest, TeamResp
|
from .team import Team, TeamMember, TeamRequest, TeamResp
|
||||||
from .total_score_best_scores import TotalScoreBestScore
|
from .total_score_best_scores import TotalScoreBestScore
|
||||||
from .user import (
|
from .user import (
|
||||||
MeResp,
|
|
||||||
User,
|
User,
|
||||||
UserResp,
|
UserDict,
|
||||||
|
UserModel,
|
||||||
)
|
)
|
||||||
from .user_account_history import (
|
from .user_account_history import (
|
||||||
UserAccountHistory,
|
UserAccountHistory,
|
||||||
@@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"APIUploadedRoom",
|
"APIUploadedRoom",
|
||||||
"Beatmap",
|
"Beatmap",
|
||||||
|
"BeatmapDict",
|
||||||
|
"BeatmapModel",
|
||||||
"BeatmapPlaycounts",
|
"BeatmapPlaycounts",
|
||||||
"BeatmapPlaycountsResp",
|
"BeatmapPlaycountsDict",
|
||||||
|
"BeatmapPlaycountsModel",
|
||||||
"BeatmapRating",
|
"BeatmapRating",
|
||||||
"BeatmapResp",
|
|
||||||
"BeatmapSync",
|
"BeatmapSync",
|
||||||
"BeatmapTagVote",
|
"BeatmapTagVote",
|
||||||
"Beatmapset",
|
"Beatmapset",
|
||||||
"BeatmapsetResp",
|
"BeatmapsetDict",
|
||||||
|
"BeatmapsetModel",
|
||||||
"BestScore",
|
"BestScore",
|
||||||
"ChannelType",
|
"ChannelType",
|
||||||
"ChatChannel",
|
"ChatChannel",
|
||||||
"ChatChannelResp",
|
"ChatChannelDict",
|
||||||
|
"ChatChannelModel",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatMessageResp",
|
"ChatMessageDict",
|
||||||
|
"ChatMessageModel",
|
||||||
"CountResp",
|
"CountResp",
|
||||||
"DailyChallengeStats",
|
"DailyChallengeStats",
|
||||||
"DailyChallengeStatsResp",
|
"DailyChallengeStatsResp",
|
||||||
@@ -100,13 +115,13 @@ __all__ = [
|
|||||||
"Event",
|
"Event",
|
||||||
"FavouriteBeatmapset",
|
"FavouriteBeatmapset",
|
||||||
"ItemAttemptsCount",
|
"ItemAttemptsCount",
|
||||||
"ItemAttemptsResp",
|
"ItemAttemptsCountDict",
|
||||||
|
"ItemAttemptsCountModel",
|
||||||
"LoginSession",
|
"LoginSession",
|
||||||
"LoginSessionResp",
|
"LoginSessionResp",
|
||||||
"MatchmakingPool",
|
"MatchmakingPool",
|
||||||
"MatchmakingPoolBeatmap",
|
"MatchmakingPoolBeatmap",
|
||||||
"MatchmakingUserStats",
|
"MatchmakingUserStats",
|
||||||
"MeResp",
|
|
||||||
"MonthlyPlaycounts",
|
"MonthlyPlaycounts",
|
||||||
"MultiplayerEvent",
|
"MultiplayerEvent",
|
||||||
"MultiplayerEventResp",
|
"MultiplayerEventResp",
|
||||||
@@ -116,26 +131,29 @@ __all__ = [
|
|||||||
"OAuthToken",
|
"OAuthToken",
|
||||||
"PasswordReset",
|
"PasswordReset",
|
||||||
"Playlist",
|
"Playlist",
|
||||||
"PlaylistAggregateScore",
|
|
||||||
"PlaylistBestScore",
|
"PlaylistBestScore",
|
||||||
"PlaylistResp",
|
"PlaylistDict",
|
||||||
|
"PlaylistModel",
|
||||||
"RankHistory",
|
"RankHistory",
|
||||||
"RankHistoryResp",
|
"RankHistoryResp",
|
||||||
"RankTop",
|
"RankTop",
|
||||||
"Relationship",
|
"Relationship",
|
||||||
"RelationshipResp",
|
"RelationshipDict",
|
||||||
|
"RelationshipModel",
|
||||||
"RelationshipType",
|
"RelationshipType",
|
||||||
"ReplayWatchedCount",
|
"ReplayWatchedCount",
|
||||||
"Room",
|
"Room",
|
||||||
|
"RoomDict",
|
||||||
|
"RoomModel",
|
||||||
"RoomParticipatedUser",
|
"RoomParticipatedUser",
|
||||||
"RoomResp",
|
|
||||||
"Score",
|
"Score",
|
||||||
"ScoreAround",
|
"ScoreAround",
|
||||||
"ScoreBase",
|
"ScoreDict",
|
||||||
"ScoreResp",
|
"ScoreModel",
|
||||||
"ScoreStatistics",
|
"ScoreStatistics",
|
||||||
"ScoreToken",
|
"ScoreToken",
|
||||||
"ScoreTokenResp",
|
"ScoreTokenResp",
|
||||||
|
"SearchBeatmapsetsResp",
|
||||||
"Team",
|
"Team",
|
||||||
"TeamMember",
|
"TeamMember",
|
||||||
"TeamRequest",
|
"TeamRequest",
|
||||||
@@ -149,17 +167,18 @@ __all__ = [
|
|||||||
"UserAccountHistoryResp",
|
"UserAccountHistoryResp",
|
||||||
"UserAccountHistoryType",
|
"UserAccountHistoryType",
|
||||||
"UserAchievement",
|
"UserAchievement",
|
||||||
"UserAchievement",
|
|
||||||
"UserAchievementResp",
|
"UserAchievementResp",
|
||||||
|
"UserDict",
|
||||||
"UserLoginLog",
|
"UserLoginLog",
|
||||||
|
"UserModel",
|
||||||
"UserNotification",
|
"UserNotification",
|
||||||
"UserPreference",
|
"UserPreference",
|
||||||
"UserResp",
|
|
||||||
"UserStatistics",
|
"UserStatistics",
|
||||||
"UserStatisticsResp",
|
"UserStatisticsDict",
|
||||||
|
"UserStatisticsModel",
|
||||||
"V1APIKeys",
|
"V1APIKeys",
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in __all__:
|
for i in __all__:
|
||||||
if i.endswith("Resp"):
|
if i.endswith("Model") or i.endswith("Resp"):
|
||||||
globals()[i].model_rebuild() # type: ignore[call-arg]
|
globals()[i].model_rebuild()
|
||||||
|
|||||||
499
app/database/_base.py
Normal file
499
app/database/_base.py
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
|
from functools import lru_cache, wraps
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from types import NoneType, get_original_bases
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Concatenate,
|
||||||
|
ForwardRef,
|
||||||
|
ParamSpec,
|
||||||
|
TypedDict,
|
||||||
|
cast,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.models.model import UTCBaseModel
|
||||||
|
from app.utils import type_is_optional
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import async_object_session
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlmodel.main import SQLModelMetaclass
|
||||||
|
|
||||||
|
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
|
||||||
|
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
|
||||||
|
if isinstance(type_, str):
|
||||||
|
type_ = ForwardRef(type_)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return evaluate_forwardref(
|
||||||
|
type_,
|
||||||
|
globalns=vars(sys.modules[module_name]),
|
||||||
|
localns={},
|
||||||
|
)
|
||||||
|
except (NameError, AttributeError, KeyError):
|
||||||
|
# Fallback to app.database module
|
||||||
|
try:
|
||||||
|
import app.database
|
||||||
|
|
||||||
|
return evaluate_forwardref(
|
||||||
|
type_,
|
||||||
|
globalns=vars(app.database),
|
||||||
|
localns={},
|
||||||
|
)
|
||||||
|
except (NameError, AttributeError, KeyError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class OnDemand[T]:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||||
|
|
||||||
|
def __set__(self, instance: Any, value: T) -> None: ...
|
||||||
|
|
||||||
|
def __delete__(self, instance: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Exclude[T]:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||||
|
|
||||||
|
def __set__(self, instance: Any, value: T) -> None: ...
|
||||||
|
|
||||||
|
def __delete__(self, instance: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
|
||||||
|
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
|
||||||
|
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
|
||||||
|
# See https://github.com/pydantic/pydantic/pull/11991
|
||||||
|
from annotationlib import (
|
||||||
|
Format,
|
||||||
|
call_annotate_function,
|
||||||
|
get_annotate_from_class_namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotate := get_annotate_from_class_namespace(class_dict):
|
||||||
|
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
|
||||||
|
return raw_annotations
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
|
||||||
|
if sys.version_info < (3, 12, 4):
|
||||||
|
|
||||||
|
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||||
|
# Even though it is the right signature for python 3.9, mypy complains with
|
||||||
|
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||||
|
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||||
|
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||||
|
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||||
|
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
|
||||||
|
# warnings:
|
||||||
|
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseModelMetaclass(SQLModelMetaclass):
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
bases: tuple[type, ...],
|
||||||
|
namespace: dict[str, Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "DatabaseModelMetaclass":
|
||||||
|
original_annotations = _get_annotations(namespace)
|
||||||
|
new_annotations = {}
|
||||||
|
ondemands = []
|
||||||
|
excludes = []
|
||||||
|
|
||||||
|
for k, v in original_annotations.items():
|
||||||
|
if get_origin(v) is OnDemand:
|
||||||
|
inner_type = v.__args__[0]
|
||||||
|
new_annotations[k] = inner_type
|
||||||
|
ondemands.append(k)
|
||||||
|
elif get_origin(v) is Exclude:
|
||||||
|
inner_type = v.__args__[0]
|
||||||
|
new_annotations[k] = inner_type
|
||||||
|
excludes.append(k)
|
||||||
|
else:
|
||||||
|
new_annotations[k] = v
|
||||||
|
|
||||||
|
new_class = super().__new__(
|
||||||
|
cls,
|
||||||
|
name,
|
||||||
|
bases,
|
||||||
|
{
|
||||||
|
**namespace,
|
||||||
|
"__annotations__": new_annotations,
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
|
||||||
|
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
|
||||||
|
ondemands
|
||||||
|
)
|
||||||
|
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
|
||||||
|
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
|
||||||
|
|
||||||
|
for attr_name, attr_value in namespace.items():
|
||||||
|
target = _get_callable_target(attr_value)
|
||||||
|
if target is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if getattr(target, "__included__", False):
|
||||||
|
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||||
|
_pre_calculate_context_params(target, attr_value)
|
||||||
|
|
||||||
|
if getattr(target, "__calculated_ondemand__", False):
|
||||||
|
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||||
|
_pre_calculate_context_params(target, attr_value)
|
||||||
|
|
||||||
|
# Register TDict to DatabaseModel mapping
|
||||||
|
for base in get_original_bases(new_class):
|
||||||
|
cls_name = base.__name__
|
||||||
|
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
|
||||||
|
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
|
||||||
|
generic_type = evaluate_forwardref(
|
||||||
|
ForwardRef(generic_type_name),
|
||||||
|
globalns=vars(sys.modules[new_class.__module__]),
|
||||||
|
localns={},
|
||||||
|
)
|
||||||
|
_dict_to_model[generic_type[0]] = new_class
|
||||||
|
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
|
||||||
|
if hasattr(target, "__context_params__"):
|
||||||
|
return
|
||||||
|
|
||||||
|
sig = inspect.signature(target)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
start_index = 2
|
||||||
|
if isinstance(attr_value, classmethod):
|
||||||
|
start_index = 3
|
||||||
|
|
||||||
|
context_params = [] if len(params) < start_index else params[start_index:]
|
||||||
|
|
||||||
|
setattr(target, "__context_params__", context_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_callable_target(value: Any) -> Callable | None:
|
||||||
|
if isinstance(value, (staticmethod, classmethod)):
|
||||||
|
return value.__func__
|
||||||
|
if inspect.isfunction(value):
|
||||||
|
return value
|
||||||
|
if inspect.ismethod(value):
|
||||||
|
return value.__func__
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _mark_callable(value: Any, flag: str) -> Callable | None:
|
||||||
|
target = _get_callable_target(value)
|
||||||
|
if target is None:
|
||||||
|
return None
|
||||||
|
setattr(target, flag, True)
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def _get_return_type(func: Callable) -> type:
|
||||||
|
sig = inspect.get_annotations(func)
|
||||||
|
return sig.get("return", Any)
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
|
||||||
|
DecoratorTarget = CalculatedField | staticmethod | classmethod
|
||||||
|
|
||||||
|
|
||||||
|
def included(func: DecoratorTarget) -> DecoratorTarget:
|
||||||
|
marker = _mark_callable(func, "__included__")
|
||||||
|
if marker is None:
|
||||||
|
raise RuntimeError("@included is only usable on callables.")
|
||||||
|
|
||||||
|
@wraps(marker)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
return await marker(*args, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(func, staticmethod):
|
||||||
|
return staticmethod(wrapper)
|
||||||
|
if isinstance(func, classmethod):
|
||||||
|
return classmethod(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
|
||||||
|
marker = _mark_callable(func, "__calculated_ondemand__")
|
||||||
|
if marker is None:
|
||||||
|
raise RuntimeError("@ondemand is only usable on callables.")
|
||||||
|
|
||||||
|
@wraps(marker)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
return await marker(*args, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(func, staticmethod):
|
||||||
|
return staticmethod(wrapper)
|
||||||
|
if isinstance(func, classmethod):
|
||||||
|
return classmethod(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def call_awaitable_with_context(
|
||||||
|
func: CalculatedField,
|
||||||
|
session: AsyncSession,
|
||||||
|
instance: Any,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> Any:
|
||||||
|
context_params: list[str] | None = getattr(func, "__context_params__", None)
|
||||||
|
|
||||||
|
if context_params is None:
|
||||||
|
# Fallback if not pre-calculated
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if len(sig.parameters) == 2:
|
||||||
|
return await func(session, instance)
|
||||||
|
else:
|
||||||
|
call_params = {}
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.name in context:
|
||||||
|
call_params[param.name] = context[param.name]
|
||||||
|
return await func(session, instance, **call_params)
|
||||||
|
|
||||||
|
if not context_params:
|
||||||
|
return await func(session, instance)
|
||||||
|
|
||||||
|
call_params = {}
|
||||||
|
for name in context_params:
|
||||||
|
if name in context:
|
||||||
|
call_params[name] = context[name]
|
||||||
|
return await func(session, instance, **call_params)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
|
||||||
|
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||||
|
|
||||||
|
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||||
|
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||||
|
|
||||||
|
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def transform(
|
||||||
|
cls,
|
||||||
|
db_instance: "DatabaseModel",
|
||||||
|
*,
|
||||||
|
session: AsyncSession,
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
**context: Any,
|
||||||
|
) -> TDict: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def transform(
|
||||||
|
cls,
|
||||||
|
db_instance: "DatabaseModel",
|
||||||
|
*,
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
**context: Any,
|
||||||
|
) -> TDict: ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def transform(
|
||||||
|
cls,
|
||||||
|
db_instance: "DatabaseModel",
|
||||||
|
*,
|
||||||
|
session: AsyncSession | None = None,
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
**context: Any,
|
||||||
|
) -> TDict:
|
||||||
|
includes = includes.copy() if includes is not None else []
|
||||||
|
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
|
||||||
|
if session is None:
|
||||||
|
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
|
||||||
|
resp_obj = cls.model_validate(db_instance.model_dump())
|
||||||
|
data = resp_obj.model_dump()
|
||||||
|
|
||||||
|
for field in cls._CALCULATED_FIELDS:
|
||||||
|
func = getattr(cls, field)
|
||||||
|
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||||
|
data[field] = value
|
||||||
|
|
||||||
|
sub_include_map: dict[str, list[str]] = {}
|
||||||
|
for include in [i for i in includes if "." in i]:
|
||||||
|
parent, sub_include = include.split(".", 1)
|
||||||
|
if parent not in sub_include_map:
|
||||||
|
sub_include_map[parent] = []
|
||||||
|
sub_include_map[parent].append(sub_include)
|
||||||
|
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
|
||||||
|
|
||||||
|
for field, sub_includes in sub_include_map.items():
|
||||||
|
if field in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||||
|
func = getattr(cls, field)
|
||||||
|
value = await call_awaitable_with_context(
|
||||||
|
func, session, db_instance, {**context, "includes": sub_includes}
|
||||||
|
)
|
||||||
|
data[field] = value
|
||||||
|
|
||||||
|
for include in includes:
|
||||||
|
if include in data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if include in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||||
|
func = getattr(cls, include)
|
||||||
|
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||||
|
data[include] = value
|
||||||
|
|
||||||
|
for field in cls._ONDEMAND_DATABASE_FIELDS:
|
||||||
|
if field not in includes:
|
||||||
|
del data[field]
|
||||||
|
|
||||||
|
for field in cls._EXCLUDED_DATABASE_FIELDS:
|
||||||
|
if field in data:
|
||||||
|
del data[field]
|
||||||
|
|
||||||
|
return cast(TDict, data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def transform_many(
|
||||||
|
cls,
|
||||||
|
db_instances: Sequence["DatabaseModel"],
|
||||||
|
*,
|
||||||
|
session: AsyncSession | None = None,
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
**context: Any,
|
||||||
|
) -> list[TDict]:
|
||||||
|
if not db_instances:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
|
||||||
|
# if the transform method performs any database operations using the shared session.
|
||||||
|
# Since we don't know if the transform method (or its calculated fields) will use the DB,
|
||||||
|
# we must execute them serially to be safe.
|
||||||
|
results = []
|
||||||
|
for instance in db_instances:
|
||||||
|
results.append(await cls.transform(instance, session=session, includes=includes, **context))
|
||||||
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@lru_cache
|
||||||
|
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
|
||||||
|
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
|
||||||
|
# Evaluate ForwardRef if present
|
||||||
|
if isinstance(field_type, (str, ForwardRef)):
|
||||||
|
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
|
||||||
|
if resolved is not None:
|
||||||
|
field_type = resolved
|
||||||
|
|
||||||
|
origin_type = get_origin(field_type)
|
||||||
|
inner_type = field_type
|
||||||
|
args = get_args(field_type)
|
||||||
|
|
||||||
|
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
|
||||||
|
if is_optional:
|
||||||
|
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
|
||||||
|
|
||||||
|
is_list = False
|
||||||
|
if origin_type is list:
|
||||||
|
is_list = True
|
||||||
|
inner_type = args[0]
|
||||||
|
|
||||||
|
# Evaluate ForwardRef in inner_type if present
|
||||||
|
if isinstance(inner_type, (str, ForwardRef)):
|
||||||
|
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
|
||||||
|
if resolved is not None:
|
||||||
|
inner_type = resolved
|
||||||
|
|
||||||
|
if not resolve_database_model:
|
||||||
|
if is_optional:
|
||||||
|
return inner_type | None # pyright: ignore[reportOperatorIssue]
|
||||||
|
elif is_list:
|
||||||
|
return list[inner_type]
|
||||||
|
return inner_type
|
||||||
|
|
||||||
|
model_class = None
|
||||||
|
|
||||||
|
# First check if inner_type is directly a DatabaseModel subclass
|
||||||
|
try:
|
||||||
|
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
|
||||||
|
model_class = inner_type
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If not found, look up in _dict_to_model
|
||||||
|
if model_class is None:
|
||||||
|
model_class = _dict_to_model.get(inner_type) # type: ignore
|
||||||
|
|
||||||
|
if model_class is not None:
|
||||||
|
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
|
||||||
|
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
|
||||||
|
|
||||||
|
if is_optional:
|
||||||
|
resolved_type = resolved_type | None # type: ignore
|
||||||
|
|
||||||
|
return resolved_type
|
||||||
|
|
||||||
|
# Fallback: use the resolved inner_type
|
||||||
|
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
|
||||||
|
if is_optional:
|
||||||
|
resolved_type = resolved_type | None # type: ignore
|
||||||
|
return resolved_type
|
||||||
|
|
||||||
|
if includes is None:
|
||||||
|
includes = ()
|
||||||
|
|
||||||
|
# Parse nested includes
|
||||||
|
direct_includes = []
|
||||||
|
sub_include_map: dict[str, list[str]] = {}
|
||||||
|
for include in includes:
|
||||||
|
if "." in include:
|
||||||
|
parent, sub_include = include.split(".", 1)
|
||||||
|
if parent not in sub_include_map:
|
||||||
|
sub_include_map[parent] = []
|
||||||
|
sub_include_map[parent].append(sub_include)
|
||||||
|
if parent not in direct_includes:
|
||||||
|
direct_includes.append(parent)
|
||||||
|
else:
|
||||||
|
direct_includes.append(include)
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
# Process model fields
|
||||||
|
for field_name, field_info in cls.model_fields.items():
|
||||||
|
field_type = field_info.annotation or Any
|
||||||
|
field_type = _evaluate_type(field_type, field_name=field_name)
|
||||||
|
|
||||||
|
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
fields[field_name] = field_type
|
||||||
|
|
||||||
|
# Process calculated fields
|
||||||
|
for field_name, field_type in cls._CALCULATED_FIELDS.items():
|
||||||
|
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||||
|
fields[field_name] = field_type
|
||||||
|
|
||||||
|
# Process ondemand calculated fields
|
||||||
|
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
|
||||||
|
if field_name not in direct_includes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||||
|
fields[field_name] = field_type
|
||||||
|
|
||||||
|
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]
|
||||||
@@ -1,117 +1,339 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.calculator import get_calculator
|
from app.calculator import get_calculator
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.beatmap_tags import BeatmapTagVote
|
|
||||||
from app.database.failtime import FailTime, FailTimeResp
|
|
||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
from app.models.mods import APIMod
|
from app.models.mods import APIMod
|
||||||
from app.models.performance import DifficultyAttributesUnion
|
from app.models.performance import DifficultyAttributesUnion
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
|
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||||
from .beatmap_playcounts import BeatmapPlaycounts
|
from .beatmap_playcounts import BeatmapPlaycounts
|
||||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
from .beatmap_tags import BeatmapTagVote
|
||||||
|
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
|
||||||
|
from .failtime import FailTime, FailTimeResp
|
||||||
|
from .user import User, UserDict, UserModel
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlalchemy import Column, DateTime
|
from sqlalchemy import Column, DateTime
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
|
|
||||||
from .user import User
|
|
||||||
|
|
||||||
|
|
||||||
class BeatmapOwner(SQLModel):
|
class BeatmapOwner(SQLModel):
|
||||||
id: int
|
id: int
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
|
|
||||||
class BeatmapBase(SQLModel):
|
class BeatmapDict(TypedDict):
|
||||||
# Beatmap
|
beatmapset_id: int
|
||||||
url: str
|
difficulty_rating: float
|
||||||
|
id: int
|
||||||
mode: GameMode
|
mode: GameMode
|
||||||
|
total_length: int
|
||||||
|
user_id: int
|
||||||
|
version: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
checksum: NotRequired[str]
|
||||||
|
max_combo: NotRequired[int | None]
|
||||||
|
ar: NotRequired[float]
|
||||||
|
cs: NotRequired[float]
|
||||||
|
drain: NotRequired[float]
|
||||||
|
accuracy: NotRequired[float]
|
||||||
|
bpm: NotRequired[float]
|
||||||
|
count_circles: NotRequired[int]
|
||||||
|
count_sliders: NotRequired[int]
|
||||||
|
count_spinners: NotRequired[int]
|
||||||
|
deleted_at: NotRequired[datetime | None]
|
||||||
|
hit_length: NotRequired[int]
|
||||||
|
last_updated: NotRequired[datetime]
|
||||||
|
|
||||||
|
status: NotRequired[str]
|
||||||
|
beatmapset: NotRequired[BeatmapsetDict]
|
||||||
|
current_user_playcount: NotRequired[int]
|
||||||
|
current_user_tag_ids: NotRequired[list[int]]
|
||||||
|
failtimes: NotRequired[FailTimeResp]
|
||||||
|
top_tag_ids: NotRequired[list[dict[str, int]]]
|
||||||
|
user: NotRequired[UserDict]
|
||||||
|
convert: NotRequired[bool]
|
||||||
|
is_scoreable: NotRequired[bool]
|
||||||
|
mode_int: NotRequired[int]
|
||||||
|
ranked: NotRequired[int]
|
||||||
|
playcount: NotRequired[int]
|
||||||
|
passcount: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapModel(DatabaseModel[BeatmapDict]):
|
||||||
|
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"checksum",
|
||||||
|
"accuracy",
|
||||||
|
"ar",
|
||||||
|
"bpm",
|
||||||
|
"convert",
|
||||||
|
"count_circles",
|
||||||
|
"count_sliders",
|
||||||
|
"count_spinners",
|
||||||
|
"cs",
|
||||||
|
"deleted_at",
|
||||||
|
"drain",
|
||||||
|
"hit_length",
|
||||||
|
"is_scoreable",
|
||||||
|
"last_updated",
|
||||||
|
"mode_int",
|
||||||
|
"passcount",
|
||||||
|
"playcount",
|
||||||
|
"ranked",
|
||||||
|
"url",
|
||||||
|
]
|
||||||
|
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"beatmapset.ratings",
|
||||||
|
"current_user_playcount",
|
||||||
|
"failtimes",
|
||||||
|
"max_combo",
|
||||||
|
"owners",
|
||||||
|
]
|
||||||
|
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
|
||||||
|
|
||||||
|
# Beatmap
|
||||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||||
difficulty_rating: float = Field(default=0.0, index=True)
|
difficulty_rating: float = Field(default=0.0, index=True)
|
||||||
|
id: int = Field(primary_key=True, index=True)
|
||||||
|
mode: GameMode
|
||||||
total_length: int
|
total_length: int
|
||||||
user_id: int = Field(index=True)
|
user_id: int = Field(index=True)
|
||||||
version: str = Field(index=True)
|
version: str = Field(index=True)
|
||||||
|
|
||||||
|
url: OnDemand[str]
|
||||||
# optional
|
# optional
|
||||||
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
|
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
|
||||||
current_user_playcount: int = Field(default=0)
|
max_combo: OnDemand[int | None] = Field(default=0)
|
||||||
max_combo: int | None = Field(default=0)
|
# TODO: owners
|
||||||
# TODO: failtimes, owners
|
|
||||||
|
|
||||||
# BeatmapExtended
|
# BeatmapExtended
|
||||||
ar: float = Field(default=0.0)
|
ar: OnDemand[float] = Field(default=0.0)
|
||||||
cs: float = Field(default=0.0)
|
cs: OnDemand[float] = Field(default=0.0)
|
||||||
drain: float = Field(default=0.0) # hp
|
drain: OnDemand[float] = Field(default=0.0) # hp
|
||||||
accuracy: float = Field(default=0.0) # od
|
accuracy: OnDemand[float] = Field(default=0.0) # od
|
||||||
bpm: float = Field(default=0.0)
|
bpm: OnDemand[float] = Field(default=0.0)
|
||||||
count_circles: int = Field(default=0)
|
count_circles: OnDemand[int] = Field(default=0)
|
||||||
count_sliders: int = Field(default=0)
|
count_sliders: OnDemand[int] = Field(default=0)
|
||||||
count_spinners: int = Field(default=0)
|
count_spinners: OnDemand[int] = Field(default=0)
|
||||||
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
|
||||||
hit_length: int = Field(default=0)
|
hit_length: OnDemand[int] = Field(default=0)
|
||||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
|
||||||
|
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
|
||||||
|
return BeatmapRankStatus.APPROVED.name.lower()
|
||||||
|
return beatmap.beatmap_status.name.lower()
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def beatmapset(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmap: "Beatmap",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> BeatmapsetDict | None:
|
||||||
|
if beatmap.beatmapset is not None:
|
||||||
|
return await BeatmapsetModel.transform(
|
||||||
|
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
|
||||||
|
playcount = (
|
||||||
|
await _session.exec(
|
||||||
|
select(BeatmapPlaycounts.playcount).where(
|
||||||
|
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
return int(playcount or 0)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
|
||||||
|
if user is None:
|
||||||
|
return []
|
||||||
|
tag_ids = (
|
||||||
|
await _session.exec(
|
||||||
|
select(BeatmapTagVote.tag_id).where(
|
||||||
|
BeatmapTagVote.beatmap_id == beatmap.id,
|
||||||
|
BeatmapTagVote.user_id == user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
return list(tag_ids)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
|
||||||
|
if beatmap.failtimes is not None:
|
||||||
|
return FailTimeResp.from_db(beatmap.failtimes)
|
||||||
|
return FailTimeResp()
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
|
||||||
|
all_votes = (
|
||||||
|
await _session.exec(
|
||||||
|
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||||
|
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||||
|
.group_by(col(BeatmapTagVote.tag_id))
|
||||||
|
.having(func.count() > settings.beatmap_tag_top_count)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
top_tag_ids: list[dict[str, int]] = []
|
||||||
|
for id, votes in all_votes:
|
||||||
|
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||||
|
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||||||
|
return top_tag_ids
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def user(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmap: "Beatmap",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> UserDict | None:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
user = await _session.get(User, beatmap.user_id)
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
return await UserModel.transform(user, includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
|
||||||
|
beatmap_status = beatmap.beatmap_status
|
||||||
|
if settings.enable_all_beatmap_leaderboard:
|
||||||
|
return True
|
||||||
|
return beatmap_status.has_leaderboard()
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||||
|
return int(beatmap.mode)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||||
|
beatmap_status = beatmap.beatmap_status
|
||||||
|
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||||
|
return BeatmapRankStatus.APPROVED.value
|
||||||
|
return beatmap_status.value
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||||
|
result = (
|
||||||
|
await _session.exec(
|
||||||
|
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
return int(result or 0)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||||
|
from .score import Score
|
||||||
|
|
||||||
|
return (
|
||||||
|
await _session.exec(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(Score)
|
||||||
|
.where(
|
||||||
|
Score.beatmap_id == beatmap.id,
|
||||||
|
col(Score.passed).is_(True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).one()
|
||||||
|
|
||||||
|
|
||||||
class Beatmap(BeatmapBase, table=True):
|
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
|
||||||
__tablename__: str = "beatmaps"
|
__tablename__: str = "beatmaps"
|
||||||
id: int = Field(primary_key=True, index=True)
|
|
||||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
|
||||||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||||||
# optional
|
# optional
|
||||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||||
d = resp.model_dump()
|
d = {k: v for k, v in resp.items() if k != "beatmapset"}
|
||||||
del d["beatmapset"]
|
beatmapset_id = resp.get("beatmapset_id")
|
||||||
|
bid = resp.get("id")
|
||||||
|
ranked = resp.get("ranked")
|
||||||
|
if beatmapset_id is None or bid is None or ranked is None:
|
||||||
|
raise ValueError("beatmapset_id, id and ranked are required")
|
||||||
beatmap = cls.model_validate(
|
beatmap = cls.model_validate(
|
||||||
{
|
{
|
||||||
**d,
|
**d,
|
||||||
"beatmapset_id": resp.beatmapset_id,
|
"beatmapset_id": beatmapset_id,
|
||||||
"id": resp.id,
|
"id": bid,
|
||||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
"beatmap_status": BeatmapRankStatus(ranked),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return beatmap
|
return beatmap
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||||
beatmap = await cls.from_resp_no_save(session, resp)
|
beatmap = await cls.from_resp_no_save(session, resp)
|
||||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
resp_id = resp.get("id")
|
||||||
|
if resp_id is None:
|
||||||
|
raise ValueError("id is required")
|
||||||
|
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
|
||||||
session.add(beatmap)
|
session.add(beatmap)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
|
||||||
beatmaps = []
|
beatmaps = []
|
||||||
for resp in inp:
|
for resp_dict in inp:
|
||||||
if resp.id == from_:
|
bid = resp_dict.get("id")
|
||||||
|
if bid == from_ or bid is None:
|
||||||
continue
|
continue
|
||||||
d = resp.model_dump()
|
|
||||||
del d["beatmapset"]
|
beatmapset_id = resp_dict.get("beatmapset_id")
|
||||||
|
ranked = resp_dict.get("ranked")
|
||||||
|
if beatmapset_id is None or ranked is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建 beatmap 字典,移除 beatmapset
|
||||||
|
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
|
||||||
|
|
||||||
beatmap = Beatmap.model_validate(
|
beatmap = Beatmap.model_validate(
|
||||||
{
|
{
|
||||||
**d,
|
**d,
|
||||||
"beatmapset_id": resp.beatmapset_id,
|
"beatmapset_id": beatmapset_id,
|
||||||
"id": resp.id,
|
"id": bid,
|
||||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
"beatmap_status": BeatmapRankStatus(ranked),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
|
||||||
session.add(beatmap)
|
session.add(beatmap)
|
||||||
beatmaps.append(beatmap)
|
beatmaps.append(beatmap)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
for beatmap in beatmaps:
|
||||||
|
await session.refresh(beatmap)
|
||||||
return beatmaps
|
return beatmaps
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
beatmap = (await session.exec(stmt)).first()
|
beatmap = (await session.exec(stmt)).first()
|
||||||
if not beatmap:
|
if not beatmap:
|
||||||
resp = await fetcher.get_beatmap(bid, md5)
|
resp = await fetcher.get_beatmap(bid, md5)
|
||||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
beatmapset_id = resp.get("beatmapset_id")
|
||||||
|
if beatmapset_id is None:
|
||||||
|
raise ValueError("beatmapset_id is required")
|
||||||
|
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
|
||||||
if not r.first():
|
if not r.first():
|
||||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
set_resp = await fetcher.get_beatmapset(beatmapset_id)
|
||||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
resp_id = resp.get("id")
|
||||||
|
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
|
||||||
return await Beatmap.from_resp(session, resp)
|
return await Beatmap.from_resp(session, resp)
|
||||||
return beatmap
|
return beatmap
|
||||||
|
|
||||||
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
|
|||||||
count: int
|
count: int
|
||||||
|
|
||||||
|
|
||||||
class BeatmapResp(BeatmapBase):
|
|
||||||
id: int
|
|
||||||
beatmapset_id: int
|
|
||||||
beatmapset: BeatmapsetResp | None = None
|
|
||||||
convert: bool = False
|
|
||||||
is_scoreable: bool
|
|
||||||
status: str
|
|
||||||
mode_int: int
|
|
||||||
ranked: int
|
|
||||||
url: str = ""
|
|
||||||
playcount: int = 0
|
|
||||||
passcount: int = 0
|
|
||||||
failtimes: FailTimeResp | None = None
|
|
||||||
top_tag_ids: list[APIBeatmapTag] | None = None
|
|
||||||
current_user_tag_ids: list[int] | None = None
|
|
||||||
is_deleted: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
beatmap: Beatmap,
|
|
||||||
query_mode: GameMode | None = None,
|
|
||||||
from_set: bool = False,
|
|
||||||
session: AsyncSession | None = None,
|
|
||||||
user: "User | None" = None,
|
|
||||||
) -> "BeatmapResp":
|
|
||||||
from .score import Score
|
|
||||||
|
|
||||||
beatmap_ = beatmap.model_dump()
|
|
||||||
beatmap_status = beatmap.beatmap_status
|
|
||||||
if query_mode is not None and beatmap.mode != query_mode:
|
|
||||||
beatmap_["convert"] = True
|
|
||||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
|
||||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
|
||||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
|
||||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
|
||||||
else:
|
|
||||||
beatmap_["status"] = beatmap_status.name.lower()
|
|
||||||
beatmap_["ranked"] = beatmap_status.value
|
|
||||||
beatmap_["mode_int"] = int(beatmap.mode)
|
|
||||||
beatmap_["is_deleted"] = beatmap.deleted_at is not None
|
|
||||||
if not from_set:
|
|
||||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
|
||||||
if beatmap.failtimes is not None:
|
|
||||||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
|
||||||
else:
|
|
||||||
beatmap_["failtimes"] = FailTimeResp()
|
|
||||||
if session:
|
|
||||||
beatmap_["playcount"] = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
|
||||||
)
|
|
||||||
).first() or 0
|
|
||||||
beatmap_["passcount"] = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.count())
|
|
||||||
.select_from(Score)
|
|
||||||
.where(
|
|
||||||
Score.beatmap_id == beatmap.id,
|
|
||||||
col(Score.passed).is_(True),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).one()
|
|
||||||
|
|
||||||
all_votes = (
|
|
||||||
await session.exec(
|
|
||||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
|
||||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
|
||||||
.group_by(col(BeatmapTagVote.tag_id))
|
|
||||||
.having(func.count() > settings.beatmap_tag_top_count)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
top_tag_ids: list[dict[str, int]] = []
|
|
||||||
for id, votes in all_votes:
|
|
||||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
|
||||||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
|
||||||
beatmap_["top_tag_ids"] = top_tag_ids
|
|
||||||
|
|
||||||
if user is not None:
|
|
||||||
beatmap_["current_user_tag_ids"] = (
|
|
||||||
await session.exec(
|
|
||||||
select(BeatmapTagVote.tag_id)
|
|
||||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
|
||||||
.where(BeatmapTagVote.user_id == user.id)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
else:
|
|
||||||
beatmap_["current_user_tag_ids"] = []
|
|
||||||
return cls.model_validate(beatmap_)
|
|
||||||
|
|
||||||
|
|
||||||
class BannedBeatmaps(SQLModel, table=True):
|
class BannedBeatmaps(SQLModel, table=True):
|
||||||
__tablename__: str = "banned_beatmaps"
|
__tablename__: str = "banned_beatmaps"
|
||||||
id: int | None = Field(primary_key=True, index=True, default=None)
|
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.events import Event, EventType
|
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from ._base import DatabaseModel, included
|
||||||
|
from .events import Event, EventType
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
@@ -12,52 +13,65 @@ from sqlmodel import (
|
|||||||
Field,
|
Field,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from .beatmap import Beatmap, BeatmapDict
|
||||||
from .beatmapset import BeatmapsetResp
|
from .beatmapset import BeatmapsetDict
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
|
class BeatmapPlaycountsDict(TypedDict):
|
||||||
|
user_id: int
|
||||||
|
beatmap_id: int
|
||||||
|
count: NotRequired[int]
|
||||||
|
beatmap: NotRequired["BeatmapDict"]
|
||||||
|
beatmapset: NotRequired["BeatmapsetDict"]
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]):
|
||||||
__tablename__: str = "beatmap_playcounts"
|
__tablename__: str = "beatmap_playcounts"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
default=None,
|
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True
|
||||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
|
||||||
)
|
)
|
||||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||||
playcount: int = Field(default=0)
|
playcount: int = Field(default=0, exclude=True)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int:
|
||||||
|
return obj.playcount
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def beatmap(
|
||||||
|
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
|
||||||
|
) -> "BeatmapDict":
|
||||||
|
from .beatmap import BeatmapModel
|
||||||
|
|
||||||
|
await obj.awaitable_attrs.beatmap
|
||||||
|
return await BeatmapModel.transform(obj.beatmap, includes=includes)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def beatmapset(
|
||||||
|
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
|
||||||
|
) -> "BeatmapsetDict":
|
||||||
|
from .beatmap import BeatmapsetModel
|
||||||
|
|
||||||
|
await obj.awaitable_attrs.beatmap
|
||||||
|
return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes)
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True):
|
||||||
user: "User" = Relationship()
|
user: "User" = Relationship()
|
||||||
beatmap: "Beatmap" = Relationship()
|
beatmap: "Beatmap" = Relationship()
|
||||||
|
|
||||||
|
|
||||||
class BeatmapPlaycountsResp(BaseModel):
|
|
||||||
beatmap_id: int
|
|
||||||
beatmap: "BeatmapResp | None" = None
|
|
||||||
beatmapset: "BeatmapsetResp | None" = None
|
|
||||||
count: int
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp":
|
|
||||||
from .beatmap import BeatmapResp
|
|
||||||
from .beatmapset import BeatmapsetResp
|
|
||||||
|
|
||||||
await db_model.awaitable_attrs.beatmap
|
|
||||||
return cls(
|
|
||||||
beatmap_id=db_model.beatmap_id,
|
|
||||||
count=db_model.playcount,
|
|
||||||
beatmap=await BeatmapResp.from_db(db_model.beatmap),
|
|
||||||
beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
|
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
|
||||||
existing_playcount = (
|
existing_playcount = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
|
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.beatmap_playcounts import BeatmapPlaycounts
|
|
||||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from .user import BASE_INCLUDES, User, UserResp
|
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||||
|
from .beatmap_playcounts import BeatmapPlaycounts
|
||||||
|
from .user import User, UserDict
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator, model_validator
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
|
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
||||||
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from .beatmap import Beatmap, BeatmapDict
|
||||||
from .favourite_beatmapset import FavouriteBeatmapset
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
|
|
||||||
|
|
||||||
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
|
|||||||
id: int | None = None
|
id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class BeatmapsetBase(SQLModel):
|
class BeatmapsetDict(TypedDict):
|
||||||
|
id: int
|
||||||
|
artist: str
|
||||||
|
artist_unicode: str
|
||||||
|
covers: BeatmapCovers | None
|
||||||
|
creator: str
|
||||||
|
nsfw: bool
|
||||||
|
preview_url: str
|
||||||
|
source: str
|
||||||
|
spotlight: bool
|
||||||
|
title: str
|
||||||
|
title_unicode: str
|
||||||
|
track_id: int | None
|
||||||
|
user_id: int
|
||||||
|
video: bool
|
||||||
|
current_nominations: list[BeatmapNomination] | None
|
||||||
|
description: BeatmapDescription | None
|
||||||
|
pack_tags: list[str]
|
||||||
|
|
||||||
|
bpm: NotRequired[float]
|
||||||
|
can_be_hyped: NotRequired[bool]
|
||||||
|
discussion_locked: NotRequired[bool]
|
||||||
|
last_updated: NotRequired[datetime]
|
||||||
|
ranked_date: NotRequired[datetime | None]
|
||||||
|
storyboard: NotRequired[bool]
|
||||||
|
submitted_date: NotRequired[datetime]
|
||||||
|
tags: NotRequired[str]
|
||||||
|
discussion_enabled: NotRequired[bool]
|
||||||
|
legacy_thread_url: NotRequired[str | None]
|
||||||
|
status: NotRequired[str]
|
||||||
|
ranked: NotRequired[int]
|
||||||
|
is_scoreable: NotRequired[bool]
|
||||||
|
favourite_count: NotRequired[int]
|
||||||
|
genre_id: NotRequired[int]
|
||||||
|
hype: NotRequired[BeatmapHype]
|
||||||
|
language_id: NotRequired[int]
|
||||||
|
play_count: NotRequired[int]
|
||||||
|
availability: NotRequired[BeatmapAvailability]
|
||||||
|
beatmaps: NotRequired[list["BeatmapDict"]]
|
||||||
|
has_favourited: NotRequired[bool]
|
||||||
|
recent_favourites: NotRequired[list[UserDict]]
|
||||||
|
genre: NotRequired[BeatmapTranslationText]
|
||||||
|
language: NotRequired[BeatmapTranslationText]
|
||||||
|
nominations: NotRequired["BeatmapNominations"]
|
||||||
|
ratings: NotRequired[list[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
|
||||||
|
BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"availability",
|
||||||
|
"has_favourited",
|
||||||
|
"bpm",
|
||||||
|
"deleted_atcan_be_hyped",
|
||||||
|
"discussion_locked",
|
||||||
|
"is_scoreable",
|
||||||
|
"last_updated",
|
||||||
|
"legacy_thread_url",
|
||||||
|
"ranked",
|
||||||
|
"ranked_date",
|
||||||
|
"submitted_date",
|
||||||
|
"tags",
|
||||||
|
"rating",
|
||||||
|
"storyboard",
|
||||||
|
]
|
||||||
|
API_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
*BEATMAPSET_TRANSFORMER_INCLUDES,
|
||||||
|
"beatmaps.current_user_playcount",
|
||||||
|
"beatmaps.current_user_tag_ids",
|
||||||
|
"beatmaps.max_combo",
|
||||||
|
"current_nominations",
|
||||||
|
"current_user_attributes",
|
||||||
|
"description",
|
||||||
|
"genre",
|
||||||
|
"language",
|
||||||
|
"pack_tags",
|
||||||
|
"ratings",
|
||||||
|
"recent_favourites",
|
||||||
|
"related_tags",
|
||||||
|
"related_users",
|
||||||
|
"user",
|
||||||
|
"version_count",
|
||||||
|
*[
|
||||||
|
f"beatmaps.{inc}"
|
||||||
|
for inc in {
|
||||||
|
"failtimes",
|
||||||
|
"owners",
|
||||||
|
"top_tag_ids",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
# Beatmapset
|
# Beatmapset
|
||||||
|
id: int = Field(default=None, primary_key=True, index=True)
|
||||||
artist: str = Field(index=True)
|
artist: str = Field(index=True)
|
||||||
artist_unicode: str = Field(index=True)
|
artist_unicode: str = Field(index=True)
|
||||||
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
|
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
|
||||||
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
|
|||||||
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
|
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
|
||||||
preview_url: str
|
preview_url: str
|
||||||
source: str = Field(default="")
|
source: str = Field(default="")
|
||||||
|
|
||||||
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
|
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
|
||||||
title: str = Field(index=True)
|
title: str = Field(index=True)
|
||||||
title_unicode: str = Field(index=True)
|
title_unicode: str = Field(index=True)
|
||||||
|
track_id: int | None = Field(default=None, index=True) # feature artist?
|
||||||
user_id: int = Field(index=True)
|
user_id: int = Field(index=True)
|
||||||
video: bool = Field(sa_column=Column(Boolean, index=True))
|
video: bool = Field(sa_column=Column(Boolean, index=True))
|
||||||
|
|
||||||
# optional
|
# optional
|
||||||
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
||||||
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
|
current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
|
||||||
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
|
description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
|
||||||
# TODO: discussions: list[BeatmapsetDiscussion] = None
|
# TODO: discussions: list[BeatmapsetDiscussion] = None
|
||||||
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
|
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
|
||||||
# TODO: events: Optional[list[BeatmapsetEvent]] = None
|
# TODO: events: Optional[list[BeatmapsetEvent]] = None
|
||||||
|
|
||||||
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
|
pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
|
||||||
# TODO: related_users: Optional[list[User]] = None
|
# TODO: related_users: Optional[list[User]] = None
|
||||||
# TODO: user: Optional[User] = Field(default=None)
|
# TODO: user: Optional[User] = Field(default=None)
|
||||||
track_id: int | None = Field(default=None, index=True) # feature artist?
|
|
||||||
|
|
||||||
# BeatmapsetExtended
|
# BeatmapsetExtended
|
||||||
bpm: float = Field(default=0.0)
|
bpm: OnDemand[float] = Field(default=0.0)
|
||||||
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
|
can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
|
||||||
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
|
discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
|
||||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||||
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
|
ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
|
||||||
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
|
storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||||
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
|
submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||||
tags: str = Field(default="", sa_column=Column(Text))
|
tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def legacy_thread_url(
|
||||||
|
_session: AsyncSession,
|
||||||
|
_beatmapset: "Beatmapset",
|
||||||
|
) -> str | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def discussion_enabled(
|
||||||
|
_session: AsyncSession,
|
||||||
|
_beatmapset: "Beatmapset",
|
||||||
|
) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def status(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> str:
|
||||||
|
beatmap_status = beatmapset.beatmap_status
|
||||||
|
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||||
|
return BeatmapRankStatus.APPROVED.name.lower()
|
||||||
|
return beatmap_status.name.lower()
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def ranked(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> int:
|
||||||
|
beatmap_status = beatmapset.beatmap_status
|
||||||
|
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||||
|
return BeatmapRankStatus.APPROVED.value
|
||||||
|
return beatmap_status.value
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def is_scoreable(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> bool:
|
||||||
|
beatmap_status = beatmapset.beatmap_status
|
||||||
|
if settings.enable_all_beatmap_leaderboard:
|
||||||
|
return True
|
||||||
|
return beatmap_status.has_leaderboard()
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def favourite_count(
|
||||||
|
session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> int:
|
||||||
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
|
|
||||||
|
count = await session.exec(
|
||||||
|
select(func.count())
|
||||||
|
.select_from(FavouriteBeatmapset)
|
||||||
|
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||||
|
)
|
||||||
|
return count.one()
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def genre_id(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> int:
|
||||||
|
return beatmapset.beatmap_genre.value
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def hype(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> BeatmapHype:
|
||||||
|
return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def language_id(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> int:
|
||||||
|
return beatmapset.beatmap_language.value
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def play_count(
|
||||||
|
session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> int:
|
||||||
|
from .beatmap import Beatmap
|
||||||
|
|
||||||
|
playcount = await session.exec(
|
||||||
|
select(func.sum(BeatmapPlaycounts.playcount)).where(
|
||||||
|
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(playcount.first() or 0)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def availability(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> BeatmapAvailability:
|
||||||
|
return BeatmapAvailability(
|
||||||
|
more_information=beatmapset.availability_info,
|
||||||
|
download_disabled=beatmapset.download_disabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def beatmaps(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
user: "User | None" = None,
|
||||||
|
) -> list["BeatmapDict"]:
|
||||||
|
from .beatmap import BeatmapModel
|
||||||
|
|
||||||
|
return [
|
||||||
|
await BeatmapModel.transform(
|
||||||
|
beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
|
||||||
|
)
|
||||||
|
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||||
|
]
|
||||||
|
|
||||||
|
# @ondemand
|
||||||
|
# @staticmethod
|
||||||
|
# async def current_nominations(
|
||||||
|
# _session: AsyncSession,
|
||||||
|
# beatmapset: "Beatmapset",
|
||||||
|
# ) -> list[BeatmapNomination] | None:
|
||||||
|
# return beatmapset.current_nominations or []
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def has_favourited(
|
||||||
|
session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
user: User | None = None,
|
||||||
|
) -> bool:
|
||||||
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return False
|
||||||
|
query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||||
|
if user is not None:
|
||||||
|
query = query.where(FavouriteBeatmapset.user_id == user.id)
|
||||||
|
existing = (await session.exec(query)).first()
|
||||||
|
return existing is not None
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def recent_favourites(
|
||||||
|
session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> list[UserDict]:
|
||||||
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
|
|
||||||
|
recent_favourites = (
|
||||||
|
await session.exec(
|
||||||
|
select(FavouriteBeatmapset)
|
||||||
|
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||||
|
.order_by(col(FavouriteBeatmapset.date).desc())
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
return [
|
||||||
|
await User.transform(
|
||||||
|
(await favourite.awaitable_attrs.user),
|
||||||
|
includes=includes,
|
||||||
|
)
|
||||||
|
for favourite in recent_favourites
|
||||||
|
]
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def genre(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> BeatmapTranslationText:
|
||||||
|
return BeatmapTranslationText(
|
||||||
|
name=beatmapset.beatmap_genre.name,
|
||||||
|
id=beatmapset.beatmap_genre.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def language(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> BeatmapTranslationText:
|
||||||
|
return BeatmapTranslationText(
|
||||||
|
name=beatmapset.beatmap_language.name,
|
||||||
|
id=beatmapset.beatmap_language.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def nominations(
|
||||||
|
_session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> BeatmapNominations:
|
||||||
|
return BeatmapNominations(
|
||||||
|
required=beatmapset.nominations_required,
|
||||||
|
current=beatmapset.nominations_current,
|
||||||
|
)
|
||||||
|
|
||||||
|
# @ondemand
|
||||||
|
# @staticmethod
|
||||||
|
# async def user(
|
||||||
|
# session: AsyncSession,
|
||||||
|
# beatmapset: Beatmapset,
|
||||||
|
# includes: list[str] | None = None,
|
||||||
|
# ) -> dict[str, Any] | None:
|
||||||
|
# db_user = await session.get(User, beatmapset.user_id)
|
||||||
|
# if not db_user:
|
||||||
|
# return None
|
||||||
|
# return await UserResp.transform(db_user, includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def ratings(
|
||||||
|
session: AsyncSession,
|
||||||
|
beatmapset: "Beatmapset",
|
||||||
|
) -> list[int]:
|
||||||
|
# Provide a stable default shape if no session is available
|
||||||
|
if session is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
from .beatmapset_ratings import BeatmapRating
|
||||||
|
|
||||||
|
beatmapset_all_ratings = (
|
||||||
|
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
|
||||||
|
).all()
|
||||||
|
ratings_list = [0] * 11
|
||||||
|
for rating in beatmapset_all_ratings:
|
||||||
|
ratings_list[rating.rating] += 1
|
||||||
|
return ratings_list
|
||||||
|
|
||||||
|
|
||||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
|
||||||
__tablename__: str = "beatmapsets"
|
__tablename__: str = "beatmapsets"
|
||||||
|
|
||||||
id: int = Field(default=None, primary_key=True, index=True)
|
|
||||||
# Beatmapset
|
# Beatmapset
|
||||||
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
|
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
|
||||||
|
|
||||||
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
|
async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
|
||||||
d = resp.model_dump()
|
# make a shallow copy so we can mutate safely
|
||||||
if resp.nominations:
|
d: dict[str, Any] = dict(resp)
|
||||||
d["nominations_required"] = resp.nominations.required
|
|
||||||
d["nominations_current"] = resp.nominations.current
|
# nominations = resp.get("nominations")
|
||||||
if resp.hype:
|
# if nominations is not None:
|
||||||
d["hype_current"] = resp.hype.current
|
# d["nominations_required"] = nominations.required
|
||||||
d["hype_required"] = resp.hype.required
|
# d["nominations_current"] = nominations.current
|
||||||
if resp.genre_id:
|
|
||||||
d["beatmap_genre"] = Genre(resp.genre_id)
|
hype = resp.get("hype")
|
||||||
elif resp.genre:
|
if hype is not None:
|
||||||
d["beatmap_genre"] = Genre(resp.genre.id)
|
d["hype_current"] = hype.current
|
||||||
if resp.language_id:
|
d["hype_required"] = hype.required
|
||||||
d["beatmap_language"] = Language(resp.language_id)
|
|
||||||
elif resp.language:
|
genre_id = resp.get("genre_id")
|
||||||
d["beatmap_language"] = Language(resp.language.id)
|
genre = resp.get("genre")
|
||||||
|
if genre_id is not None:
|
||||||
|
d["beatmap_genre"] = Genre(genre_id)
|
||||||
|
elif genre is not None:
|
||||||
|
d["beatmap_genre"] = Genre(genre.id)
|
||||||
|
|
||||||
|
language_id = resp.get("language_id")
|
||||||
|
language = resp.get("language")
|
||||||
|
if language_id is not None:
|
||||||
|
d["beatmap_language"] = Language(language_id)
|
||||||
|
elif language is not None:
|
||||||
|
d["beatmap_language"] = Language(language.id)
|
||||||
|
|
||||||
|
availability = resp.get("availability")
|
||||||
|
ranked = resp.get("ranked")
|
||||||
|
if ranked is None:
|
||||||
|
raise ValueError("ranked field is required")
|
||||||
|
|
||||||
beatmapset = Beatmapset.model_validate(
|
beatmapset = Beatmapset.model_validate(
|
||||||
{
|
{
|
||||||
**d,
|
**d,
|
||||||
"id": resp.id,
|
"beatmap_status": BeatmapRankStatus(ranked),
|
||||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
"availability_info": availability.more_information if availability is not None else None,
|
||||||
"availability_info": resp.availability.more_information,
|
"download_disabled": bool(availability.download_disabled) if availability is not None else False,
|
||||||
"download_disabled": resp.availability.download_disabled or False,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return beatmapset
|
return beatmapset
|
||||||
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
async def from_resp(
|
async def from_resp(
|
||||||
cls,
|
cls,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
resp: "BeatmapsetResp",
|
resp: BeatmapsetDict,
|
||||||
from_: int = 0,
|
from_: int = 0,
|
||||||
) -> "Beatmapset":
|
) -> "Beatmapset":
|
||||||
from .beatmap import Beatmap
|
from .beatmap import Beatmap
|
||||||
|
|
||||||
|
beatmapset_id = resp["id"]
|
||||||
beatmapset = await cls.from_resp_no_save(resp)
|
beatmapset = await cls.from_resp_no_save(resp)
|
||||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
|
||||||
session.add(beatmapset)
|
session.add(beatmapset)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
beatmaps = resp.get("beatmaps", [])
|
||||||
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
|
await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
|
||||||
|
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
|
||||||
return beatmapset
|
return beatmapset
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
resp = await fetcher.get_beatmapset(sid)
|
resp = await fetcher.get_beatmapset(sid)
|
||||||
beatmapset = await cls.from_resp(session, resp)
|
beatmapset = await cls.from_resp(session, resp)
|
||||||
await get_beatmapset_update_service().add(resp)
|
await get_beatmapset_update_service().add(resp)
|
||||||
|
await session.refresh(beatmapset)
|
||||||
return beatmapset
|
return beatmapset
|
||||||
|
|
||||||
|
|
||||||
class BeatmapsetResp(BeatmapsetBase):
|
|
||||||
id: int
|
|
||||||
beatmaps: list["BeatmapResp"] = Field(default_factory=list)
|
|
||||||
discussion_enabled: bool = True
|
|
||||||
status: str
|
|
||||||
ranked: int
|
|
||||||
legacy_thread_url: str | None = ""
|
|
||||||
is_scoreable: bool
|
|
||||||
hype: BeatmapHype | None = None
|
|
||||||
availability: BeatmapAvailability
|
|
||||||
genre: BeatmapTranslationText | None = None
|
|
||||||
genre_id: int
|
|
||||||
language: BeatmapTranslationText | None = None
|
|
||||||
language_id: int
|
|
||||||
nominations: BeatmapNominations | None = None
|
|
||||||
has_favourited: bool = False
|
|
||||||
favourite_count: int = 0
|
|
||||||
recent_favourites: list[UserResp] = Field(default_factory=list)
|
|
||||||
play_count: int = 0
|
|
||||||
|
|
||||||
@field_validator(
|
|
||||||
"nsfw",
|
|
||||||
"spotlight",
|
|
||||||
"video",
|
|
||||||
"can_be_hyped",
|
|
||||||
"discussion_locked",
|
|
||||||
"storyboard",
|
|
||||||
"discussion_enabled",
|
|
||||||
"is_scoreable",
|
|
||||||
"has_favourited",
|
|
||||||
mode="before",
|
|
||||||
)
|
|
||||||
@classmethod
|
|
||||||
def validate_bool_fields(cls, v):
|
|
||||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
|
||||||
if isinstance(v, int):
|
|
||||||
return bool(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def fix_genre_language(self) -> Self:
|
|
||||||
if self.genre is None:
|
|
||||||
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
|
|
||||||
if self.language is None:
|
|
||||||
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
beatmapset: Beatmapset,
|
|
||||||
include: list[str] = [],
|
|
||||||
session: AsyncSession | None = None,
|
|
||||||
user: User | None = None,
|
|
||||||
) -> "BeatmapsetResp":
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
|
||||||
from .favourite_beatmapset import FavouriteBeatmapset
|
|
||||||
|
|
||||||
update = {
|
|
||||||
"beatmaps": [
|
|
||||||
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
|
|
||||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
|
||||||
],
|
|
||||||
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
|
|
||||||
"availability": BeatmapAvailability(
|
|
||||||
more_information=beatmapset.availability_info,
|
|
||||||
download_disabled=beatmapset.download_disabled,
|
|
||||||
),
|
|
||||||
"genre": BeatmapTranslationText(
|
|
||||||
name=beatmapset.beatmap_genre.name,
|
|
||||||
id=beatmapset.beatmap_genre.value,
|
|
||||||
),
|
|
||||||
"language": BeatmapTranslationText(
|
|
||||||
name=beatmapset.beatmap_language.name,
|
|
||||||
id=beatmapset.beatmap_language.value,
|
|
||||||
),
|
|
||||||
"genre_id": beatmapset.beatmap_genre.value,
|
|
||||||
"language_id": beatmapset.beatmap_language.value,
|
|
||||||
"nominations": BeatmapNominations(
|
|
||||||
required=beatmapset.nominations_required,
|
|
||||||
current=beatmapset.nominations_current,
|
|
||||||
),
|
|
||||||
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
|
|
||||||
**beatmapset.model_dump(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if session is not None:
|
|
||||||
# 从数据库读取对应谱面集的评分
|
|
||||||
from .beatmapset_ratings import BeatmapRating
|
|
||||||
|
|
||||||
beatmapset_all_ratings = (
|
|
||||||
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
|
|
||||||
).all()
|
|
||||||
ratings_list = [0] * 11
|
|
||||||
for rating in beatmapset_all_ratings:
|
|
||||||
ratings_list[rating.rating] += 1
|
|
||||||
update["ratings"] = ratings_list
|
|
||||||
else:
|
|
||||||
# 返回非空值避免客户端崩溃
|
|
||||||
if update.get("ratings") is None:
|
|
||||||
update["ratings"] = []
|
|
||||||
|
|
||||||
beatmap_status = beatmapset.beatmap_status
|
|
||||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
|
||||||
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
|
||||||
update["ranked"] = BeatmapRankStatus.APPROVED.value
|
|
||||||
else:
|
|
||||||
update["status"] = beatmap_status.name.lower()
|
|
||||||
update["ranked"] = beatmap_status.value
|
|
||||||
|
|
||||||
if session and user:
|
|
||||||
existing_favourite = (
|
|
||||||
await session.exec(
|
|
||||||
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
update["has_favourited"] = existing_favourite is not None
|
|
||||||
|
|
||||||
if session and "recent_favourites" in include:
|
|
||||||
recent_favourites = (
|
|
||||||
await session.exec(
|
|
||||||
select(FavouriteBeatmapset)
|
|
||||||
.where(
|
|
||||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
|
|
||||||
)
|
|
||||||
.order_by(col(FavouriteBeatmapset.date).desc())
|
|
||||||
.limit(50)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
update["recent_favourites"] = [
|
|
||||||
await UserResp.from_db(
|
|
||||||
await favourite.awaitable_attrs.user,
|
|
||||||
session=session,
|
|
||||||
include=BASE_INCLUDES,
|
|
||||||
)
|
|
||||||
for favourite in recent_favourites
|
|
||||||
]
|
|
||||||
|
|
||||||
if session:
|
|
||||||
update["favourite_count"] = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.count())
|
|
||||||
.select_from(FavouriteBeatmapset)
|
|
||||||
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
|
||||||
)
|
|
||||||
).one()
|
|
||||||
|
|
||||||
update["play_count"] = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.sum(BeatmapPlaycounts.playcount)).where(
|
|
||||||
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).first() or 0
|
|
||||||
return cls.model_validate(
|
|
||||||
update,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchBeatmapsetsResp(SQLModel):
|
|
||||||
beatmapsets: list[BeatmapsetResp]
|
|
||||||
total: int
|
|
||||||
cursor: dict[str, int | float | str] | None = None
|
|
||||||
cursor_string: str | None = None
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from app.database.beatmapset import Beatmapset
|
from .beatmapset import Beatmapset
|
||||||
from app.database.user import User
|
from .user import User
|
||||||
|
|
||||||
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel
|
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.database.statistics import UserStatistics
|
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
|
from .statistics import UserStatistics
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Self
|
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
|
||||||
from app.models.model import UTCBaseModel
|
from app.models.model import UTCBaseModel
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
|
from ._base import DatabaseModel, included, ondemand
|
||||||
|
from .user import User, UserDict, UserModel
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.asyncio import Redis
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
VARCHAR,
|
VARCHAR,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
@@ -22,6 +23,8 @@ from sqlmodel import (
|
|||||||
)
|
)
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.router.notification.server import ChatServer
|
||||||
# ChatChannel
|
# ChatChannel
|
||||||
|
|
||||||
|
|
||||||
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
|
|||||||
TEAM = "TEAM"
|
TEAM = "TEAM"
|
||||||
|
|
||||||
|
|
||||||
class ChatChannelBase(SQLModel):
|
class MessageType(str, Enum):
|
||||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
ACTION = "action"
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
PLAIN = "plain"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannelDict(TypedDict):
|
||||||
|
channel_id: int
|
||||||
|
description: str
|
||||||
|
name: str
|
||||||
|
icon: str | None
|
||||||
|
type: ChannelType
|
||||||
|
uuid: NotRequired[str | None]
|
||||||
|
message_length_limit: NotRequired[int]
|
||||||
|
moderated: NotRequired[bool]
|
||||||
|
current_user_attributes: NotRequired[ChatUserAttributes]
|
||||||
|
last_read_id: NotRequired[int | None]
|
||||||
|
last_message_id: NotRequired[int | None]
|
||||||
|
recent_messages: NotRequired[list["ChatMessageDict"]]
|
||||||
|
users: NotRequired[list[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannelModel(DatabaseModel[ChatChannelDict]):
|
||||||
|
CONVERSATION_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"last_message_id",
|
||||||
|
"users",
|
||||||
|
]
|
||||||
|
LISTING_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
*CONVERSATION_INCLUDES,
|
||||||
|
"current_user_attributes",
|
||||||
|
"last_read_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||||
icon: str | None = Field(default=None)
|
icon: str | None = Field(default=None)
|
||||||
type: ChannelType = Field(index=True)
|
type: ChannelType = Field(index=True)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
|
||||||
|
users = server.channels.get(channel.channel_id, [])
|
||||||
|
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||||
|
target_user_id = next(u for u in users if u != user.id)
|
||||||
|
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||||
|
return target_name.one()
|
||||||
|
return channel.name
|
||||||
|
|
||||||
class ChatChannel(ChatChannelBase, table=True):
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
|
||||||
|
silence = (
|
||||||
|
await session.exec(
|
||||||
|
select(SilenceUser).where(
|
||||||
|
SilenceUser.channel_id == channel.channel_id,
|
||||||
|
SilenceUser.user_id == user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
return silence is not None
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def current_user_attributes(
|
||||||
|
session: AsyncSession,
|
||||||
|
channel: "ChatChannel",
|
||||||
|
user: User,
|
||||||
|
) -> ChatUserAttributes:
|
||||||
|
from app.dependencies.database import get_redis
|
||||||
|
|
||||||
|
silence = (
|
||||||
|
await session.exec(
|
||||||
|
select(SilenceUser).where(
|
||||||
|
SilenceUser.channel_id == channel.channel_id,
|
||||||
|
SilenceUser.user_id == user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
can_message = silence is None
|
||||||
|
can_message_error = "You are silenced in this channel" if not can_message else None
|
||||||
|
|
||||||
|
redis = get_redis()
|
||||||
|
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||||
|
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||||
|
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||||
|
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
|
||||||
|
|
||||||
|
return ChatUserAttributes(
|
||||||
|
can_message=can_message,
|
||||||
|
can_message_error=can_message_error,
|
||||||
|
last_read_id=last_read_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
|
||||||
|
from app.dependencies.database import get_redis
|
||||||
|
|
||||||
|
redis = get_redis()
|
||||||
|
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||||
|
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||||
|
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||||
|
return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
|
||||||
|
from app.dependencies.database import get_redis
|
||||||
|
|
||||||
|
redis = get_redis()
|
||||||
|
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||||
|
return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def recent_messages(
|
||||||
|
session: AsyncSession,
|
||||||
|
channel: "ChatChannel",
|
||||||
|
) -> list["ChatMessageDict"]:
|
||||||
|
messages = (
|
||||||
|
await session.exec(
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(ChatMessage.channel_id == channel.channel_id)
|
||||||
|
.order_by(col(ChatMessage.message_id).desc())
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
result = [
|
||||||
|
await ChatMessageModel.transform(
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
for msg in reversed(messages)
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def users(
|
||||||
|
_session: AsyncSession,
|
||||||
|
channel: "ChatChannel",
|
||||||
|
server: "ChatServer",
|
||||||
|
user: User,
|
||||||
|
) -> list[int]:
|
||||||
|
if channel.type == ChannelType.PUBLIC:
|
||||||
|
return []
|
||||||
|
users = server.channels.get(channel.channel_id, []).copy()
|
||||||
|
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||||
|
target_user_id = next(u for u in users if u != user.id)
|
||||||
|
users = [target_user_id, user.id]
|
||||||
|
return users
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
|
||||||
|
return 1000
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannel(ChatChannelModel, table=True):
|
||||||
__tablename__: str = "chat_channels"
|
__tablename__: str = "chat_channels"
|
||||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
|
||||||
|
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||||
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
|
|||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
class ChatChannelResp(ChatChannelBase):
|
|
||||||
channel_id: int
|
|
||||||
moderated: bool = False
|
|
||||||
uuid: str | None = None
|
|
||||||
current_user_attributes: ChatUserAttributes | None = None
|
|
||||||
last_read_id: int | None = None
|
|
||||||
last_message_id: int | None = None
|
|
||||||
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
|
|
||||||
users: list[int] = Field(default_factory=list)
|
|
||||||
message_length_limit: int = 1000
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
channel: ChatChannel,
|
|
||||||
session: AsyncSession,
|
|
||||||
user: User,
|
|
||||||
redis: Redis,
|
|
||||||
users: list[int] | None = None,
|
|
||||||
include_recent_messages: bool = False,
|
|
||||||
) -> Self:
|
|
||||||
c = cls.model_validate(channel)
|
|
||||||
silence = (
|
|
||||||
await session.exec(
|
|
||||||
select(SilenceUser).where(
|
|
||||||
SilenceUser.channel_id == channel.channel_id,
|
|
||||||
SilenceUser.user_id == user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
|
||||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
|
||||||
|
|
||||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
|
||||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
|
||||||
|
|
||||||
if silence is not None:
|
|
||||||
attribute = ChatUserAttributes(
|
|
||||||
can_message=False,
|
|
||||||
can_message_error=silence.reason or "You are muted in this channel.",
|
|
||||||
last_read_id=last_read_id or 0,
|
|
||||||
)
|
|
||||||
c.moderated = True
|
|
||||||
else:
|
|
||||||
attribute = ChatUserAttributes(
|
|
||||||
can_message=True,
|
|
||||||
last_read_id=last_read_id or 0,
|
|
||||||
)
|
|
||||||
c.moderated = False
|
|
||||||
|
|
||||||
c.current_user_attributes = attribute
|
|
||||||
if c.type != ChannelType.PUBLIC and users is not None:
|
|
||||||
c.users = users
|
|
||||||
c.last_message_id = last_msg
|
|
||||||
c.last_read_id = last_read_id
|
|
||||||
|
|
||||||
if include_recent_messages:
|
|
||||||
messages = (
|
|
||||||
await session.exec(
|
|
||||||
select(ChatMessage)
|
|
||||||
.where(ChatMessage.channel_id == channel.channel_id)
|
|
||||||
.order_by(col(ChatMessage.timestamp).desc())
|
|
||||||
.limit(10)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
|
||||||
c.recent_messages.reverse()
|
|
||||||
|
|
||||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
|
||||||
target_user_id = next(u for u in users if u != user.id)
|
|
||||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
|
||||||
c.name = target_name.one()
|
|
||||||
c.users = [target_user_id, user.id]
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
# ChatMessage
|
# ChatMessage
|
||||||
|
class ChatMessageDict(TypedDict):
|
||||||
|
channel_id: int
|
||||||
|
content: str
|
||||||
|
message_id: int
|
||||||
|
sender_id: int
|
||||||
|
timestamp: datetime
|
||||||
|
type: MessageType
|
||||||
|
uuid: str | None
|
||||||
|
is_action: NotRequired[bool]
|
||||||
|
sender: NotRequired[UserDict]
|
||||||
|
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class ChatMessageModel(DatabaseModel[ChatMessageDict]):
|
||||||
ACTION = "action"
|
|
||||||
MARKDOWN = "markdown"
|
|
||||||
PLAIN = "plain"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
|
||||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||||
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
|
|||||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||||
uuid: str | None = Field(default=None)
|
uuid: str | None = Field(default=None)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
|
||||||
|
return db_message.type == MessageType.ACTION
|
||||||
|
|
||||||
class ChatMessage(ChatMessageBase, table=True):
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
|
||||||
|
return await UserModel.transform(db_message.user)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(ChatMessageModel, table=True):
|
||||||
__tablename__: str = "chat_messages"
|
__tablename__: str = "chat_messages"
|
||||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||||
channel: ChatChannel = Relationship()
|
channel: "ChatChannel" = Relationship()
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageResp(ChatMessageBase):
|
|
||||||
sender: UserResp | None = None
|
|
||||||
is_action: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
|
||||||
) -> "ChatMessageResp":
|
|
||||||
m = cls.model_validate(db_message.model_dump())
|
|
||||||
m.is_action = db_message.type == MessageType.ACTION
|
|
||||||
if user:
|
|
||||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
|
||||||
else:
|
|
||||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
|
||||||
return m
|
|
||||||
|
|
||||||
|
|
||||||
# SilenceUser
|
|
||||||
|
|
||||||
|
|
||||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from app.database.beatmapset import Beatmapset
|
from .beatmapset import Beatmapset
|
||||||
from app.database.user import User
|
from .user import User
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from .playlist_best_score import PlaylistBestScore
|
from typing import Any, NotRequired, TypedDict
|
||||||
from .user import User, UserResp
|
|
||||||
|
from ._base import DatabaseModel, ondemand
|
||||||
|
from .playlist_best_score import PlaylistBestScore
|
||||||
|
from .user import User, UserDict, UserModel
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
@@ -9,7 +11,6 @@ from sqlmodel import (
|
|||||||
Field,
|
Field,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
|
||||||
col,
|
col,
|
||||||
func,
|
func,
|
||||||
select,
|
select,
|
||||||
@@ -17,17 +18,66 @@ from sqlmodel import (
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
class ItemAttemptsCountBase(SQLModel):
|
class ItemAttemptsCountDict(TypedDict):
|
||||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
accuracy: float
|
||||||
|
attempts: int
|
||||||
|
completed: int
|
||||||
|
pp: float
|
||||||
|
room_id: int
|
||||||
|
total_score: int
|
||||||
|
user_id: int
|
||||||
|
user: NotRequired[UserDict]
|
||||||
|
position: NotRequired[int]
|
||||||
|
playlist_item_attempts: NotRequired[list[dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]):
|
||||||
|
accuracy: float = 0.0
|
||||||
attempts: int = Field(default=0)
|
attempts: int = Field(default=0)
|
||||||
completed: int = Field(default=0)
|
completed: int = Field(default=0)
|
||||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
|
||||||
accuracy: float = 0.0
|
|
||||||
pp: float = 0
|
pp: float = 0
|
||||||
|
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||||
total_score: int = 0
|
total_score: int = 0
|
||||||
|
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict:
|
||||||
|
user_instance = await item_attempts.awaitable_attrs.user
|
||||||
|
return await UserModel.transform(user_instance)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int:
|
||||||
|
return await item_attempts.get_position(session)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def playlist_item_attempts(
|
||||||
|
session: AsyncSession,
|
||||||
|
item_attempts: "ItemAttemptsCount",
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
playlist_scores = (
|
||||||
|
await session.exec(
|
||||||
|
select(PlaylistBestScore).where(
|
||||||
|
PlaylistBestScore.room_id == item_attempts.room_id,
|
||||||
|
PlaylistBestScore.user_id == item_attempts.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
for score in playlist_scores:
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"id": score.playlist_id,
|
||||||
|
"attempts": score.attempts,
|
||||||
|
"passed": score.score.passed,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True):
|
||||||
__tablename__: str = "item_attempts_count"
|
__tablename__: str = "item_attempts_count"
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
@@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
|||||||
rownum = (
|
rownum = (
|
||||||
func.row_number()
|
func.row_number()
|
||||||
.over(
|
.over(
|
||||||
partition_by=col(ItemAttemptsCountBase.room_id),
|
partition_by=col(ItemAttemptsCount.room_id),
|
||||||
order_by=col(ItemAttemptsCountBase.total_score).desc(),
|
order_by=col(ItemAttemptsCount.total_score).desc(),
|
||||||
)
|
)
|
||||||
.label("rn")
|
.label("rn")
|
||||||
)
|
)
|
||||||
subq = select(ItemAttemptsCountBase, rownum).subquery()
|
subq = select(ItemAttemptsCount, rownum).subquery()
|
||||||
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
||||||
result = await session.exec(stmt)
|
result = await session.exec(stmt)
|
||||||
return result.one()
|
return result.first() or 0
|
||||||
|
|
||||||
async def update(self, session: AsyncSession):
|
async def update(self, session: AsyncSession):
|
||||||
playlist_scores = (
|
playlist_scores = (
|
||||||
@@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
|||||||
await session.refresh(item_attempts)
|
await session.refresh(item_attempts)
|
||||||
await item_attempts.update(session)
|
await item_attempts.update(session)
|
||||||
return item_attempts
|
return item_attempts
|
||||||
|
|
||||||
|
|
||||||
class ItemAttemptsResp(ItemAttemptsCountBase):
|
|
||||||
user: UserResp | None = None
|
|
||||||
position: int | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
item_attempts: ItemAttemptsCount,
|
|
||||||
session: AsyncSession,
|
|
||||||
include: list[str] = [],
|
|
||||||
) -> "ItemAttemptsResp":
|
|
||||||
resp = cls.model_validate(item_attempts.model_dump())
|
|
||||||
resp.user = await UserResp.from_db(
|
|
||||||
await item_attempts.awaitable_attrs.user,
|
|
||||||
session=session,
|
|
||||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
|
||||||
)
|
|
||||||
if "position" in include:
|
|
||||||
resp.position = await item_attempts.get_position(session)
|
|
||||||
# resp.accuracy *= 100
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class ItemAttemptsCountForItem(BaseModel):
|
|
||||||
id: int
|
|
||||||
attempts: int
|
|
||||||
passed: bool
|
|
||||||
|
|
||||||
|
|
||||||
class PlaylistAggregateScore(BaseModel):
|
|
||||||
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
room_id: int,
|
|
||||||
user_id: int,
|
|
||||||
session: AsyncSession,
|
|
||||||
) -> "PlaylistAggregateScore":
|
|
||||||
playlist_scores = (
|
|
||||||
await session.exec(
|
|
||||||
select(PlaylistBestScore).where(
|
|
||||||
PlaylistBestScore.room_id == room_id,
|
|
||||||
PlaylistBestScore.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
playlist_item_attempts = []
|
|
||||||
for score in playlist_scores:
|
|
||||||
playlist_item_attempts.append(
|
|
||||||
ItemAttemptsCountForItem(
|
|
||||||
id=score.playlist_id,
|
|
||||||
attempts=score.attempts,
|
|
||||||
passed=score.score.passed,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return cls(playlist_item_attempts=playlist_item_attempts)
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.models.model import UTCBaseModel
|
|
||||||
from app.models.mods import APIMod
|
from app.models.mods import APIMod
|
||||||
from app.models.playlist import PlaylistItem
|
from app.models.playlist import PlaylistItem
|
||||||
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from ._base import DatabaseModel, ondemand
|
||||||
|
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
JSON,
|
JSON,
|
||||||
@@ -15,7 +15,6 @@ from sqlmodel import (
|
|||||||
Field,
|
Field,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
|
||||||
func,
|
func,
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .room import Room
|
from .room import Room
|
||||||
|
from .score import ScoreDict
|
||||||
|
|
||||||
|
|
||||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
class PlaylistDict(TypedDict):
|
||||||
id: int = Field(index=True)
|
id: int
|
||||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
room_id: int
|
||||||
|
beatmap_id: int
|
||||||
|
created_at: datetime | None
|
||||||
ruleset_id: int
|
ruleset_id: int
|
||||||
expired: bool = Field(default=False)
|
allowed_mods: list[APIMod]
|
||||||
playlist_order: int = Field(default=0)
|
required_mods: list[APIMod]
|
||||||
played_at: datetime | None = Field(
|
freestyle: bool
|
||||||
sa_column=Column(DateTime(timezone=True)),
|
expired: bool
|
||||||
default=None,
|
owner_id: int
|
||||||
|
playlist_order: int
|
||||||
|
played_at: datetime | None
|
||||||
|
beatmap: NotRequired["BeatmapDict"]
|
||||||
|
scores: NotRequired[list[dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class PlaylistModel(DatabaseModel[PlaylistDict]):
|
||||||
|
id: int = Field(index=True)
|
||||||
|
room_id: int = Field(foreign_key="rooms.id")
|
||||||
|
beatmap_id: int = Field(
|
||||||
|
foreign_key="beatmaps.id",
|
||||||
)
|
)
|
||||||
|
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||||
|
ruleset_id: int
|
||||||
allowed_mods: list[APIMod] = Field(
|
allowed_mods: list[APIMod] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
sa_column=Column(JSON),
|
sa_column=Column(JSON),
|
||||||
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
sa_column=Column(JSON),
|
sa_column=Column(JSON),
|
||||||
)
|
)
|
||||||
beatmap_id: int = Field(
|
|
||||||
foreign_key="beatmaps.id",
|
|
||||||
)
|
|
||||||
freestyle: bool = Field(default=False)
|
freestyle: bool = Field(default=False)
|
||||||
|
expired: bool = Field(default=False)
|
||||||
|
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||||
|
playlist_order: int = Field(default=0)
|
||||||
|
played_at: datetime | None = Field(
|
||||||
|
sa_column=Column(DateTime(timezone=True)),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
|
||||||
|
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
|
||||||
|
from .score import Score, ScoreModel
|
||||||
|
|
||||||
|
scores = (
|
||||||
|
await session.exec(
|
||||||
|
select(Score).where(
|
||||||
|
Score.playlist_item_id == playlist.id,
|
||||||
|
Score.room_id == playlist.room_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
result: list[ScoreDict] = []
|
||||||
|
for score in scores:
|
||||||
|
result.append(
|
||||||
|
await ScoreModel.transform(
|
||||||
|
score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Playlist(PlaylistBase, table=True):
|
class Playlist(PlaylistModel, table=True):
|
||||||
__tablename__: str = "room_playlists"
|
__tablename__: str = "room_playlists"
|
||||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
|
||||||
|
|
||||||
beatmap: Beatmap = Relationship(
|
beatmap: Beatmap = Relationship(
|
||||||
sa_relationship_kwargs={
|
sa_relationship_kwargs={
|
||||||
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
room: "Room" = Relationship()
|
room: "Room" = Relationship()
|
||||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
|
||||||
updated_at: datetime | None = Field(
|
updated_at: datetime | None = Field(
|
||||||
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
||||||
)
|
)
|
||||||
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
|
|||||||
raise ValueError("Playlist item not found")
|
raise ValueError("Playlist item not found")
|
||||||
await session.delete(db_playlist)
|
await session.delete(db_playlist)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
class PlaylistResp(PlaylistBase):
|
|
||||||
beatmap: BeatmapResp | None = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
|
|
||||||
data = playlist.model_dump()
|
|
||||||
if "beatmap" in include:
|
|
||||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
|
||||||
resp = cls.model_validate(data)
|
|
||||||
return resp
|
|
||||||
|
|||||||
@@ -1,26 +1,39 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||||
|
|
||||||
from .user import User, UserResp
|
from app.models.score import GameMode
|
||||||
|
|
||||||
|
from ._base import DatabaseModel, included, ondemand
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
Column,
|
Column,
|
||||||
Field,
|
Field,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship as SQLRelationship,
|
Relationship as SQLRelationship,
|
||||||
SQLModel,
|
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User, UserDict
|
||||||
|
|
||||||
|
|
||||||
class RelationshipType(str, Enum):
|
class RelationshipType(str, Enum):
|
||||||
FOLLOW = "Friend"
|
FOLLOW = "friend"
|
||||||
BLOCK = "Block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
class Relationship(SQLModel, table=True):
|
class RelationshipDict(TypedDict):
|
||||||
|
target_id: int | None
|
||||||
|
type: RelationshipType
|
||||||
|
id: NotRequired[int | None]
|
||||||
|
user_id: NotRequired[int | None]
|
||||||
|
mutual: NotRequired[bool]
|
||||||
|
target: NotRequired["UserDict"]
|
||||||
|
|
||||||
|
|
||||||
|
class RelationshipModel(DatabaseModel[RelationshipDict]):
|
||||||
__tablename__: str = "relationship"
|
__tablename__: str = "relationship"
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True):
|
|||||||
ForeignKey("lazer_users.id"),
|
ForeignKey("lazer_users.id"),
|
||||||
index=True,
|
index=True,
|
||||||
),
|
),
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
target_id: int = Field(
|
target_id: int = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||||
target: User = SQLRelationship(
|
|
||||||
sa_relationship_kwargs={
|
|
||||||
"foreign_keys": "[Relationship.target_id]",
|
|
||||||
"lazy": "selectin",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
@included
|
||||||
class RelationshipResp(BaseModel):
|
@staticmethod
|
||||||
target_id: int
|
async def mutual(session: AsyncSession, relationship: "Relationship") -> bool:
|
||||||
target: UserResp
|
|
||||||
mutual: bool = False
|
|
||||||
type: RelationshipType
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
|
|
||||||
target_relationship = (
|
target_relationship = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Relationship).where(
|
select(Relationship).where(
|
||||||
@@ -68,23 +70,29 @@ class RelationshipResp(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
mutual = bool(
|
return bool(
|
||||||
target_relationship is not None
|
target_relationship is not None
|
||||||
and relationship.type == RelationshipType.FOLLOW
|
and relationship.type == RelationshipType.FOLLOW
|
||||||
and target_relationship.type == RelationshipType.FOLLOW
|
and target_relationship.type == RelationshipType.FOLLOW
|
||||||
)
|
)
|
||||||
return cls(
|
|
||||||
target_id=relationship.target_id,
|
@ondemand
|
||||||
target=await UserResp.from_db(
|
@staticmethod
|
||||||
relationship.target,
|
async def target(
|
||||||
session,
|
_session: AsyncSession,
|
||||||
include=[
|
relationship: "Relationship",
|
||||||
"team",
|
ruleset: GameMode | None = None,
|
||||||
"daily_challenge_user_stats",
|
includes: list[str] | None = None,
|
||||||
"statistics",
|
) -> "UserDict":
|
||||||
"statistics_rulesets",
|
from .user import UserModel
|
||||||
],
|
|
||||||
),
|
return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes)
|
||||||
mutual=mutual,
|
|
||||||
type=relationship.type,
|
|
||||||
)
|
class Relationship(RelationshipModel, table=True):
|
||||||
|
target: "User" = SQLRelationship(
|
||||||
|
sa_relationship_kwargs={
|
||||||
|
"foreign_keys": "[Relationship.target_id]",
|
||||||
|
"lazy": "selectin",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.database.item_attempts_count import PlaylistAggregateScore
|
|
||||||
from app.database.room_participated_user import RoomParticipatedUser
|
|
||||||
from app.models.model import UTCBaseModel
|
|
||||||
from app.models.room import (
|
from app.models.room import (
|
||||||
MatchType,
|
MatchType,
|
||||||
QueueMode,
|
QueueMode,
|
||||||
@@ -13,28 +11,58 @@ from app.models.room import (
|
|||||||
)
|
)
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
from .playlists import Playlist, PlaylistResp
|
from ._base import DatabaseModel, included, ondemand
|
||||||
from .user import User, UserResp
|
from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel
|
||||||
|
from .playlists import Playlist, PlaylistDict, PlaylistModel
|
||||||
|
from .room_participated_user import RoomParticipatedUser
|
||||||
|
from .user import User, UserDict, UserModel
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import (
|
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select
|
||||||
BigInteger,
|
|
||||||
Column,
|
|
||||||
DateTime,
|
|
||||||
Field,
|
|
||||||
ForeignKey,
|
|
||||||
Relationship,
|
|
||||||
SQLModel,
|
|
||||||
col,
|
|
||||||
func,
|
|
||||||
select,
|
|
||||||
)
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
class RoomBase(SQLModel, UTCBaseModel):
|
class RoomDict(TypedDict):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
category: RoomCategory
|
||||||
|
status: RoomStatus
|
||||||
|
type: MatchType
|
||||||
|
duration: int | None
|
||||||
|
starts_at: datetime | None
|
||||||
|
ends_at: datetime | None
|
||||||
|
max_attempts: int | None
|
||||||
|
participant_count: int
|
||||||
|
channel_id: int
|
||||||
|
queue_mode: QueueMode
|
||||||
|
auto_skip: bool
|
||||||
|
auto_start_duration: int
|
||||||
|
has_password: NotRequired[bool]
|
||||||
|
current_playlist_item: NotRequired["PlaylistDict | None"]
|
||||||
|
playlist: NotRequired[list["PlaylistDict"]]
|
||||||
|
playlist_item_stats: NotRequired[RoomPlaylistItemStats]
|
||||||
|
difficulty_range: NotRequired[RoomDifficultyRange]
|
||||||
|
host: NotRequired[UserDict]
|
||||||
|
recent_participants: NotRequired[list[UserDict]]
|
||||||
|
current_user_score: NotRequired["ItemAttemptsCountDict | None"]
|
||||||
|
|
||||||
|
|
||||||
|
class RoomModel(DatabaseModel[RoomDict]):
|
||||||
|
SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"current_user_score.playlist_item_attempts",
|
||||||
|
"host.country",
|
||||||
|
"playlist.beatmap.beatmapset",
|
||||||
|
"playlist.beatmap.checksum",
|
||||||
|
"playlist.beatmap.max_combo",
|
||||||
|
"recent_participants",
|
||||||
|
]
|
||||||
|
|
||||||
|
id: int = Field(default=None, primary_key=True, index=True)
|
||||||
name: str = Field(index=True)
|
name: str = Field(index=True)
|
||||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||||
|
status: RoomStatus
|
||||||
|
type: MatchType
|
||||||
duration: int | None = Field(default=None) # minutes
|
duration: int | None = Field(default=None) # minutes
|
||||||
starts_at: datetime | None = Field(
|
starts_at: datetime | None = Field(
|
||||||
sa_column=Column(
|
sa_column=Column(
|
||||||
@@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel):
|
|||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
participant_count: int = Field(default=0)
|
|
||||||
max_attempts: int | None = Field(default=None) # playlists
|
max_attempts: int | None = Field(default=None) # playlists
|
||||||
type: MatchType
|
participant_count: int = Field(default=0)
|
||||||
|
channel_id: int = 0
|
||||||
queue_mode: QueueMode
|
queue_mode: QueueMode
|
||||||
auto_skip: bool
|
auto_skip: bool
|
||||||
|
|
||||||
auto_start_duration: int
|
auto_start_duration: int
|
||||||
status: RoomStatus
|
|
||||||
channel_id: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Room(AsyncAttrs, RoomBase, table=True):
|
|
||||||
__tablename__: str = "rooms"
|
|
||||||
id: int = Field(default=None, primary_key=True, index=True)
|
|
||||||
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
|
||||||
password: str | None = Field(default=None)
|
|
||||||
|
|
||||||
host: User = Relationship()
|
|
||||||
playlist: list[Playlist] = Relationship(
|
|
||||||
sa_relationship_kwargs={
|
|
||||||
"lazy": "selectin",
|
|
||||||
"cascade": "all, delete-orphan",
|
|
||||||
"overlaps": "room",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoomResp(RoomBase):
|
|
||||||
id: int
|
|
||||||
has_password: bool = False
|
|
||||||
host: UserResp | None = None
|
|
||||||
playlist: list[PlaylistResp] = []
|
|
||||||
playlist_item_stats: RoomPlaylistItemStats | None = None
|
|
||||||
difficulty_range: RoomDifficultyRange | None = None
|
|
||||||
current_playlist_item: PlaylistResp | None = None
|
|
||||||
current_user_score: PlaylistAggregateScore | None = None
|
|
||||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
|
||||||
channel_id: int = 0
|
|
||||||
|
|
||||||
|
@field_validator("channel_id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_db(
|
def validate_channel_id(cls, v):
|
||||||
cls,
|
"""将 None 转换为 0"""
|
||||||
room: Room,
|
if v is None:
|
||||||
session: AsyncSession,
|
return 0
|
||||||
include: list[str] = [],
|
return v
|
||||||
user: User | None = None,
|
|
||||||
) -> "RoomResp":
|
|
||||||
d = room.model_dump()
|
|
||||||
d["channel_id"] = d.get("channel_id", 0) or 0
|
|
||||||
d["has_password"] = bool(room.password)
|
|
||||||
resp = cls.model_validate(d)
|
|
||||||
|
|
||||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
@included
|
||||||
difficulty_range = RoomDifficultyRange(
|
@staticmethod
|
||||||
min=0,
|
async def has_password(_session: AsyncSession, room: "Room") -> bool:
|
||||||
max=0,
|
return bool(room.password)
|
||||||
)
|
|
||||||
rulesets = set()
|
@ondemand
|
||||||
for playlist in room.playlist:
|
@staticmethod
|
||||||
|
async def current_playlist_item(
|
||||||
|
_session: AsyncSession, room: "Room", includes: list[str] | None = None
|
||||||
|
) -> "PlaylistDict | None":
|
||||||
|
playlists = await room.awaitable_attrs.playlist
|
||||||
|
if not playlists:
|
||||||
|
return None
|
||||||
|
return await PlaylistModel.transform(playlists[-1], includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]:
|
||||||
|
playlists = await room.awaitable_attrs.playlist
|
||||||
|
result: list[PlaylistDict] = []
|
||||||
|
for playlist_item in playlists:
|
||||||
|
result.append(await PlaylistModel.transform(playlist_item, includes=includes))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats:
|
||||||
|
playlists = await room.awaitable_attrs.playlist
|
||||||
|
stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[])
|
||||||
|
rulesets: set[int] = set()
|
||||||
|
for playlist in playlists:
|
||||||
stats.count_total += 1
|
stats.count_total += 1
|
||||||
if not playlist.expired:
|
if not playlist.expired:
|
||||||
stats.count_active += 1
|
stats.count_active += 1
|
||||||
rulesets.add(playlist.ruleset_id)
|
rulesets.add(playlist.ruleset_id)
|
||||||
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
|
|
||||||
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
|
|
||||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
|
||||||
stats.ruleset_ids = list(rulesets)
|
stats.ruleset_ids = list(rulesets)
|
||||||
resp.playlist_item_stats = stats
|
return stats
|
||||||
resp.difficulty_range = difficulty_range
|
|
||||||
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
|
@ondemand
|
||||||
resp.recent_participants = []
|
@staticmethod
|
||||||
|
async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange:
|
||||||
|
playlists = await room.awaitable_attrs.playlist
|
||||||
|
if not playlists:
|
||||||
|
return RoomDifficultyRange(min=0.0, max=0.0)
|
||||||
|
min_diff = float("inf")
|
||||||
|
max_diff = float("-inf")
|
||||||
|
for playlist in playlists:
|
||||||
|
rating = playlist.beatmap.difficulty_rating
|
||||||
|
min_diff = min(min_diff, rating)
|
||||||
|
max_diff = max(max_diff, rating)
|
||||||
|
if min_diff == float("inf"):
|
||||||
|
min_diff = 0.0
|
||||||
|
if max_diff == float("-inf"):
|
||||||
|
max_diff = 0.0
|
||||||
|
return RoomDifficultyRange(min=min_diff, max=max_diff)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict:
|
||||||
|
host_user = await room.awaitable_attrs.host
|
||||||
|
return await UserModel.transform(host_user, includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]:
|
||||||
|
participants: list[UserDict] = []
|
||||||
if room.category == RoomCategory.REALTIME:
|
if room.category == RoomCategory.REALTIME:
|
||||||
query = (
|
query = (
|
||||||
select(RoomParticipatedUser)
|
select(RoomParticipatedUser)
|
||||||
@@ -137,39 +177,67 @@ class RoomResp(RoomBase):
|
|||||||
.limit(8)
|
.limit(8)
|
||||||
.order_by(col(RoomParticipatedUser.joined_at).desc())
|
.order_by(col(RoomParticipatedUser.joined_at).desc())
|
||||||
)
|
)
|
||||||
resp.participant_count = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.count())
|
|
||||||
.select_from(RoomParticipatedUser)
|
|
||||||
.where(
|
|
||||||
RoomParticipatedUser.room_id == room.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).first() or 0
|
|
||||||
for recent_participant in await session.exec(query):
|
for recent_participant in await session.exec(query):
|
||||||
resp.recent_participants.append(
|
user_instance = await recent_participant.awaitable_attrs.user
|
||||||
await UserResp.from_db(
|
participants.append(await UserModel.transform(user_instance))
|
||||||
await recent_participant.awaitable_attrs.user,
|
return participants
|
||||||
session,
|
|
||||||
include=["statistics"],
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def current_user_score(
|
||||||
|
session: AsyncSession, room: "Room", includes: list[str] | None = None
|
||||||
|
) -> "ItemAttemptsCountDict | None":
|
||||||
|
item_attempt = (
|
||||||
|
await session.exec(
|
||||||
|
select(ItemAttemptsCount).where(
|
||||||
|
ItemAttemptsCount.room_id == room.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
|
).first()
|
||||||
if "current_user_score" in include and user:
|
if item_attempt is None:
|
||||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
return None
|
||||||
return resp
|
|
||||||
|
return await ItemAttemptsCountModel.transform(item_attempt, includes=includes)
|
||||||
|
|
||||||
|
|
||||||
class APIUploadedRoom(RoomBase):
|
class Room(AsyncAttrs, RoomModel, table=True):
|
||||||
def to_room(self) -> Room:
|
__tablename__: str = "rooms"
|
||||||
"""
|
|
||||||
将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
|
|
||||||
"""
|
|
||||||
room_dict = self.model_dump()
|
|
||||||
room_dict.pop("playlist", None)
|
|
||||||
# host_id 已在字段中
|
|
||||||
return Room(**room_dict)
|
|
||||||
|
|
||||||
id: int | None
|
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||||
host_id: int | None = None
|
password: str | None = Field(default=None)
|
||||||
|
|
||||||
|
host: User = Relationship()
|
||||||
|
playlist: list[Playlist] = Relationship(
|
||||||
|
sa_relationship_kwargs={
|
||||||
|
"lazy": "selectin",
|
||||||
|
"cascade": "all, delete-orphan",
|
||||||
|
"overlaps": "room",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIUploadedRoom(SQLModel):
|
||||||
|
name: str = Field(index=True)
|
||||||
|
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||||
|
status: RoomStatus
|
||||||
|
type: MatchType
|
||||||
|
duration: int | None = Field(default=None) # minutes
|
||||||
|
starts_at: datetime | None = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
),
|
||||||
|
default_factory=utcnow,
|
||||||
|
)
|
||||||
|
ends_at: datetime | None = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
max_attempts: int | None = Field(default=None) # playlists
|
||||||
|
participant_count: int = Field(default=0)
|
||||||
|
channel_id: int = 0
|
||||||
|
queue_mode: QueueMode
|
||||||
|
auto_skip: bool
|
||||||
|
auto_start_duration: int
|
||||||
playlist: list[Playlist] = Field(default_factory=list)
|
playlist: list[Playlist] = Field(default_factory=list)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from datetime import date, datetime
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.calculator import (
|
from app.calculator import (
|
||||||
calculate_pp_weight,
|
calculate_pp_weight,
|
||||||
@@ -15,8 +15,6 @@ from app.calculator import (
|
|||||||
pre_fetch_and_calculate_pp,
|
pre_fetch_and_calculate_pp,
|
||||||
)
|
)
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.beatmap_playcounts import BeatmapPlaycounts
|
|
||||||
from app.database.team import TeamMember
|
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.log import log
|
from app.log import log
|
||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
|
|||||||
from app.storage import StorageService
|
from app.storage import StorageService
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||||
from .beatmapset import BeatmapsetResp
|
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
|
||||||
|
from .beatmap_playcounts import BeatmapPlaycounts
|
||||||
|
from .beatmapset import BeatmapsetDict, BeatmapsetModel
|
||||||
from .best_scores import BestScore
|
from .best_scores import BestScore
|
||||||
from .counts import MonthlyPlaycounts
|
from .counts import MonthlyPlaycounts
|
||||||
from .events import Event, EventType
|
from .events import Event, EventType
|
||||||
@@ -50,8 +50,9 @@ from .relationship import (
|
|||||||
RelationshipType,
|
RelationshipType,
|
||||||
)
|
)
|
||||||
from .score_token import ScoreToken
|
from .score_token import ScoreToken
|
||||||
|
from .team import TeamMember
|
||||||
from .total_score_best_scores import TotalScoreBestScore
|
from .total_score_best_scores import TotalScoreBestScore
|
||||||
from .user import User, UserResp
|
from .user import User, UserDict, UserModel
|
||||||
|
|
||||||
from pydantic import BaseModel, field_serializer, field_validator
|
from pydantic import BaseModel, field_serializer, field_validator
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
|
|||||||
logger = log("Score")
|
logger = log("Score")
|
||||||
|
|
||||||
|
|
||||||
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
class ScoreDict(TypedDict):
|
||||||
# 基本字段
|
beatmap_id: int
|
||||||
|
id: int
|
||||||
|
rank: Rank
|
||||||
|
type: str
|
||||||
|
user_id: int
|
||||||
|
accuracy: float
|
||||||
|
build_id: int | None
|
||||||
|
ended_at: datetime
|
||||||
|
has_replay: bool
|
||||||
|
max_combo: int
|
||||||
|
passed: bool
|
||||||
|
pp: float
|
||||||
|
started_at: datetime
|
||||||
|
total_score: int
|
||||||
|
maximum_statistics: ScoreStatistics
|
||||||
|
mods: list[APIMod]
|
||||||
|
classic_total_score: int | None
|
||||||
|
preserve: bool
|
||||||
|
processed: bool
|
||||||
|
ranked: bool
|
||||||
|
playlist_item_id: NotRequired[int | None]
|
||||||
|
room_id: NotRequired[int | None]
|
||||||
|
best_id: NotRequired[int | None]
|
||||||
|
legacy_perfect: NotRequired[bool]
|
||||||
|
is_perfect_combo: NotRequired[bool]
|
||||||
|
ruleset_id: NotRequired[int]
|
||||||
|
statistics: NotRequired[ScoreStatistics]
|
||||||
|
beatmapset: NotRequired[BeatmapsetDict]
|
||||||
|
beatmap: NotRequired[BeatmapDict]
|
||||||
|
current_user_attributes: NotRequired[CurrentUserAttributes]
|
||||||
|
position: NotRequired[int | None]
|
||||||
|
scores_around: NotRequired["ScoreAround | None"]
|
||||||
|
rank_country: NotRequired[int | None]
|
||||||
|
rank_global: NotRequired[int | None]
|
||||||
|
user: NotRequired[UserDict]
|
||||||
|
weight: NotRequired[float | None]
|
||||||
|
|
||||||
|
# ScoreResp 字段
|
||||||
|
legacy_total_score: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
|
||||||
|
# https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
|
||||||
|
MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
|
||||||
|
MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"user.country",
|
||||||
|
"user.cover",
|
||||||
|
"user.team",
|
||||||
|
*MULTIPLAYER_SCORE_INCLUDE,
|
||||||
|
]
|
||||||
|
USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
|
||||||
|
|
||||||
|
# 基本字段
|
||||||
|
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||||
|
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
||||||
|
rank: Rank
|
||||||
|
type: str
|
||||||
|
user_id: int = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column(
|
||||||
|
BigInteger,
|
||||||
|
ForeignKey("lazer_users.id"),
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
accuracy: float
|
accuracy: float
|
||||||
map_md5: str = Field(max_length=32, index=True)
|
|
||||||
build_id: int | None = Field(default=None)
|
build_id: int | None = Field(default=None)
|
||||||
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
|
|
||||||
ended_at: datetime = Field(sa_column=Column(DateTime))
|
ended_at: datetime = Field(sa_column=Column(DateTime))
|
||||||
has_replay: bool = Field(sa_column=Column(Boolean))
|
has_replay: bool = Field(sa_column=Column(Boolean))
|
||||||
max_combo: int
|
max_combo: int
|
||||||
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
|
||||||
passed: bool = Field(sa_column=Column(Boolean))
|
passed: bool = Field(sa_column=Column(Boolean))
|
||||||
playlist_item_id: int | None = Field(default=None) # multiplayer
|
|
||||||
pp: float = Field(default=0.0)
|
pp: float = Field(default=0.0)
|
||||||
preserve: bool = Field(default=True, sa_column=Column(Boolean))
|
|
||||||
rank: Rank
|
|
||||||
room_id: int | None = Field(default=None) # multiplayer
|
|
||||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||||
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
|
||||||
type: str
|
|
||||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
|
||||||
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
|
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
|
||||||
processed: bool = False # solo_score
|
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
||||||
ranked: bool = False
|
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
||||||
|
|
||||||
|
# solo
|
||||||
|
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
|
||||||
|
preserve: bool = Field(default=True, sa_column=Column(Boolean))
|
||||||
|
processed: bool = Field(default=False)
|
||||||
|
ranked: bool = Field(default=False)
|
||||||
|
|
||||||
|
# multiplayer
|
||||||
|
playlist_item_id: OnDemand[int | None] = Field(default=None)
|
||||||
|
room_id: OnDemand[int | None] = Field(default=None)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def best_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> int | None:
|
||||||
|
return await get_best_id(session, score.id)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def legacy_perfect(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> bool:
|
||||||
|
await score.awaitable_attrs.beatmap
|
||||||
|
return score.max_combo == score.beatmap.max_combo
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def is_perfect_combo(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> bool:
|
||||||
|
await score.awaitable_attrs.beatmap
|
||||||
|
return score.max_combo == score.beatmap.max_combo
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def ruleset_id(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> int:
|
||||||
|
return int(score.gamemode)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def statistics(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> ScoreStatistics:
|
||||||
|
stats = {
|
||||||
|
HitResult.MISS: score.nmiss,
|
||||||
|
HitResult.MEH: score.n50,
|
||||||
|
HitResult.OK: score.n100,
|
||||||
|
HitResult.GREAT: score.n300,
|
||||||
|
HitResult.PERFECT: score.ngeki,
|
||||||
|
HitResult.GOOD: score.nkatu,
|
||||||
|
}
|
||||||
|
if score.nlarge_tick_miss is not None:
|
||||||
|
stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
||||||
|
if score.nslider_tail_hit is not None:
|
||||||
|
stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
||||||
|
if score.nsmall_tick_hit is not None:
|
||||||
|
stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
|
||||||
|
if score.nlarge_tick_hit is not None:
|
||||||
|
stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
|
||||||
|
return stats
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def beatmapset(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> BeatmapsetDict:
|
||||||
|
await score.awaitable_attrs.beatmap
|
||||||
|
return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
|
||||||
|
|
||||||
|
# reorder beatmapset and beatmap
|
||||||
|
# https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def beatmap(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> BeatmapDict:
|
||||||
|
await score.awaitable_attrs.beatmap
|
||||||
|
return await BeatmapModel.transform(score.beatmap, includes=includes)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def current_user_attributes(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> CurrentUserAttributes:
|
||||||
|
return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def position(
|
||||||
|
session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> int | None:
|
||||||
|
return await get_score_position_by_id(
|
||||||
|
session,
|
||||||
|
score.beatmap_id,
|
||||||
|
score.id,
|
||||||
|
mode=score.gamemode,
|
||||||
|
user=score.user,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def scores_around(
|
||||||
|
session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
|
||||||
|
) -> "ScoreAround | None":
|
||||||
|
scores = (
|
||||||
|
await session.exec(
|
||||||
|
select(PlaylistBestScore).where(
|
||||||
|
PlaylistBestScore.playlist_id == playlist_id,
|
||||||
|
PlaylistBestScore.room_id == room_id,
|
||||||
|
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
|
||||||
|
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
higher_scores = []
|
||||||
|
lower_scores = []
|
||||||
|
for score in scores:
|
||||||
|
total_score = score.score.total_score
|
||||||
|
resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
|
||||||
|
if score.total_score > total_score:
|
||||||
|
higher_scores.append(resp)
|
||||||
|
elif score.total_score < total_score:
|
||||||
|
lower_scores.append(resp)
|
||||||
|
|
||||||
|
return ScoreAround(
|
||||||
|
higher=MultiplayerScores(scores=higher_scores),
|
||||||
|
lower=MultiplayerScores(scores=lower_scores),
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def rank_country(
|
||||||
|
session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> int | None:
|
||||||
|
return (
|
||||||
|
await get_score_position_by_id(
|
||||||
|
session,
|
||||||
|
score.beatmap_id,
|
||||||
|
score.id,
|
||||||
|
score.gamemode,
|
||||||
|
score.user,
|
||||||
|
type=LeaderboardType.COUNTRY,
|
||||||
|
)
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def rank_global(
|
||||||
|
session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> int | None:
|
||||||
|
return (
|
||||||
|
await get_score_position_by_id(
|
||||||
|
session,
|
||||||
|
score.beatmap_id,
|
||||||
|
score.id,
|
||||||
|
mode=score.gamemode,
|
||||||
|
user=score.user,
|
||||||
|
)
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def user(
|
||||||
|
_session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
includes: list[str] | None = None,
|
||||||
|
) -> UserDict:
|
||||||
|
return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def weight(
|
||||||
|
session: AsyncSession,
|
||||||
|
score: "Score",
|
||||||
|
) -> float | None:
|
||||||
|
best_id = await get_best_id(session, score.id)
|
||||||
|
if best_id:
|
||||||
|
return calculate_pp_weight(best_id - 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def legacy_total_score(
|
||||||
|
_session: AsyncSession,
|
||||||
|
_score: "Score",
|
||||||
|
) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
@field_validator("maximum_statistics", mode="before")
|
@field_validator("maximum_statistics", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
|||||||
# TODO: current_user_attributes
|
# TODO: current_user_attributes
|
||||||
|
|
||||||
|
|
||||||
class Score(ScoreBase, table=True):
|
class Score(ScoreModel, table=True):
|
||||||
__tablename__: str = "scores"
|
__tablename__: str = "scores"
|
||||||
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
|
||||||
user_id: int = Field(
|
|
||||||
default=None,
|
|
||||||
sa_column=Column(
|
|
||||||
BigInteger,
|
|
||||||
ForeignKey("lazer_users.id"),
|
|
||||||
index=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# ScoreStatistics
|
# ScoreStatistics
|
||||||
n300: int = Field(exclude=True)
|
n300: int = Field(exclude=True)
|
||||||
n100: int = Field(exclude=True)
|
n100: int = Field(exclude=True)
|
||||||
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
|
|||||||
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
|
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
|
||||||
gamemode: GameMode = Field(index=True)
|
gamemode: GameMode = Field(index=True)
|
||||||
pinned_order: int = Field(default=0, exclude=True)
|
pinned_order: int = Field(default=0, exclude=True)
|
||||||
|
map_md5: str = Field(max_length=32, index=True, exclude=True)
|
||||||
|
|
||||||
@field_validator("gamemode", mode="before")
|
@field_validator("gamemode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
|
|||||||
maximum_statistics=self.maximum_statistics,
|
maximum_statistics=self.maximum_statistics,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
|
async def to_resp(
|
||||||
|
self, session: AsyncSession, api_version: int, includes: list[str] = []
|
||||||
|
) -> "ScoreDict | LegacyScoreResp":
|
||||||
if api_version >= 20220705:
|
if api_version >= 20220705:
|
||||||
return await ScoreResp.from_db(session, self)
|
return await ScoreModel.transform(self, includes=includes)
|
||||||
return await LegacyScoreResp.from_db(session, self)
|
return await LegacyScoreResp.from_db(session, self)
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
|
|||||||
await session.delete(self)
|
await session.delete(self)
|
||||||
|
|
||||||
|
|
||||||
class ScoreResp(ScoreBase):
|
MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
is_perfect_combo: bool = False
|
|
||||||
legacy_perfect: bool = False
|
|
||||||
legacy_total_score: int = 0 # FIXME
|
|
||||||
weight: float = 0.0
|
|
||||||
best_id: int | None = None
|
|
||||||
ruleset_id: int | None = None
|
|
||||||
beatmap: BeatmapResp | None = None
|
|
||||||
beatmapset: BeatmapsetResp | None = None
|
|
||||||
user: UserResp | None = None
|
|
||||||
statistics: ScoreStatistics | None = None
|
|
||||||
maximum_statistics: ScoreStatistics | None = None
|
|
||||||
rank_global: int | None = None
|
|
||||||
rank_country: int | None = None
|
|
||||||
position: int | None = None
|
|
||||||
scores_around: "ScoreAround | None" = None
|
|
||||||
current_user_attributes: CurrentUserAttributes | None = None
|
|
||||||
|
|
||||||
@field_validator(
|
|
||||||
"has_replay",
|
|
||||||
"passed",
|
|
||||||
"preserve",
|
|
||||||
"is_perfect_combo",
|
|
||||||
"legacy_perfect",
|
|
||||||
"processed",
|
|
||||||
"ranked",
|
|
||||||
mode="before",
|
|
||||||
)
|
|
||||||
@classmethod
|
|
||||||
def validate_bool_fields(cls, v):
|
|
||||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
|
||||||
if isinstance(v, int):
|
|
||||||
return bool(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("statistics", "maximum_statistics", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_statistics_fields(cls, v):
|
|
||||||
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
|
|
||||||
if isinstance(v, dict):
|
|
||||||
converted = {}
|
|
||||||
for key, value in v.items():
|
|
||||||
if isinstance(key, str):
|
|
||||||
try:
|
|
||||||
# 尝试将字符串转换为 HitResult 枚举
|
|
||||||
enum_key = HitResult(key)
|
|
||||||
converted[enum_key] = value
|
|
||||||
except ValueError:
|
|
||||||
# 如果转换失败,跳过这个键值对
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
converted[key] = value
|
|
||||||
return converted
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_serializer("statistics", when_used="json")
|
|
||||||
def serialize_statistics_fields(self, v):
|
|
||||||
"""序列化统计字段,确保枚举值正确转换为字符串"""
|
|
||||||
if isinstance(v, dict):
|
|
||||||
serialized = {}
|
|
||||||
for key, value in v.items():
|
|
||||||
if hasattr(key, "value"):
|
|
||||||
# 如果是枚举,使用其值
|
|
||||||
serialized[key.value] = value
|
|
||||||
else:
|
|
||||||
# 否则直接使用键
|
|
||||||
serialized[str(key)] = value
|
|
||||||
return serialized
|
|
||||||
return v
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
|
||||||
# 确保 score 对象完全加载,避免懒加载问题
|
|
||||||
await session.refresh(score)
|
|
||||||
|
|
||||||
s = cls.model_validate(score.model_dump())
|
|
||||||
await score.awaitable_attrs.beatmap
|
|
||||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
|
||||||
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
|
|
||||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
|
||||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
|
||||||
s.ruleset_id = int(score.gamemode)
|
|
||||||
best_id = await get_best_id(session, score.id)
|
|
||||||
if best_id:
|
|
||||||
s.best_id = best_id
|
|
||||||
s.weight = calculate_pp_weight(best_id - 1)
|
|
||||||
s.statistics = {
|
|
||||||
HitResult.MISS: score.nmiss,
|
|
||||||
HitResult.MEH: score.n50,
|
|
||||||
HitResult.OK: score.n100,
|
|
||||||
HitResult.GREAT: score.n300,
|
|
||||||
HitResult.PERFECT: score.ngeki,
|
|
||||||
HitResult.GOOD: score.nkatu,
|
|
||||||
}
|
|
||||||
if score.nlarge_tick_miss is not None:
|
|
||||||
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
|
||||||
if score.nslider_tail_hit is not None:
|
|
||||||
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
|
||||||
if score.nsmall_tick_hit is not None:
|
|
||||||
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
|
|
||||||
if score.nlarge_tick_hit is not None:
|
|
||||||
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
|
|
||||||
s.user = await UserResp.from_db(
|
|
||||||
score.user,
|
|
||||||
session,
|
|
||||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
|
||||||
ruleset=score.gamemode,
|
|
||||||
)
|
|
||||||
s.rank_global = (
|
|
||||||
await get_score_position_by_id(
|
|
||||||
session,
|
|
||||||
score.beatmap_id,
|
|
||||||
score.id,
|
|
||||||
mode=score.gamemode,
|
|
||||||
user=score.user,
|
|
||||||
)
|
|
||||||
or None
|
|
||||||
)
|
|
||||||
s.rank_country = (
|
|
||||||
await get_score_position_by_id(
|
|
||||||
session,
|
|
||||||
score.beatmap_id,
|
|
||||||
score.id,
|
|
||||||
score.gamemode,
|
|
||||||
score.user,
|
|
||||||
type=LeaderboardType.COUNTRY,
|
|
||||||
)
|
|
||||||
or None
|
|
||||||
)
|
|
||||||
s.current_user_attributes = CurrentUserAttributes(
|
|
||||||
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
|
|
||||||
)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyStatistics(BaseModel):
|
class LegacyStatistics(BaseModel):
|
||||||
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LegacyScoreResp(UTCBaseModel):
|
class LegacyScoreResp(UTCBaseModel):
|
||||||
accuracy: float
|
|
||||||
best_id: int
|
|
||||||
created_at: datetime
|
|
||||||
id: int
|
id: int
|
||||||
max_combo: int
|
best_id: int
|
||||||
mode: GameMode
|
user_id: int
|
||||||
mode_int: int
|
accuracy: float
|
||||||
mods: list[str] # acronym
|
mods: list[str] # acronym
|
||||||
passed: bool
|
score: int
|
||||||
|
max_combo: int
|
||||||
perfect: bool = False
|
perfect: bool = False
|
||||||
|
statistics: LegacyStatistics
|
||||||
|
passed: bool
|
||||||
pp: float
|
pp: float
|
||||||
rank: Rank
|
rank: Rank
|
||||||
|
created_at: datetime
|
||||||
|
mode: GameMode
|
||||||
|
mode_int: int
|
||||||
replay: bool
|
replay: bool
|
||||||
score: int
|
|
||||||
statistics: LegacyStatistics
|
|
||||||
type: str
|
|
||||||
user_id: int
|
|
||||||
current_user_attributes: CurrentUserAttributes
|
|
||||||
user: UserResp
|
|
||||||
beatmap: BeatmapResp
|
|
||||||
rank_global: int | None = Field(default=None, exclude=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
|
async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
|
||||||
await session.refresh(score)
|
|
||||||
await score.awaitable_attrs.beatmap
|
await score.awaitable_attrs.beatmap
|
||||||
return cls(
|
return cls(
|
||||||
accuracy=score.accuracy,
|
accuracy=score.accuracy,
|
||||||
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
|
|||||||
count_geki=score.ngeki or 0,
|
count_geki=score.ngeki or 0,
|
||||||
count_katu=score.nkatu or 0,
|
count_katu=score.nkatu or 0,
|
||||||
),
|
),
|
||||||
type=score.type,
|
|
||||||
user_id=score.user_id,
|
user_id=score.user_id,
|
||||||
current_user_attributes=CurrentUserAttributes(
|
|
||||||
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
|
|
||||||
),
|
|
||||||
user=await UserResp.from_db(
|
|
||||||
score.user,
|
|
||||||
session,
|
|
||||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
|
||||||
ruleset=score.gamemode,
|
|
||||||
),
|
|
||||||
beatmap=await BeatmapResp.from_db(score.beatmap),
|
|
||||||
perfect=score.is_perfect_combo,
|
perfect=score.is_perfect_combo,
|
||||||
rank_global=(
|
|
||||||
await get_score_position_by_id(
|
|
||||||
session,
|
|
||||||
score.beatmap_id,
|
|
||||||
score.id,
|
|
||||||
mode=score.gamemode,
|
|
||||||
user=score.user,
|
|
||||||
)
|
|
||||||
or None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerScores(RespWithCursor):
|
class MultiplayerScores(RespWithCursor):
|
||||||
scores: list[ScoreResp] = Field(default_factory=list)
|
scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
|
||||||
params: dict[str, Any] = Field(default_factory=dict)
|
params: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@@ -842,13 +937,13 @@ async def get_user_best_pp(
|
|||||||
|
|
||||||
|
|
||||||
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
|
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
|
||||||
def get_play_length(score: Score, beatmap_length: int):
|
def get_play_length(score: "Score", beatmap_length: int):
|
||||||
speed_rate = get_speed_rate(score.mods)
|
speed_rate = get_speed_rate(score.mods)
|
||||||
length = beatmap_length / speed_rate
|
length = beatmap_length / speed_rate
|
||||||
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
|
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
|
||||||
|
|
||||||
|
|
||||||
def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
|
def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
|
||||||
total_length = get_play_length(score, beatmap_length)
|
total_length = get_play_length(score, beatmap_length)
|
||||||
total_obj_hited = (
|
total_obj_hited = (
|
||||||
score.n300
|
score.n300
|
||||||
@@ -937,7 +1032,7 @@ async def process_score(
|
|||||||
return score
|
return score
|
||||||
|
|
||||||
|
|
||||||
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||||
if score.pp != 0:
|
if score.pp != 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
|
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
|
||||||
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _process_score_events(score: Score, session: AsyncSession):
|
async def _process_score_events(score: "Score", session: AsyncSession):
|
||||||
total_users = (await session.exec(select(func.count()).select_from(User))).one()
|
total_users = (await session.exec(select(func.count()).select_from(User))).one()
|
||||||
rank_global = await get_score_position_by_id(
|
rank_global = await get_score_position_by_id(
|
||||||
session,
|
session,
|
||||||
@@ -1088,7 +1183,7 @@ async def _process_statistics(
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
user: User,
|
user: User,
|
||||||
score: Score,
|
score: "Score",
|
||||||
score_token: int,
|
score_token: int,
|
||||||
beatmap_length: int,
|
beatmap_length: int,
|
||||||
beatmap_status: BeatmapRankStatus,
|
beatmap_status: BeatmapRankStatus,
|
||||||
@@ -1318,7 +1413,7 @@ async def process_user(
|
|||||||
redis: Redis,
|
redis: Redis,
|
||||||
fetcher: "Fetcher",
|
fetcher: "Fetcher",
|
||||||
user: User,
|
user: User,
|
||||||
score: Score,
|
score: "Score",
|
||||||
score_token: int,
|
score_token: int,
|
||||||
beatmap_length: int,
|
beatmap_length: int,
|
||||||
beatmap_status: BeatmapRankStatus,
|
beatmap_status: BeatmapRankStatus,
|
||||||
|
|||||||
13
app/database/search_beatmapset.py
Normal file
13
app/database/search_beatmapset.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from . import beatmap # noqa: F401
|
||||||
|
from .beatmapset import BeatmapsetModel
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags"))
|
||||||
|
|
||||||
|
|
||||||
|
class SearchBeatmapsetsResp(SQLModel):
|
||||||
|
beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm]
|
||||||
|
total: int
|
||||||
|
cursor: dict[str, int | float | str] | None = None
|
||||||
|
cursor_string: str | None = None
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||||
|
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
|
from ._base import DatabaseModel, included, ondemand
|
||||||
from .rank_history import RankHistory
|
from .rank_history import RankHistory
|
||||||
|
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
@@ -15,7 +16,6 @@ from sqlmodel import (
|
|||||||
Field,
|
Field,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
|
||||||
col,
|
col,
|
||||||
func,
|
func,
|
||||||
select,
|
select,
|
||||||
@@ -23,10 +23,40 @@ from sqlmodel import (
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User, UserResp
|
from .user import User, UserDict
|
||||||
|
|
||||||
|
|
||||||
class UserStatisticsBase(SQLModel):
|
class UserStatisticsDict(TypedDict):
|
||||||
|
mode: GameMode
|
||||||
|
count_100: int
|
||||||
|
count_300: int
|
||||||
|
count_50: int
|
||||||
|
count_miss: int
|
||||||
|
pp: float
|
||||||
|
ranked_score: int
|
||||||
|
hit_accuracy: float
|
||||||
|
total_score: int
|
||||||
|
total_hits: int
|
||||||
|
maximum_combo: int
|
||||||
|
play_count: int
|
||||||
|
play_time: int
|
||||||
|
replays_watched_by_others: int
|
||||||
|
is_ranked: bool
|
||||||
|
level: NotRequired[dict[str, int]]
|
||||||
|
global_rank: NotRequired[int | None]
|
||||||
|
grade_counts: NotRequired[dict[str, int]]
|
||||||
|
rank_change_since_30_days: NotRequired[int]
|
||||||
|
country_rank: NotRequired[int | None]
|
||||||
|
user: NotRequired["UserDict"]
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
|
||||||
|
RANKING_INCLUDES: ClassVar[list[str]] = [
|
||||||
|
"user.country",
|
||||||
|
"user.cover",
|
||||||
|
"user.team",
|
||||||
|
]
|
||||||
|
|
||||||
mode: GameMode = Field(index=True)
|
mode: GameMode = Field(index=True)
|
||||||
count_100: int = Field(default=0, sa_column=Column(BigInteger))
|
count_100: int = Field(default=0, sa_column=Column(BigInteger))
|
||||||
count_300: int = Field(default=0, sa_column=Column(BigInteger))
|
count_300: int = Field(default=0, sa_column=Column(BigInteger))
|
||||||
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
|
|||||||
return GameMode.OSU
|
return GameMode.OSU
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||||
|
return {
|
||||||
|
"current": int(statistics.level_current),
|
||||||
|
"progress": int(math.fmod(statistics.level_current, 1) * 100),
|
||||||
|
}
|
||||||
|
|
||||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
|
||||||
|
return await get_rank(session, statistics)
|
||||||
|
|
||||||
|
@included
|
||||||
|
@staticmethod
|
||||||
|
async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||||
|
return {
|
||||||
|
"ss": statistics.grade_ss,
|
||||||
|
"ssh": statistics.grade_ssh,
|
||||||
|
"s": statistics.grade_s,
|
||||||
|
"sh": statistics.grade_sh,
|
||||||
|
"a": statistics.grade_a,
|
||||||
|
}
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
|
||||||
|
global_rank = await get_rank(session, statistics)
|
||||||
|
rank_best = (
|
||||||
|
await session.exec(
|
||||||
|
select(func.max(RankHistory.rank)).where(
|
||||||
|
RankHistory.date > utcnow() - timedelta(days=30),
|
||||||
|
RankHistory.user_id == statistics.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if rank_best is None or global_rank is None:
|
||||||
|
return 0
|
||||||
|
return rank_best - global_rank
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def country_rank(
|
||||||
|
session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
|
||||||
|
) -> int | None:
|
||||||
|
return await get_rank(session, statistics, user_country)
|
||||||
|
|
||||||
|
@ondemand
|
||||||
|
@staticmethod
|
||||||
|
async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
|
||||||
|
from .user import UserModel
|
||||||
|
|
||||||
|
user_instance = await statistics.awaitable_attrs.user
|
||||||
|
return await UserModel.transform(user_instance)
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
|
||||||
__tablename__: str = "lazer_user_statistics"
|
__tablename__: str = "lazer_user_statistics"
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
user_id: int = Field(
|
user_id: int = Field(
|
||||||
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
|||||||
user: "User" = Relationship(back_populates="statistics")
|
user: "User" = Relationship(back_populates="statistics")
|
||||||
|
|
||||||
|
|
||||||
class UserStatisticsResp(UserStatisticsBase):
|
|
||||||
user: "UserResp | None" = None
|
|
||||||
rank_change_since_30_days: int | None = 0
|
|
||||||
global_rank: int | None = Field(default=None)
|
|
||||||
country_rank: int | None = Field(default=None)
|
|
||||||
grade_counts: dict[str, int] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"ss": 0,
|
|
||||||
"ssh": 0,
|
|
||||||
"s": 0,
|
|
||||||
"sh": 0,
|
|
||||||
"a": 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
level: dict[str, int] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"current": 1,
|
|
||||||
"progress": 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def from_db(
|
|
||||||
cls,
|
|
||||||
obj: UserStatistics,
|
|
||||||
session: AsyncSession,
|
|
||||||
user_country: str | None = None,
|
|
||||||
include: list[str] = [],
|
|
||||||
) -> "UserStatisticsResp":
|
|
||||||
s = cls.model_validate(obj.model_dump())
|
|
||||||
s.grade_counts = {
|
|
||||||
"ss": obj.grade_ss,
|
|
||||||
"ssh": obj.grade_ssh,
|
|
||||||
"s": obj.grade_s,
|
|
||||||
"sh": obj.grade_sh,
|
|
||||||
"a": obj.grade_a,
|
|
||||||
}
|
|
||||||
s.level = {
|
|
||||||
"current": int(obj.level_current),
|
|
||||||
"progress": int(math.fmod(obj.level_current, 1) * 100),
|
|
||||||
}
|
|
||||||
if "user" in include:
|
|
||||||
from .user import RANKING_INCLUDES, UserResp
|
|
||||||
|
|
||||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
|
||||||
s.user = user
|
|
||||||
user_country = user.country_code
|
|
||||||
|
|
||||||
s.global_rank = await get_rank(session, obj)
|
|
||||||
s.country_rank = await get_rank(session, obj, user_country)
|
|
||||||
|
|
||||||
if "rank_change_since_30_days" in include:
|
|
||||||
rank_best = (
|
|
||||||
await session.exec(
|
|
||||||
select(func.max(RankHistory.rank)).where(
|
|
||||||
RankHistory.date > utcnow() - timedelta(days=30),
|
|
||||||
RankHistory.user_id == obj.user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
if rank_best is None or s.global_rank is None:
|
|
||||||
s.rank_change_since_30_days = 0
|
|
||||||
else:
|
|
||||||
s.rank_change_since_30_days = rank_best - s.global_rank
|
|
||||||
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
|
|||||||
query = query.join(User).where(User.country_code == country)
|
query = query.join(User).where(User.country_code == country)
|
||||||
|
|
||||||
subq = query.subquery()
|
subq = query.subquery()
|
||||||
|
|
||||||
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
||||||
|
|
||||||
rank = result.first()
|
rank = result.first()
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.calculator import calculate_score_to_level
|
from app.calculator import calculate_score_to_level
|
||||||
from app.database.statistics import UserStatistics
|
|
||||||
from app.models.score import GameMode, Rank
|
from app.models.score import GameMode, Rank
|
||||||
|
|
||||||
|
from .statistics import UserStatistics
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,40 @@
|
|||||||
from app.database.beatmap import BeatmapResp
|
from app.database.beatmap import BeatmapDict, BeatmapModel
|
||||||
from app.log import fetcher_logger
|
from app.log import fetcher_logger
|
||||||
|
|
||||||
from ._base import BaseFetcher
|
from ._base import BaseFetcher
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
logger = fetcher_logger("BeatmapFetcher")
|
logger = fetcher_logger("BeatmapFetcher")
|
||||||
|
adapter = TypeAdapter(
|
||||||
|
BeatmapModel.generate_typeddict(
|
||||||
|
(
|
||||||
|
"checksum",
|
||||||
|
"accuracy",
|
||||||
|
"ar",
|
||||||
|
"bpm",
|
||||||
|
"convert",
|
||||||
|
"count_circles",
|
||||||
|
"count_sliders",
|
||||||
|
"count_spinners",
|
||||||
|
"cs",
|
||||||
|
"deleted_at",
|
||||||
|
"drain",
|
||||||
|
"hit_length",
|
||||||
|
"is_scoreable",
|
||||||
|
"last_updated",
|
||||||
|
"mode_int",
|
||||||
|
"ranked",
|
||||||
|
"url",
|
||||||
|
"max_combo",
|
||||||
|
"beatmapset",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BeatmapFetcher(BaseFetcher):
|
class BeatmapFetcher(BaseFetcher):
|
||||||
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
|
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapDict:
|
||||||
if beatmap_id:
|
if beatmap_id:
|
||||||
params = {"id": beatmap_id}
|
params = {"id": beatmap_id}
|
||||||
elif beatmap_checksum:
|
elif beatmap_checksum:
|
||||||
@@ -16,7 +43,7 @@ class BeatmapFetcher(BaseFetcher):
|
|||||||
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
|
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
|
||||||
logger.opt(colors=True).debug(f"get_beatmap: <y>{params}</y>")
|
logger.opt(colors=True).debug(f"get_beatmap: <y>{params}</y>")
|
||||||
|
|
||||||
return BeatmapResp.model_validate(
|
return adapter.validate_python( # pyright: ignore[reportReturnType]
|
||||||
await self.request_api(
|
await self.request_api(
|
||||||
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
|
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
|
||||||
params=params,
|
params=params,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
|
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
|
||||||
from app.helpers.rate_limiter import osu_api_rate_limiter
|
from app.helpers.rate_limiter import osu_api_rate_limiter
|
||||||
from app.log import fetcher_logger
|
from app.log import fetcher_logger
|
||||||
from app.models.beatmap import SearchQueryModel
|
from app.models.beatmap import SearchQueryModel
|
||||||
@@ -13,6 +13,7 @@ from app.utils import bg_tasks
|
|||||||
from ._base import BaseFetcher
|
from ._base import BaseFetcher
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from pydantic import TypeAdapter
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
|
||||||
@@ -26,6 +27,46 @@ logger = fetcher_logger("BeatmapsetFetcher")
|
|||||||
|
|
||||||
|
|
||||||
MAX_RETRY_ATTEMPTS = 3
|
MAX_RETRY_ATTEMPTS = 3
|
||||||
|
adapter = TypeAdapter(
|
||||||
|
BeatmapsetModel.generate_typeddict(
|
||||||
|
(
|
||||||
|
"availability",
|
||||||
|
"bpm",
|
||||||
|
"last_updated",
|
||||||
|
"ranked",
|
||||||
|
"ranked_date",
|
||||||
|
"submitted_date",
|
||||||
|
"tags",
|
||||||
|
"storyboard",
|
||||||
|
"description",
|
||||||
|
"genre",
|
||||||
|
"language",
|
||||||
|
*[
|
||||||
|
f"beatmaps.{inc}"
|
||||||
|
for inc in (
|
||||||
|
"checksum",
|
||||||
|
"accuracy",
|
||||||
|
"ar",
|
||||||
|
"bpm",
|
||||||
|
"convert",
|
||||||
|
"count_circles",
|
||||||
|
"count_sliders",
|
||||||
|
"count_spinners",
|
||||||
|
"cs",
|
||||||
|
"deleted_at",
|
||||||
|
"drain",
|
||||||
|
"hit_length",
|
||||||
|
"is_scoreable",
|
||||||
|
"last_updated",
|
||||||
|
"mode_int",
|
||||||
|
"ranked",
|
||||||
|
"url",
|
||||||
|
"max_combo",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BeatmapsetFetcher(BaseFetcher):
|
class BeatmapsetFetcher(BaseFetcher):
|
||||||
@@ -139,10 +180,9 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetDict:
|
||||||
logger.opt(colors=True).debug(f"get_beatmapset: <y>{beatmap_set_id}</y>")
|
logger.opt(colors=True).debug(f"get_beatmapset: <y>{beatmap_set_id}</y>")
|
||||||
|
return adapter.validate_python( # pyright: ignore[reportReturnType]
|
||||||
return BeatmapsetResp.model_validate(
|
|
||||||
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
|
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class ChatEvent(BaseModel):
|
class ChatEvent(TypedDict):
|
||||||
event: str
|
event: str
|
||||||
data: dict[str, Any] | None = None
|
data: dict[str, Any] | None
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ from datetime import UTC, datetime
|
|||||||
|
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from pydantic import BaseModel, field_serializer
|
from pydantic import BaseModel, FieldSerializationInfo, field_serializer
|
||||||
|
|
||||||
|
|
||||||
class UTCBaseModel(BaseModel):
|
class UTCBaseModel(BaseModel):
|
||||||
@field_serializer("*", when_used="json")
|
@field_serializer("*", when_used="always")
|
||||||
def serialize_datetime(self, v, _info):
|
def serialize_datetime(self, v, _info: FieldSerializationInfo):
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
if v.tzinfo is None:
|
if v.tzinfo is None:
|
||||||
v = v.replace(tzinfo=UTC)
|
v = v.replace(tzinfo=UTC)
|
||||||
return v.astimezone(UTC).isoformat().replace("+00:00", "Z")
|
return v.astimezone(UTC)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ async def create_multiplayer_room(
|
|||||||
# 让房主加入频道
|
# 让房主加入频道
|
||||||
host_user = await db.get(User, host_user_id)
|
host_user = await db.get(User, host_user_id)
|
||||||
if host_user:
|
if host_user:
|
||||||
await server.batch_join_channel([host_user], channel, db)
|
await server.batch_join_channel([host_user], channel)
|
||||||
# Add playlist items
|
# Add playlist items
|
||||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
|
|||||||
from math import ceil
|
from math import ceil
|
||||||
import random
|
import random
|
||||||
import shlex
|
import shlex
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.calculator import calculate_weighted_pp
|
from app.calculator import calculate_weighted_pp
|
||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import ChatMessageResp
|
|
||||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
|
||||||
from app.database.score import Score, get_best_id
|
from app.database.score import Score, get_best_id
|
||||||
from app.database.statistics import UserStatistics, get_rank
|
from app.database.statistics import UserStatistics, get_rank
|
||||||
from app.database.user import User
|
from app.database.user import User
|
||||||
@@ -95,7 +98,7 @@ class Bot:
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(msg)
|
await session.refresh(msg)
|
||||||
await session.refresh(bot)
|
await session.refresh(bot)
|
||||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||||
await server.send_message_to_channel(resp)
|
await server.send_message_to_channel(resp)
|
||||||
|
|
||||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||||
@@ -119,7 +122,7 @@ class Bot:
|
|||||||
await session.refresh(channel)
|
await session.refresh(channel)
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
await session.refresh(bot)
|
await session.refresh(bot)
|
||||||
await server.batch_join_channel([user, bot], channel, session)
|
await server.batch_join_channel([user, bot], channel)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
async def _send_reply(
|
async def _send_reply(
|
||||||
|
|||||||
@@ -1,37 +1,40 @@
|
|||||||
from typing import Annotated, Any, Literal, Self
|
from typing import Annotated, Literal, Self
|
||||||
|
|
||||||
from app.database.chat import (
|
from app.database.chat import (
|
||||||
ChannelType,
|
ChannelType,
|
||||||
ChatChannel,
|
ChatChannel,
|
||||||
ChatChannelResp,
|
ChatChannelModel,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
SilenceUser,
|
SilenceUser,
|
||||||
UserSilenceResp,
|
UserSilenceResp,
|
||||||
)
|
)
|
||||||
from app.database.user import User, UserResp
|
from app.database.user import User, UserModel
|
||||||
from app.dependencies.database import Database, Redis
|
from app.dependencies.database import Database, Redis
|
||||||
from app.dependencies.param import BodyOrForm
|
from app.dependencies.param import BodyOrForm
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .server import server
|
from .server import server
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
|
|
||||||
class UpdateResponse(BaseModel):
|
|
||||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
|
||||||
silences: list[Any] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/chat/updates",
|
"/chat/updates",
|
||||||
response_model=UpdateResponse,
|
|
||||||
name="获取更新",
|
name="获取更新",
|
||||||
description="获取当前用户所在频道的最新的禁言情况。",
|
description="获取当前用户所在频道的最新的禁言情况。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"获取更新响应。",
|
||||||
|
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
|
||||||
|
ChatChannel.LISTING_INCLUDES,
|
||||||
|
name="UpdateResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def get_update(
|
async def get_update(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -44,45 +47,44 @@ async def get_update(
|
|||||||
Query(alias="includes[]", description="要包含的更新类型"),
|
Query(alias="includes[]", description="要包含的更新类型"),
|
||||||
] = ["presence", "silences"],
|
] = ["presence", "silences"],
|
||||||
):
|
):
|
||||||
resp = UpdateResponse()
|
resp = {
|
||||||
|
"presence": [],
|
||||||
|
"silences": [],
|
||||||
|
}
|
||||||
if "presence" in includes:
|
if "presence" in includes:
|
||||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||||
for channel_id in channel_ids:
|
for channel_id in channel_ids:
|
||||||
# 使用明确的查询避免延迟加载
|
# 使用明确的查询避免延迟加载
|
||||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||||
if db_channel:
|
if db_channel:
|
||||||
# 提取必要的属性避免惰性加载
|
resp["presence"].append(
|
||||||
channel_type = db_channel.type
|
await ChatChannelModel.transform(
|
||||||
|
|
||||||
resp.presence.append(
|
|
||||||
await ChatChannelResp.from_db(
|
|
||||||
db_channel,
|
db_channel,
|
||||||
session,
|
user=current_user,
|
||||||
current_user,
|
server=server,
|
||||||
redis,
|
includes=ChatChannel.LISTING_INCLUDES,
|
||||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if "silences" in includes:
|
if "silences" in includes:
|
||||||
if history_since:
|
if history_since:
|
||||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||||
elif since:
|
elif since:
|
||||||
msg = await session.get(ChatMessage, since)
|
msg = await session.get(ChatMessage, since)
|
||||||
if msg:
|
if msg:
|
||||||
silences = (
|
silences = (
|
||||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||||
).all()
|
).all()
|
||||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/chat/channels/{channel}/users/{user}",
|
"/chat/channels/{channel}/users/{user}",
|
||||||
response_model=ChatChannelResp,
|
|
||||||
name="加入频道",
|
name="加入频道",
|
||||||
description="加入指定的公开/房间频道。",
|
description="加入指定的公开/房间频道。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
|
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
|
||||||
)
|
)
|
||||||
async def join_channel(
|
async def join_channel(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -101,7 +103,7 @@ async def join_channel(
|
|||||||
|
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
return await server.join_channel(current_user, db_channel, session)
|
return await server.join_channel(current_user, db_channel)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -128,13 +130,13 @@ async def leave_channel(
|
|||||||
|
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
await server.leave_channel(current_user, db_channel, session)
|
await server.leave_channel(current_user, db_channel)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/chat/channels",
|
"/chat/channels",
|
||||||
response_model=list[ChatChannelResp],
|
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
|
||||||
name="获取频道列表",
|
name="获取频道列表",
|
||||||
description="获取所有公开频道。",
|
description="获取所有公开频道。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
@@ -142,35 +144,30 @@ async def leave_channel(
|
|||||||
async def get_channel_list(
|
async def get_channel_list(
|
||||||
session: Database,
|
session: Database,
|
||||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||||
redis: Redis,
|
|
||||||
):
|
):
|
||||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||||
results = []
|
results = await ChatChannelModel.transform_many(
|
||||||
for channel in channels:
|
channels,
|
||||||
# 提取必要的属性避免惰性加载
|
user=current_user,
|
||||||
channel_id = channel.channel_id
|
server=server,
|
||||||
channel_type = channel.type
|
)
|
||||||
|
|
||||||
results.append(
|
|
||||||
await ChatChannelResp.from_db(
|
|
||||||
channel,
|
|
||||||
session,
|
|
||||||
current_user,
|
|
||||||
redis,
|
|
||||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class GetChannelResp(BaseModel):
|
|
||||||
channel: ChatChannelResp
|
|
||||||
users: list[UserResp] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/chat/channels/{channel}",
|
"/chat/channels/{channel}",
|
||||||
response_model=GetChannelResp,
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"频道详细信息",
|
||||||
|
{
|
||||||
|
"channel": ChatChannelModel,
|
||||||
|
"users": list[UserModel],
|
||||||
|
},
|
||||||
|
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
|
||||||
|
name="GetChannelResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取频道信息",
|
name="获取频道信息",
|
||||||
description="获取指定频道的信息。",
|
description="获取指定频道的信息。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
@@ -191,7 +188,6 @@ async def get_channel(
|
|||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
|
||||||
# 立即提取需要的属性
|
# 立即提取需要的属性
|
||||||
channel_id = db_channel.channel_id
|
|
||||||
channel_type = db_channel.type
|
channel_type = db_channel.type
|
||||||
channel_name = db_channel.name
|
channel_name = db_channel.name
|
||||||
|
|
||||||
@@ -209,15 +205,15 @@ async def get_channel(
|
|||||||
users.extend([target_user, current_user])
|
users.extend([target_user, current_user])
|
||||||
break
|
break
|
||||||
|
|
||||||
return GetChannelResp(
|
return {
|
||||||
channel=await ChatChannelResp.from_db(
|
"channel": await ChatChannelModel.transform(
|
||||||
db_channel,
|
db_channel,
|
||||||
session,
|
user=current_user,
|
||||||
current_user,
|
server=server,
|
||||||
redis,
|
includes=ChatChannel.LISTING_INCLUDES,
|
||||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
),
|
||||||
)
|
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
class CreateChannelReq(BaseModel):
|
class CreateChannelReq(BaseModel):
|
||||||
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
|
|||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/chat/channels",
|
"/chat/channels",
|
||||||
response_model=ChatChannelResp,
|
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
|
||||||
name="创建频道",
|
name="创建频道",
|
||||||
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
@@ -289,21 +285,13 @@ async def create_channel(
|
|||||||
await session.refresh(current_user)
|
await session.refresh(current_user)
|
||||||
if req.type == "PM":
|
if req.type == "PM":
|
||||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||||
else:
|
else:
|
||||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
await server.batch_join_channel([*target_users, current_user], channel)
|
||||||
|
|
||||||
await server.join_channel(current_user, channel, session)
|
await server.join_channel(current_user, channel)
|
||||||
|
|
||||||
# 提取必要的属性避免惰性加载
|
return await ChatChannelModel.transform(
|
||||||
channel_id = channel.channel_id
|
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||||
|
|
||||||
return await ChatChannelResp.from_db(
|
|
||||||
channel,
|
|
||||||
session,
|
|
||||||
current_user,
|
|
||||||
redis,
|
|
||||||
server.channels.get(channel_id, []),
|
|
||||||
include_recent_messages=True,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from app.database import ChatMessageResp
|
from app.database import ChatChannelModel
|
||||||
from app.database.chat import (
|
from app.database.chat import (
|
||||||
ChannelType,
|
ChannelType,
|
||||||
ChatChannel,
|
ChatChannel,
|
||||||
ChatChannelResp,
|
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ChatMessageModel,
|
||||||
MessageType,
|
MessageType,
|
||||||
SilenceUser,
|
SilenceUser,
|
||||||
UserSilenceResp,
|
UserSilenceResp,
|
||||||
@@ -18,6 +18,7 @@ from app.log import log
|
|||||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
from app.service.redis_message_system import redis_message_system
|
from app.service.redis_message_system import redis_message_system
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .banchobot import bot
|
from .banchobot import bot
|
||||||
from .server import server
|
from .server import server
|
||||||
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
|
|||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/chat/channels/{channel}/messages",
|
"/chat/channels/{channel}/messages",
|
||||||
response_model=ChatMessageResp,
|
responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
|
||||||
name="发送消息",
|
name="发送消息",
|
||||||
description="发送消息到指定频道。",
|
description="发送消息到指定频道。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
@@ -130,7 +131,7 @@ async def send_message(
|
|||||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||||
temp_msg = ChatMessage(
|
temp_msg = ChatMessage(
|
||||||
message_id=resp.message_id, # 使用 Redis 系统生成的ID
|
message_id=resp["message_id"], # 使用 Redis 系统生成的ID
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
content=req.message,
|
content=req.message,
|
||||||
sender_id=user_id,
|
sender_id=user_id,
|
||||||
@@ -151,7 +152,7 @@ async def send_message(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/chat/channels/{channel}/messages",
|
"/chat/channels/{channel}/messages",
|
||||||
response_model=list[ChatMessageResp],
|
responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
|
||||||
name="获取消息",
|
name="获取消息",
|
||||||
description="获取指定频道的消息列表(统一按时间正序返回)。",
|
description="获取指定频道的消息列表(统一按时间正序返回)。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
@@ -177,7 +178,7 @@ async def get_message(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||||
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
|
if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
|
||||||
messages.reverse()
|
messages.reverse()
|
||||||
return messages
|
return messages
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -189,7 +190,7 @@ async def get_message(
|
|||||||
# 向前加载新消息 → 直接 ASC
|
# 向前加载新消息 → 直接 ASC
|
||||||
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||||
rows = (await session.exec(query)).all()
|
rows = (await session.exec(query)).all()
|
||||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||||
# 已经 ASC,无需反转
|
# 已经 ASC,无需反转
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@@ -202,15 +203,14 @@ async def get_message(
|
|||||||
rows = (await session.exec(query)).all()
|
rows = (await session.exec(query)).all()
|
||||||
rows = list(rows)
|
rows = list(rows)
|
||||||
rows.reverse() # 反转为 ASC
|
rows.reverse() # 反转为 ASC
|
||||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||||
rows = (await session.exec(query)).all()
|
rows = (await session.exec(query)).all()
|
||||||
rows = list(rows)
|
rows = list(rows)
|
||||||
rows.reverse() # 反转为 ASC
|
rows.reverse() # 反转为 ASC
|
||||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||||
return resp
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
|
|||||||
uuid: str | None = None
|
uuid: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class NewPMResp(BaseModel):
|
|
||||||
channel: ChatChannelResp
|
|
||||||
message: ChatMessageResp
|
|
||||||
new_channel_id: int
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/chat/new",
|
"/chat/new",
|
||||||
name="创建私聊频道",
|
name="创建私聊频道",
|
||||||
description="创建一个新的私聊频道。",
|
description="创建一个新的私聊频道。",
|
||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"创建私聊频道响应",
|
||||||
|
{
|
||||||
|
"channel": ChatChannelModel,
|
||||||
|
"message": ChatMessageModel,
|
||||||
|
"new_channel_id": int,
|
||||||
|
},
|
||||||
|
["recent_messages.sender", "sender"],
|
||||||
|
name="NewPMResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def create_new_pm(
|
async def create_new_pm(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -290,9 +296,9 @@ async def create_new_pm(
|
|||||||
await session.refresh(target)
|
await session.refresh(target)
|
||||||
await session.refresh(current_user)
|
await session.refresh(current_user)
|
||||||
|
|
||||||
await server.batch_join_channel([target, current_user], channel, session)
|
await server.batch_join_channel([target, current_user], channel)
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelModel.transform(
|
||||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||||
)
|
)
|
||||||
msg = ChatMessage(
|
msg = ChatMessage(
|
||||||
channel_id=channel.channel_id,
|
channel_id=channel.channel_id,
|
||||||
@@ -306,10 +312,10 @@ async def create_new_pm(
|
|||||||
await session.refresh(msg)
|
await session.refresh(msg)
|
||||||
await session.refresh(current_user)
|
await session.refresh(current_user)
|
||||||
await session.refresh(channel)
|
await session.refresh(channel)
|
||||||
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
|
||||||
await server.send_message_to_channel(message_resp)
|
await server.send_message_to_channel(message_resp)
|
||||||
return NewPMResp(
|
return {
|
||||||
channel=channel_resp,
|
"channel": channel_resp,
|
||||||
message=message_resp,
|
"message": message_resp,
|
||||||
new_channel_id=channel_resp.channel_id,
|
"new_channel_id": channel_resp["channel_id"],
|
||||||
)
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Annotated, overload
|
from typing import Annotated, overload
|
||||||
|
|
||||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
from app.database import ChatMessageDict
|
||||||
|
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
|
||||||
from app.database.notification import UserNotification, insert_notification
|
from app.database.notification import UserNotification, insert_notification
|
||||||
from app.database.user import User
|
from app.database.user import User
|
||||||
from app.dependencies.database import (
|
from app.dependencies.database import (
|
||||||
@@ -16,7 +17,7 @@ from app.log import log
|
|||||||
from app.models.chat import ChatEvent
|
from app.models.chat import ChatEvent
|
||||||
from app.models.notification import NotificationDetail
|
from app.models.notification import NotificationDetail
|
||||||
from app.service.subscribers.chat import ChatSubscriber
|
from app.service.subscribers.chat import ChatSubscriber
|
||||||
from app.utils import bg_tasks
|
from app.utils import bg_tasks, safe_json_dumps
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.security import SecurityScopes
|
from fastapi.security import SecurityScopes
|
||||||
@@ -65,7 +66,7 @@ class ChatServer:
|
|||||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||||
).first()
|
).first()
|
||||||
if db_channel:
|
if db_channel:
|
||||||
await self.leave_channel(user, db_channel, session)
|
await self.leave_channel(user, db_channel)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def send_event(self, client: int, event: ChatEvent): ...
|
async def send_event(self, client: int, event: ChatEvent): ...
|
||||||
@@ -80,7 +81,7 @@ class ChatServer:
|
|||||||
return
|
return
|
||||||
client = client_
|
client = client_
|
||||||
if client.client_state == WebSocketState.CONNECTED:
|
if client.client_state == WebSocketState.CONNECTED:
|
||||||
await client.send_text(event.model_dump_json())
|
await client.send_text(safe_json_dumps(event))
|
||||||
|
|
||||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||||
users_in_channel = self.channels.get(channel_id, [])
|
users_in_channel = self.channels.get(channel_id, [])
|
||||||
@@ -107,38 +108,38 @@ class ChatServer:
|
|||||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||||
|
|
||||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sending message to channel {message.channel_id}, message_id: "
|
f"Sending message to channel {message['channel_id']}, message_id: "
|
||||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
f"{message['message_id']}, is_bot_command: {is_bot_command}"
|
||||||
)
|
)
|
||||||
|
|
||||||
event = ChatEvent(
|
event = ChatEvent(
|
||||||
event="chat.message.new",
|
event="chat.message.new",
|
||||||
data={"messages": [message], "users": [message.sender]},
|
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||||
)
|
)
|
||||||
if is_bot_command:
|
if is_bot_command:
|
||||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
logger.info(f"Sending bot command to user {message['sender_id']}")
|
||||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
bg_tasks.add_task(self.send_event, message["sender_id"], event)
|
||||||
else:
|
else:
|
||||||
# 总是广播消息,无论是临时ID还是真实ID
|
# 总是广播消息,无论是临时ID还是真实ID
|
||||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
|
||||||
bg_tasks.add_task(
|
bg_tasks.add_task(
|
||||||
self.broadcast,
|
self.broadcast,
|
||||||
message.channel_id,
|
message["channel_id"],
|
||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||||
if message.message_id and message.message_id > 0:
|
if message["message_id"] and message["message_id"] > 0:
|
||||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
|
||||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
|
||||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
|
||||||
|
|
||||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
|
||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
|
|
||||||
not_joined = []
|
not_joined = []
|
||||||
@@ -151,22 +152,18 @@ class ChatServer:
|
|||||||
not_joined.append(user)
|
not_joined.append(user)
|
||||||
|
|
||||||
for user in not_joined:
|
for user in not_joined:
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelModel.transform(
|
||||||
channel,
|
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||||
session,
|
|
||||||
user,
|
|
||||||
self.redis,
|
|
||||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
|
||||||
)
|
)
|
||||||
await self.send_event(
|
await self.send_event(
|
||||||
user.id,
|
user.id,
|
||||||
ChatEvent(
|
ChatEvent(
|
||||||
event="chat.channel.join",
|
event="chat.channel.join",
|
||||||
data=channel_resp.model_dump(),
|
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
|
|
||||||
@@ -175,25 +172,21 @@ class ChatServer:
|
|||||||
if user_id not in self.channels[channel_id]:
|
if user_id not in self.channels[channel_id]:
|
||||||
self.channels[channel_id].append(user_id)
|
self.channels[channel_id].append(user_id)
|
||||||
|
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
|
||||||
channel,
|
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||||
session,
|
|
||||||
user,
|
|
||||||
self.redis,
|
|
||||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.send_event(
|
await self.send_event(
|
||||||
user_id,
|
user_id,
|
||||||
ChatEvent(
|
ChatEvent(
|
||||||
event="chat.channel.join",
|
event="chat.channel.join",
|
||||||
data=channel_resp.model_dump(),
|
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return channel_resp
|
return channel_resp
|
||||||
|
|
||||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
|
|
||||||
@@ -203,18 +196,14 @@ class ChatServer:
|
|||||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||||
del self.channels[channel_id]
|
del self.channels[channel_id]
|
||||||
|
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelModel.transform(
|
||||||
channel,
|
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||||
session,
|
|
||||||
user,
|
|
||||||
self.redis,
|
|
||||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
|
||||||
)
|
)
|
||||||
await self.send_event(
|
await self.send_event(
|
||||||
user_id,
|
user_id,
|
||||||
ChatEvent(
|
ChatEvent(
|
||||||
event="chat.channel.part",
|
event="chat.channel.part",
|
||||||
data=channel_resp.model_dump(),
|
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,7 +221,7 @@ class ChatServer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
||||||
await self.join_channel(user, db_channel, session)
|
await self.join_channel(user, db_channel)
|
||||||
|
|
||||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
@@ -248,7 +237,7 @@ class ChatServer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
||||||
await self.leave_channel(user, db_channel, session)
|
await self.leave_channel(user, db_channel)
|
||||||
|
|
||||||
async def new_private_notification(self, detail: NotificationDetail):
|
async def new_private_notification(self, detail: NotificationDetail):
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
@@ -336,6 +325,6 @@ async def chat_websocket(
|
|||||||
# 使用明确的查询避免延迟加载
|
# 使用明确的查询避免延迟加载
|
||||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||||
if db_channel is not None:
|
if db_channel is not None:
|
||||||
await server.join_channel(user, db_channel, session)
|
await server.join_channel(user, db_channel)
|
||||||
|
|
||||||
await _listen_stop(websocket, user_id, factory)
|
await _listen_stop(websocket, user_id, factory)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import hashlib
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
|
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
|
||||||
from app.database.user import BASE_INCLUDES, User, UserResp
|
from app.database.user import User, UserModel
|
||||||
from app.dependencies.database import Database, Redis
|
from app.dependencies.database import Database, Redis
|
||||||
from app.dependencies.storage import StorageService
|
from app.dependencies.storage import StorageService
|
||||||
from app.dependencies.user import ClientUser
|
from app.dependencies.user import ClientUser
|
||||||
@@ -14,12 +14,11 @@ from app.models.notification import (
|
|||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.router.notification import server
|
from app.router.notification import server
|
||||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||||
from app.utils import check_image, utcnow
|
from app.utils import api_doc, check_image, utcnow
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import File, Form, HTTPException, Path, Query, Request
|
from fastapi import File, Form, HTTPException, Path, Query, Request
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, exists, select
|
||||||
|
|
||||||
|
|
||||||
@@ -214,12 +213,22 @@ async def delete_team(
|
|||||||
await cache_service.invalidate_team_cache()
|
await cache_service.invalidate_team_cache()
|
||||||
|
|
||||||
|
|
||||||
class TeamQueryResp(BaseModel):
|
@router.get(
|
||||||
team: TeamResp
|
"/team/{team_id}",
|
||||||
members: list[UserResp]
|
name="查询战队",
|
||||||
|
tags=["战队", "g0v0 API"],
|
||||||
|
responses={
|
||||||
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
|
200: api_doc(
|
||||||
|
"战队信息",
|
||||||
|
{
|
||||||
|
"team": TeamResp,
|
||||||
|
"members": list[UserModel],
|
||||||
|
},
|
||||||
|
["statistics", "country"],
|
||||||
|
name="TeamQueryResp",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
async def get_team(
|
async def get_team(
|
||||||
session: Database,
|
session: Database,
|
||||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||||
@@ -233,10 +242,10 @@ async def get_team(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
return TeamQueryResp(
|
return {
|
||||||
team=await TeamResp.from_db(members[0].team, session, gamemode),
|
"team": await TeamResp.from_db(members[0].team, session, gamemode),
|
||||||
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
|
"members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]),
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
|
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
from app.database.statistics import UserStatistics, UserStatisticsModel
|
||||||
from app.database.user import User
|
from app.database.user import User
|
||||||
from app.dependencies.database import Database, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -46,7 +46,7 @@ class V1User(AllStrModel):
|
|||||||
return f"v1_user:{user_id}"
|
return f"v1_user:{user_id}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
||||||
ruleset = ruleset or db_user.playmode
|
ruleset = ruleset or db_user.playmode
|
||||||
current_statistics: UserStatistics | None = None
|
current_statistics: UserStatistics | None = None
|
||||||
for i in await db_user.awaitable_attrs.statistics:
|
for i in await db_user.awaitable_attrs.statistics:
|
||||||
@@ -54,31 +54,33 @@ class V1User(AllStrModel):
|
|||||||
current_statistics = i
|
current_statistics = i
|
||||||
break
|
break
|
||||||
if current_statistics:
|
if current_statistics:
|
||||||
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
|
statistics = await UserStatisticsModel.transform(
|
||||||
|
current_statistics, country_code=db_user.country_code, includes=["country_rank"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
statistics = None
|
statistics = None
|
||||||
return cls(
|
return cls(
|
||||||
user_id=db_user.id,
|
user_id=db_user.id,
|
||||||
username=db_user.username,
|
username=db_user.username,
|
||||||
join_date=db_user.join_date,
|
join_date=db_user.join_date,
|
||||||
count300=statistics.count_300 if statistics else 0,
|
count300=current_statistics.count_300 if current_statistics else 0,
|
||||||
count100=statistics.count_100 if statistics else 0,
|
count100=current_statistics.count_100 if current_statistics else 0,
|
||||||
count50=statistics.count_50 if statistics else 0,
|
count50=current_statistics.count_50 if current_statistics else 0,
|
||||||
playcount=statistics.play_count if statistics else 0,
|
playcount=current_statistics.play_count if current_statistics else 0,
|
||||||
ranked_score=statistics.ranked_score if statistics else 0,
|
ranked_score=current_statistics.ranked_score if current_statistics else 0,
|
||||||
total_score=statistics.total_score if statistics else 0,
|
total_score=current_statistics.total_score if current_statistics else 0,
|
||||||
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
|
pp_rank=statistics.get("global_rank") or 0 if statistics else 0,
|
||||||
level=current_statistics.level_current if current_statistics else 0,
|
level=current_statistics.level_current if current_statistics else 0,
|
||||||
pp_raw=statistics.pp if statistics else 0.0,
|
pp_raw=current_statistics.pp if current_statistics else 0.0,
|
||||||
accuracy=statistics.hit_accuracy if statistics else 0,
|
accuracy=current_statistics.hit_accuracy if current_statistics else 0,
|
||||||
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
|
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
|
||||||
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
|
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
|
||||||
count_rank_s=current_statistics.grade_s if current_statistics else 0,
|
count_rank_s=current_statistics.grade_s if current_statistics else 0,
|
||||||
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
|
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
|
||||||
count_rank_a=current_statistics.grade_a if current_statistics else 0,
|
count_rank_a=current_statistics.grade_a if current_statistics else 0,
|
||||||
country=db_user.country_code,
|
country=db_user.country_code,
|
||||||
total_seconds_played=statistics.play_time if statistics else 0,
|
total_seconds_played=current_statistics.play_time if current_statistics else 0,
|
||||||
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
|
pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0,
|
||||||
events=[], # TODO
|
events=[], # TODO
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -134,7 +136,7 @@ async def get_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 生成用户数据
|
# 生成用户数据
|
||||||
v1_user = await V1User.from_db(session, db_user, ruleset)
|
v1_user = await V1User.from_db(db_user, ruleset)
|
||||||
|
|
||||||
# 异步缓存结果(如果有用户ID)
|
# 异步缓存结果(如果有用户ID)
|
||||||
if db_user.id is not None:
|
if db_user.id is not None:
|
||||||
|
|||||||
@@ -5,7 +5,11 @@ from typing import Annotated
|
|||||||
|
|
||||||
from app.calculator import get_calculator
|
from app.calculator import get_calculator
|
||||||
from app.calculators.performance import ConvertError
|
from app.calculators.performance import ConvertError
|
||||||
from app.database import Beatmap, BeatmapResp, User
|
from app.database import (
|
||||||
|
Beatmap,
|
||||||
|
BeatmapModel,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from app.database.beatmap import calculate_beatmap_attributes
|
from app.database.beatmap import calculate_beatmap_attributes
|
||||||
from app.dependencies.database import Database, Redis
|
from app.dependencies.database import Database, Redis
|
||||||
from app.dependencies.fetcher import Fetcher
|
from app.dependencies.fetcher import Fetcher
|
||||||
@@ -19,29 +23,20 @@ from app.models.performance import (
|
|||||||
from app.models.score import (
|
from app.models.score import (
|
||||||
GameMode,
|
GameMode,
|
||||||
)
|
)
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import HTTPException, Path, Query, Security
|
from fastapi import HTTPException, Path, Query, Security
|
||||||
from httpx import HTTPError, HTTPStatusError
|
from httpx import HTTPError, HTTPStatusError
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
|
|
||||||
class BatchGetResp(BaseModel):
|
|
||||||
"""批量获取谱面返回模型。
|
|
||||||
|
|
||||||
返回字段说明:
|
|
||||||
- beatmaps: 谱面详细信息列表。"""
|
|
||||||
|
|
||||||
beatmaps: list[BeatmapResp]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/beatmaps/lookup",
|
"/beatmaps/lookup",
|
||||||
tags=["谱面"],
|
tags=["谱面"],
|
||||||
name="查询单个谱面",
|
name="查询单个谱面",
|
||||||
response_model=BeatmapResp,
|
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
|
||||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
@@ -67,14 +62,14 @@ async def lookup_beatmap(
|
|||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
await db.refresh(current_user)
|
await db.refresh(current_user)
|
||||||
|
|
||||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/beatmaps/{beatmap_id}",
|
"/beatmaps/{beatmap_id}",
|
||||||
tags=["谱面"],
|
tags=["谱面"],
|
||||||
name="获取谱面详情",
|
name="获取谱面详情",
|
||||||
response_model=BeatmapResp,
|
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
|
||||||
description="获取单个谱面详情。",
|
description="获取单个谱面详情。",
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
@@ -86,7 +81,12 @@ async def get_beatmap(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
await db.refresh(current_user)
|
||||||
|
return await BeatmapModel.transform(
|
||||||
|
beatmap,
|
||||||
|
user=current_user,
|
||||||
|
includes=BeatmapModel.TRANSFORMER_INCLUDES,
|
||||||
|
)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
|
|
||||||
@@ -95,7 +95,11 @@ async def get_beatmap(
|
|||||||
"/beatmaps/",
|
"/beatmaps/",
|
||||||
tags=["谱面"],
|
tags=["谱面"],
|
||||||
name="批量获取谱面",
|
name="批量获取谱面",
|
||||||
response_model=BatchGetResp,
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
|
||||||
|
)
|
||||||
|
},
|
||||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
@@ -124,7 +128,12 @@ async def batch_get_beatmaps(
|
|||||||
for beatmap in beatmaps:
|
for beatmap in beatmaps:
|
||||||
await db.refresh(beatmap)
|
await db.refresh(beatmap)
|
||||||
await db.refresh(current_user)
|
await db.refresh(current_user)
|
||||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
return {
|
||||||
|
"beatmaps": [
|
||||||
|
await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
|
||||||
|
for bm in beatmaps
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|||||||
@@ -2,17 +2,24 @@ import re
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
from app.database import (
|
||||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
Beatmap,
|
||||||
|
Beatmapset,
|
||||||
|
BeatmapsetModel,
|
||||||
|
FavouriteBeatmapset,
|
||||||
|
SearchBeatmapsetsResp,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from app.dependencies.beatmap_download import DownloadService
|
from app.dependencies.beatmap_download import DownloadService
|
||||||
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
|
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
|
||||||
from app.dependencies.database import Database, Redis, with_db
|
from app.dependencies.database import Database, Redis
|
||||||
from app.dependencies.fetcher import Fetcher
|
from app.dependencies.fetcher import Fetcher
|
||||||
from app.dependencies.geoip import IPAddress, get_geoip_helper
|
from app.dependencies.geoip import IPAddress, get_geoip_helper
|
||||||
from app.dependencies.user import ClientUser, get_current_user
|
from app.dependencies.user import ClientUser, get_current_user
|
||||||
from app.helpers.asset_proxy_helper import asset_proxy_response
|
from app.helpers.asset_proxy_helper import asset_proxy_response
|
||||||
from app.models.beatmap import SearchQueryModel
|
from app.models.beatmap import SearchQueryModel
|
||||||
from app.service.beatmapset_cache_service import generate_hash
|
from app.service.beatmapset_cache_service import generate_hash
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
@@ -27,14 +34,7 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from sqlmodel import exists, select
|
from sqlmodel import select
|
||||||
|
|
||||||
|
|
||||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
|
||||||
async with with_db() as session:
|
|
||||||
for s in sets.beatmapsets:
|
|
||||||
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
|
|
||||||
await Beatmapset.from_resp(session, s)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -105,7 +105,6 @@ async def search_beatmapset(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
sets = await fetcher.search_beatmapset(query, cursor, redis)
|
sets = await fetcher.search_beatmapset(query, cursor, redis)
|
||||||
background_tasks.add_task(_save_to_db, sets)
|
|
||||||
|
|
||||||
# 缓存搜索结果
|
# 缓存搜索结果
|
||||||
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
||||||
@@ -117,8 +116,8 @@ async def search_beatmapset(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/beatmapsets/lookup",
|
"/beatmapsets/lookup",
|
||||||
tags=["谱面集"],
|
tags=["谱面集"],
|
||||||
|
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
|
||||||
name="查询谱面集 (通过谱面 ID)",
|
name="查询谱面集 (通过谱面 ID)",
|
||||||
response_model=BeatmapsetResp,
|
|
||||||
description=("通过谱面 ID 查询所属谱面集。"),
|
description=("通过谱面 ID 查询所属谱面集。"),
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
@@ -137,7 +136,10 @@ async def lookup_beatmapset(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
|
||||||
|
resp = await BeatmapsetModel.transform(
|
||||||
|
beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES
|
||||||
|
)
|
||||||
|
|
||||||
# 缓存结果
|
# 缓存结果
|
||||||
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
||||||
@@ -149,8 +151,8 @@ async def lookup_beatmapset(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/beatmapsets/{beatmapset_id}",
|
"/beatmapsets/{beatmapset_id}",
|
||||||
tags=["谱面集"],
|
tags=["谱面集"],
|
||||||
|
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
|
||||||
name="获取谱面集详情",
|
name="获取谱面集详情",
|
||||||
response_model=BeatmapsetResp,
|
|
||||||
description="获取单个谱面集详情。",
|
description="获取单个谱面集详情。",
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
@@ -169,7 +171,8 @@ async def get_beatmapset(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
||||||
resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
|
await db.refresh(current_user)
|
||||||
|
resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user)
|
||||||
|
|
||||||
# 缓存结果
|
# 缓存结果
|
||||||
await cache_service.cache_beatmapset(resp)
|
await cache_service.cache_beatmapset(resp)
|
||||||
|
|||||||
@@ -1,20 +1,29 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from app.database import FavouriteBeatmapset, MeResp, User
|
from app.database import FavouriteBeatmapset, User
|
||||||
|
from app.database.user import UserModel
|
||||||
from app.dependencies.database import Database
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
|
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Path, Security
|
from fastapi import Path, Security
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
|
ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"]
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapsetIds(BaseModel):
|
||||||
|
beatmapset_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/beatmapset-favourites",
|
"/me/beatmapset-favourites",
|
||||||
response_model=list[int],
|
response_model=BeatmapsetIds,
|
||||||
name="获取当前用户收藏的谱面集 ID 列表",
|
name="获取当前用户收藏的谱面集 ID 列表",
|
||||||
description="获取当前登录用户收藏的谱面集 ID 列表。",
|
description="获取当前登录用户收藏的谱面集 ID 列表。",
|
||||||
tags=["用户", "谱面集"],
|
tags=["用户", "谱面集"],
|
||||||
@@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites(
|
|||||||
beatmapset_ids = await session.exec(
|
beatmapset_ids = await session.exec(
|
||||||
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
|
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
|
||||||
)
|
)
|
||||||
return beatmapset_ids.all()
|
return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all()))
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/{ruleset}",
|
"/me/{ruleset}",
|
||||||
response_model=MeResp,
|
responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)},
|
||||||
name="获取当前用户信息 (指定 ruleset)",
|
name="获取当前用户信息 (指定 ruleset)",
|
||||||
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info_with_ruleset(
|
async def get_user_info_with_ruleset(
|
||||||
session: Database,
|
|
||||||
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
|
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
|
||||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||||
):
|
):
|
||||||
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
|
user_resp = await UserModel.transform(
|
||||||
|
user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES
|
||||||
|
)
|
||||||
return user_resp
|
return user_resp
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/",
|
"/me/",
|
||||||
response_model=MeResp,
|
responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)},
|
||||||
name="获取当前用户信息",
|
name="获取当前用户信息",
|
||||||
description="获取当前登录用户信息。",
|
description="获取当前登录用户信息。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info_default(
|
async def get_user_info_default(
|
||||||
session: Database,
|
|
||||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||||
):
|
):
|
||||||
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
|
user_resp = await UserModel.transform(
|
||||||
|
user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES
|
||||||
|
)
|
||||||
return user_resp
|
return user_resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
|
from app.database import Team, TeamMember, User, UserStatistics
|
||||||
|
from app.database.statistics import UserStatisticsModel
|
||||||
from app.dependencies.database import Database, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
@@ -308,14 +310,16 @@ async def get_country_ranking(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class TopUsersResponse(BaseModel):
|
|
||||||
ranking: list[UserStatisticsResp]
|
|
||||||
total: int
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rankings/{ruleset}/{sort}",
|
"/rankings/{ruleset}/{sort}",
|
||||||
response_model=TopUsersResponse,
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"用户排行榜",
|
||||||
|
{"ranking": list[UserStatisticsModel], "total": int},
|
||||||
|
["user.country", "user.cover"],
|
||||||
|
name="TopUsersResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取用户排行榜",
|
name="获取用户排行榜",
|
||||||
description="获取在指定模式下的用户排行榜",
|
description="获取在指定模式下的用户排行榜",
|
||||||
tags=["排行榜"],
|
tags=["排行榜"],
|
||||||
@@ -339,10 +343,10 @@ async def get_user_ranking(
|
|||||||
|
|
||||||
if cached_data and cached_stats:
|
if cached_data and cached_stats:
|
||||||
# 从缓存返回数据
|
# 从缓存返回数据
|
||||||
return TopUsersResponse(
|
return {
|
||||||
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data],
|
"ranking": cached_data,
|
||||||
total=cached_stats.get("total", 0),
|
"total": cached_stats.get("total", 0),
|
||||||
)
|
}
|
||||||
|
|
||||||
# 缓存未命中,从数据库查询
|
# 缓存未命中,从数据库查询
|
||||||
wheres = [
|
wheres = [
|
||||||
@@ -350,7 +354,7 @@ async def get_user_ranking(
|
|||||||
col(UserStatistics.pp) > 0,
|
col(UserStatistics.pp) > 0,
|
||||||
col(UserStatistics.is_ranked),
|
col(UserStatistics.is_ranked),
|
||||||
]
|
]
|
||||||
include = ["user"]
|
include = UserStatistics.RANKING_INCLUDES.copy()
|
||||||
if sort == "performance":
|
if sort == "performance":
|
||||||
order_by = col(UserStatistics.pp).desc()
|
order_by = col(UserStatistics.pp).desc()
|
||||||
include.append("rank_change_since_30_days")
|
include.append("rank_change_since_30_days")
|
||||||
@@ -358,6 +362,7 @@ async def get_user_ranking(
|
|||||||
order_by = col(UserStatistics.ranked_score).desc()
|
order_by = col(UserStatistics.ranked_score).desc()
|
||||||
if country:
|
if country:
|
||||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||||
|
include.append("country_rank")
|
||||||
|
|
||||||
# 查询总数
|
# 查询总数
|
||||||
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
|
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
|
||||||
@@ -378,12 +383,14 @@ async def get_user_ranking(
|
|||||||
# 转换为响应格式
|
# 转换为响应格式
|
||||||
ranking_data = []
|
ranking_data = []
|
||||||
for statistics in statistics_list:
|
for statistics in statistics_list:
|
||||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
user_stats_resp = await UserStatisticsModel.transform(
|
||||||
|
statistics, includes=include, user_country=current_user.country_code
|
||||||
|
)
|
||||||
ranking_data.append(user_stats_resp)
|
ranking_data.append(user_stats_resp)
|
||||||
|
|
||||||
# 异步缓存数据(不等待完成)
|
# 异步缓存数据(不等待完成)
|
||||||
# 使用配置文件中的TTL设置
|
# 使用配置文件中的TTL设置
|
||||||
cache_data = [item.model_dump() for item in ranking_data]
|
cache_data = ranking_data
|
||||||
stats_data = {"total": total_count}
|
stats_data = {"total": total_count}
|
||||||
|
|
||||||
# 创建后台任务来缓存数据
|
# 创建后台任务来缓存数据
|
||||||
@@ -407,5 +414,7 @@ async def get_user_ranking(
|
|||||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = TopUsersResponse(ranking=ranking_data, total=total_count)
|
return {
|
||||||
return resp
|
"ranking": ranking_data,
|
||||||
|
"total": total_count,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
from app.database import Relationship, RelationshipType, User
|
||||||
from app.database.user import UserResp
|
from app.database.relationship import RelationshipModel
|
||||||
|
from app.database.user import UserModel
|
||||||
from app.dependencies.api_version import APIVersion
|
from app.dependencies.api_version import APIVersion
|
||||||
from app.dependencies.database import Database
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import ClientUser, get_current_user
|
from app.dependencies.user import ClientUser, get_current_user
|
||||||
|
from app.utils import api_doc
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import HTTPException, Path, Query, Request, Security
|
from fastapi import HTTPException, Path, Query, Request, Security
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, exists, select
|
||||||
|
|
||||||
|
|
||||||
@@ -17,38 +18,19 @@ from sqlmodel import col, exists, select
|
|||||||
"/friends",
|
"/friends",
|
||||||
tags=["用户关系"],
|
tags=["用户关系"],
|
||||||
responses={
|
responses={
|
||||||
200: {
|
200: api_doc(
|
||||||
"description": "好友列表",
|
"好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表。",
|
||||||
"content": {
|
list[RelationshipModel] | list[UserModel],
|
||||||
"application/json": {
|
[f"target.{inc}" for inc in User.LIST_INCLUDES],
|
||||||
"schema": {
|
)
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "array",
|
|
||||||
"items": {"$ref": "#/components/schemas/RelationshipResp"},
|
|
||||||
"description": "好友列表",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array",
|
|
||||||
"items": {"$ref": "#/components/schemas/UserResp"},
|
|
||||||
"description": "好友列表 (`x-api-version < 20241022`)",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
name="获取好友列表",
|
name="获取好友列表",
|
||||||
description=(
|
description="获取当前用户的好友列表。",
|
||||||
"获取当前用户的好友列表。\n\n"
|
|
||||||
"如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@router.get(
|
@router.get(
|
||||||
"/blocks",
|
"/blocks",
|
||||||
tags=["用户关系"],
|
tags=["用户关系"],
|
||||||
response_model=list[RelationshipResp],
|
response_model=list[dict[str, Any]],
|
||||||
name="获取屏蔽列表",
|
name="获取屏蔽列表",
|
||||||
description="获取当前用户的屏蔽用户列表。",
|
description="获取当前用户的屏蔽用户列表。",
|
||||||
)
|
)
|
||||||
@@ -67,35 +49,29 @@ async def get_relationship(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
|
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
|
||||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
return [
|
||||||
|
await RelationshipModel.transform(
|
||||||
|
rel,
|
||||||
|
includes=[f"target.{inc}" for inc in User.LIST_INCLUDES],
|
||||||
|
ruleset=current_user.playmode,
|
||||||
|
)
|
||||||
|
for rel in relationships.unique()
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
return [
|
return [
|
||||||
await UserResp.from_db(
|
await UserModel.transform(
|
||||||
rel.target,
|
rel.target,
|
||||||
db,
|
ruleset=current_user.playmode,
|
||||||
include=[
|
includes=User.LIST_INCLUDES,
|
||||||
"team",
|
|
||||||
"daily_challenge_user_stats",
|
|
||||||
"statistics",
|
|
||||||
"statistics_rulesets",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
for rel in relationships.unique()
|
for rel in relationships.unique()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class AddFriendResp(BaseModel):
|
|
||||||
"""添加好友/屏蔽 返回模型。
|
|
||||||
|
|
||||||
- user_relation: 新的或更新后的关系对象。"""
|
|
||||||
|
|
||||||
user_relation: RelationshipResp
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/friends",
|
"/friends",
|
||||||
tags=["用户关系"],
|
tags=["用户关系"],
|
||||||
response_model=AddFriendResp,
|
responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")},
|
||||||
name="添加或更新好友关系",
|
name="添加或更新好友关系",
|
||||||
description="\n添加或更新与目标用户的好友关系。",
|
description="\n添加或更新与目标用户的好友关系。",
|
||||||
)
|
)
|
||||||
@@ -163,7 +139,13 @@ async def add_relationship(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).one()
|
).one()
|
||||||
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
|
return {
|
||||||
|
"user_relation": await RelationshipModel.transform(
|
||||||
|
relationship,
|
||||||
|
includes=[],
|
||||||
|
ruleset=current_user.playmode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
|||||||
@@ -1,25 +1,27 @@
|
|||||||
from datetime import UTC
|
from datetime import UTC
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from app.database.beatmap import Beatmap, BeatmapResp
|
from app.database.beatmap import (
|
||||||
from app.database.beatmapset import BeatmapsetResp
|
Beatmap,
|
||||||
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp
|
BeatmapModel,
|
||||||
|
)
|
||||||
|
from app.database.beatmapset import BeatmapsetModel
|
||||||
|
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel
|
||||||
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||||
from app.database.playlists import Playlist, PlaylistResp
|
from app.database.playlists import Playlist, PlaylistModel
|
||||||
from app.database.room import APIUploadedRoom, Room, RoomResp
|
from app.database.room import APIUploadedRoom, Room, RoomModel
|
||||||
from app.database.room_participated_user import RoomParticipatedUser
|
from app.database.room_participated_user import RoomParticipatedUser
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.database.user import User, UserResp
|
from app.database.user import User, UserModel
|
||||||
from app.dependencies.database import Database, Redis
|
from app.dependencies.database import Database, Redis
|
||||||
from app.dependencies.user import ClientUser, get_current_user
|
from app.dependencies.user import ClientUser, get_current_user
|
||||||
from app.models.room import MatchType, RoomCategory, RoomStatus
|
from app.models.room import MatchType, RoomCategory, RoomStatus
|
||||||
from app.service.room import create_playlist_room_from_api
|
from app.service.room import create_playlist_room_from_api
|
||||||
from app.utils import utcnow
|
from app.utils import api_doc, utcnow
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import HTTPException, Path, Query, Security
|
from fastapi import HTTPException, Path, Query, Security
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/rooms",
|
"/rooms",
|
||||||
tags=["房间"],
|
tags=["房间"],
|
||||||
response_model=list[RoomResp],
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间列表",
|
||||||
|
list[RoomModel],
|
||||||
|
[
|
||||||
|
"current_playlist_item.beatmap.beatmapset",
|
||||||
|
"difficulty_range",
|
||||||
|
"host.country",
|
||||||
|
"playlist_item_stats",
|
||||||
|
"recent_participants",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取房间列表",
|
name="获取房间列表",
|
||||||
description="获取房间列表。支持按状态/模式筛选",
|
description="获取房间列表。支持按状态/模式筛选",
|
||||||
)
|
)
|
||||||
@@ -49,7 +63,7 @@ async def get_all_rooms(
|
|||||||
] = RoomCategory.NORMAL,
|
] = RoomCategory.NORMAL,
|
||||||
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
|
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
|
||||||
):
|
):
|
||||||
resp_list: list[RoomResp] = []
|
resp_list = []
|
||||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
|
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
|
|
||||||
@@ -90,22 +104,24 @@ async def get_all_rooms(
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
for room in db_rooms:
|
for room in db_rooms:
|
||||||
resp = await RoomResp.from_db(room, db)
|
resp = await RoomModel.transform(
|
||||||
|
room,
|
||||||
|
includes=[
|
||||||
|
"current_playlist_item.beatmap.beatmapset",
|
||||||
|
"difficulty_range",
|
||||||
|
"host.country",
|
||||||
|
"playlist_item_stats",
|
||||||
|
"recent_participants",
|
||||||
|
],
|
||||||
|
)
|
||||||
if category == RoomCategory.REALTIME:
|
if category == RoomCategory.REALTIME:
|
||||||
resp.category = RoomCategory.NORMAL
|
resp["category"] = RoomCategory.NORMAL
|
||||||
|
|
||||||
resp_list.append(resp)
|
resp_list.append(resp)
|
||||||
|
|
||||||
return resp_list
|
return resp_list
|
||||||
|
|
||||||
|
|
||||||
class APICreatedRoom(RoomResp):
|
|
||||||
"""创建房间返回模型,继承 RoomResp。额外字段:
|
|
||||||
- error: 错误信息(为空表示成功)。"""
|
|
||||||
|
|
||||||
error: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
||||||
participated_user = (
|
participated_user = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/rooms",
|
"/rooms",
|
||||||
tags=["房间"],
|
tags=["房间"],
|
||||||
response_model=APICreatedRoom,
|
|
||||||
name="创建房间",
|
name="创建房间",
|
||||||
description="\n创建一个新的房间。",
|
description="\n创建一个新的房间。",
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"创建的房间信息",
|
||||||
|
RoomModel,
|
||||||
|
Room.SHOW_RESPONSE_INCLUDES,
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def create_room(
|
async def create_room(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -145,23 +167,27 @@ async def create_room(
|
|||||||
):
|
):
|
||||||
if await current_user.is_restricted(db):
|
if await current_user.is_restricted(db):
|
||||||
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
|
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
|
||||||
|
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(db_room)
|
await db.refresh(db_room)
|
||||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
|
||||||
created_room.error = ""
|
|
||||||
return created_room
|
return created_room
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rooms/{room_id}",
|
"/rooms/{room_id}",
|
||||||
tags=["房间"],
|
tags=["房间"],
|
||||||
response_model=RoomResp,
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间详细信息",
|
||||||
|
RoomModel,
|
||||||
|
Room.SHOW_RESPONSE_INCLUDES,
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取房间详情",
|
name="获取房间详情",
|
||||||
description="获取单个房间详情。",
|
description="获取指定房间详情。",
|
||||||
)
|
)
|
||||||
async def get_room(
|
async def get_room(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -177,7 +203,7 @@ async def get_room(
|
|||||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||||
if db_room is None:
|
if db_room is None:
|
||||||
raise HTTPException(404, "Room not found")
|
raise HTTPException(404, "Room not found")
|
||||||
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
|
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@@ -225,10 +251,10 @@ async def add_user_to_room(
|
|||||||
await _participate_room(room_id, user_id, db_room, db, redis)
|
await _participate_room(room_id, user_id, db_room, db, redis)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(db_room)
|
await db.refresh(db_room)
|
||||||
resp = await RoomResp.from_db(db_room, db)
|
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
|
||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
raise HTTPException(404, "room not found0")
|
raise HTTPException(404, "room not found")
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -268,21 +294,22 @@ async def remove_user_from_room(
|
|||||||
raise HTTPException(404, "Room not found")
|
raise HTTPException(404, "Room not found")
|
||||||
|
|
||||||
|
|
||||||
class APILeaderboard(BaseModel):
|
|
||||||
"""房间全局排行榜返回模型。
|
|
||||||
- leaderboard: 用户游玩统计(尝试次数/分数等)。
|
|
||||||
- user_score: 当前用户对应统计。"""
|
|
||||||
|
|
||||||
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
|
|
||||||
user_score: ItemAttemptsResp | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rooms/{room_id}/leaderboard",
|
"/rooms/{room_id}/leaderboard",
|
||||||
tags=["房间"],
|
tags=["房间"],
|
||||||
response_model=APILeaderboard,
|
|
||||||
name="获取房间排行榜",
|
name="获取房间排行榜",
|
||||||
description="获取房间内累计得分排行榜。",
|
description="获取房间内累计得分排行榜。",
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间排行榜",
|
||||||
|
{
|
||||||
|
"leaderboard": list[ItemAttemptsCountModel],
|
||||||
|
"user_score": ItemAttemptsCountModel | None,
|
||||||
|
},
|
||||||
|
["user.country", "position"],
|
||||||
|
name="RoomLeaderboardResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def get_room_leaderboard(
|
async def get_room_leaderboard(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -300,45 +327,43 @@ async def get_room_leaderboard(
|
|||||||
aggs_resp = []
|
aggs_resp = []
|
||||||
user_agg = None
|
user_agg = None
|
||||||
for i, agg in enumerate(aggs):
|
for i, agg in enumerate(aggs):
|
||||||
resp = await ItemAttemptsResp.from_db(agg, db)
|
includes = ["user.country"]
|
||||||
resp.position = i + 1
|
if agg.user_id == current_user.id:
|
||||||
|
includes.append("position")
|
||||||
|
resp = await ItemAttemptsCountModel.transform(agg, includes=includes)
|
||||||
aggs_resp.append(resp)
|
aggs_resp.append(resp)
|
||||||
if agg.user_id == current_user.id:
|
if agg.user_id == current_user.id:
|
||||||
user_agg = resp
|
user_agg = resp
|
||||||
return APILeaderboard(
|
|
||||||
leaderboard=aggs_resp,
|
|
||||||
user_score=user_agg,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
return {
|
||||||
class RoomEvents(BaseModel):
|
"leaderboard": aggs_resp,
|
||||||
"""房间事件流返回模型。
|
"user_score": user_agg,
|
||||||
- beatmaps: 本次结果涉及的谱面列表。
|
}
|
||||||
- beatmapsets: 谱面集映射。
|
|
||||||
- current_playlist_item_id: 当前游玩列表(项目)项 ID。
|
|
||||||
- events: 事件列表。
|
|
||||||
- first_event_id / last_event_id: 事件范围。
|
|
||||||
- playlist_items: 房间游玩列表(项目)详情。
|
|
||||||
- room: 房间详情。
|
|
||||||
- user: 关联用户列表。"""
|
|
||||||
|
|
||||||
beatmaps: list[BeatmapResp] = Field(default_factory=list)
|
|
||||||
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
|
|
||||||
current_playlist_item_id: int = 0
|
|
||||||
events: list[MultiplayerEventResp] = Field(default_factory=list)
|
|
||||||
first_event_id: int = 0
|
|
||||||
last_event_id: int = 0
|
|
||||||
playlist_items: list[PlaylistResp] = Field(default_factory=list)
|
|
||||||
room: RoomResp
|
|
||||||
user: list[UserResp] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rooms/{room_id}/events",
|
"/rooms/{room_id}/events",
|
||||||
response_model=RoomEvents,
|
|
||||||
tags=["房间"],
|
tags=["房间"],
|
||||||
name="获取房间事件",
|
name="获取房间事件",
|
||||||
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
|
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间事件",
|
||||||
|
{
|
||||||
|
"beatmaps": list[BeatmapModel],
|
||||||
|
"beatmapsets": list[BeatmapsetModel],
|
||||||
|
"current_playlist_item_id": int,
|
||||||
|
"events": list[MultiplayerEventResp],
|
||||||
|
"first_event_id": int,
|
||||||
|
"last_event_id": int,
|
||||||
|
"playlist_items": list[PlaylistModel],
|
||||||
|
"room": RoomModel,
|
||||||
|
"user": list[UserModel],
|
||||||
|
},
|
||||||
|
["country", "details", "scores"],
|
||||||
|
name="RoomEventsResponse",
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def get_room_events(
|
async def get_room_events(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -402,28 +427,44 @@ async def get_room_events(
|
|||||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||||
if room is None:
|
if room is None:
|
||||||
raise HTTPException(404, "Room not found")
|
raise HTTPException(404, "Room not found")
|
||||||
room_resp = await RoomResp.from_db(room, db)
|
room_resp = await RoomModel.transform(room, includes=["current_playlist_item"])
|
||||||
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
|
if room.category == RoomCategory.REALTIME:
|
||||||
current_playlist_item_id = room_resp.current_playlist_item.id
|
current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"]
|
||||||
|
|
||||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
user_resps = [await UserModel.transform(user, includes=["country"]) for user in users]
|
||||||
|
|
||||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||||
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
|
beatmap_resps = [
|
||||||
beatmapset_resps = {}
|
await BeatmapModel.transform(
|
||||||
for beatmap_resp in beatmap_resps:
|
beatmap,
|
||||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
)
|
||||||
|
for beatmap in beatmaps
|
||||||
|
]
|
||||||
|
|
||||||
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
|
beatmapsets = []
|
||||||
|
for beatmap in beatmaps:
|
||||||
|
if beatmap.beatmapset_id not in beatmapsets:
|
||||||
|
beatmapsets.append(beatmap.beatmapset)
|
||||||
|
beatmapset_resps = [
|
||||||
|
await BeatmapsetModel.transform(
|
||||||
|
beatmapset,
|
||||||
|
)
|
||||||
|
for beatmapset in beatmapsets
|
||||||
|
]
|
||||||
|
|
||||||
return RoomEvents(
|
playlist_items_resps = [
|
||||||
beatmaps=beatmap_resps,
|
await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values()
|
||||||
beatmapsets=beatmapset_resps,
|
]
|
||||||
current_playlist_item_id=current_playlist_item_id,
|
|
||||||
events=event_resps,
|
return {
|
||||||
first_event_id=first_event_id,
|
"beatmaps": beatmap_resps,
|
||||||
last_event_id=last_event_id,
|
"beatmapsets": beatmapset_resps,
|
||||||
playlist_items=playlist_items_resps,
|
"current_playlist_item_id": current_playlist_item_id,
|
||||||
room=room_resp,
|
"events": event_resps,
|
||||||
user=user_resps,
|
"first_event_id": first_event_id,
|
||||||
)
|
"last_event_id": last_event_id,
|
||||||
|
"playlist_items": playlist_items_resps,
|
||||||
|
"room": room_resp,
|
||||||
|
"user": user_resps,
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from app.database import (
|
|||||||
Playlist,
|
Playlist,
|
||||||
Room,
|
Room,
|
||||||
Score,
|
Score,
|
||||||
ScoreResp,
|
|
||||||
ScoreToken,
|
ScoreToken,
|
||||||
ScoreTokenResp,
|
ScoreTokenResp,
|
||||||
User,
|
User,
|
||||||
@@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType
|
|||||||
from app.database.score import (
|
from app.database.score import (
|
||||||
LegacyScoreResp,
|
LegacyScoreResp,
|
||||||
MultiplayerScores,
|
MultiplayerScores,
|
||||||
ScoreAround,
|
MultiplayScoreDict,
|
||||||
|
ScoreModel,
|
||||||
get_leaderboard,
|
get_leaderboard,
|
||||||
|
get_score_position_by_id,
|
||||||
process_score,
|
process_score,
|
||||||
process_user,
|
process_user,
|
||||||
)
|
)
|
||||||
@@ -49,7 +50,7 @@ from app.models.score import (
|
|||||||
)
|
)
|
||||||
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||||
from app.service.user_cache_service import refresh_user_cache_background
|
from app.service.user_cache_service import refresh_user_cache_background
|
||||||
from app.utils import utcnow
|
from app.utils import api_doc, utcnow
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
@@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
READ_SCORE_TIMEOUT = 10
|
READ_SCORE_TIMEOUT = 10
|
||||||
|
DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"]
|
||||||
logger = log("Score")
|
logger = log("Score")
|
||||||
|
|
||||||
|
|
||||||
@@ -180,13 +182,15 @@ async def submit_score(
|
|||||||
await db.refresh(score)
|
await db.refresh(score)
|
||||||
|
|
||||||
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
|
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
|
||||||
resp: ScoreResp = await ScoreResp.from_db(db, score)
|
resp = await ScoreModel.transform(
|
||||||
|
score,
|
||||||
|
)
|
||||||
score_gamemode = score.gamemode
|
score_gamemode = score.gamemode
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
|
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
|
||||||
background_task.add_task(_process_user_achievement, resp.id)
|
background_task.add_task(_process_user_achievement, resp["id"])
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None:
|
|||||||
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
|
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel):
|
LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapUserScore(BaseModel):
|
||||||
position: int
|
position: int
|
||||||
score: T
|
score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm]
|
||||||
|
|
||||||
|
|
||||||
class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
|
class BeatmapScores(BaseModel):
|
||||||
scores: list[T]
|
scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm]
|
||||||
user_score: BeatmapUserScore[T] | None = None
|
user_score: BeatmapUserScore | None = None
|
||||||
score_count: int = 0
|
score_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/beatmaps/{beatmap_id}/scores",
|
"/beatmaps/{beatmap_id}/scores",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp],
|
responses={
|
||||||
|
200: {
|
||||||
|
"model": BeatmapScores,
|
||||||
|
"description": (
|
||||||
|
"排行榜及当前用户成绩。\n\n"
|
||||||
|
f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`"
|
||||||
|
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
|
||||||
|
"否则为 `BeatmapScores[LegacyScoreResp]`。"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
name="获取谱面排行榜",
|
name="获取谱面排行榜",
|
||||||
description=(
|
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
|
||||||
"获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n"
|
|
||||||
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`,"
|
|
||||||
"否则为 `BeatmapScores[LegacyScoreResp]`。"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
async def get_beatmap_scores(
|
async def get_beatmap_scores(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -266,27 +279,46 @@ async def get_beatmap_scores(
|
|||||||
mods=sorted(mods),
|
mods=sorted(mods),
|
||||||
)
|
)
|
||||||
|
|
||||||
user_score_resp = await user_score.to_resp(db, api_version) if user_score else None
|
user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None
|
||||||
resp = BeatmapScores(
|
return {
|
||||||
scores=[await score.to_resp(db, api_version) for score in all_scores],
|
"scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores],
|
||||||
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
|
"user_score": (
|
||||||
if user_score_resp
|
{
|
||||||
else None,
|
"score": user_score_resp,
|
||||||
score_count=count,
|
"position": (
|
||||||
)
|
await get_score_position_by_id(
|
||||||
return resp
|
db,
|
||||||
|
user_score.beatmap_id,
|
||||||
|
user_score.id,
|
||||||
|
mode=user_score.gamemode,
|
||||||
|
user=user_score.user,
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if user_score and user_score_resp
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"score_count": count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
|
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp],
|
responses={
|
||||||
|
200: {
|
||||||
|
"model": BeatmapUserScore,
|
||||||
|
"description": (
|
||||||
|
"指定用户在指定谱面上的最高成绩\n\n"
|
||||||
|
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`,"
|
||||||
|
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}),"
|
||||||
|
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
name="获取用户谱面最高成绩",
|
name="获取用户谱面最高成绩",
|
||||||
description=(
|
description="获取指定用户在指定谱面上的最高成绩。",
|
||||||
"获取指定用户在指定谱面上的最高成绩。\n\n"
|
|
||||||
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`,"
|
|
||||||
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
async def get_user_beatmap_score(
|
async def get_user_beatmap_score(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -318,23 +350,38 @@ async def get_user_beatmap_score(
|
|||||||
detail=f"Cannot find user {user_id}'s score on this beatmap",
|
detail=f"Cannot find user {user_id}'s score on this beatmap",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
resp = await user_score.to_resp(db, api_version=api_version)
|
resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES)
|
||||||
return BeatmapUserScore(
|
return {
|
||||||
position=resp.rank_global or 0,
|
"position": (
|
||||||
score=resp,
|
await get_score_position_by_id(
|
||||||
)
|
db,
|
||||||
|
user_score.beatmap_id,
|
||||||
|
user_score.id,
|
||||||
|
mode=user_score.gamemode,
|
||||||
|
user=user_score.user,
|
||||||
|
)
|
||||||
|
or 0
|
||||||
|
),
|
||||||
|
"score": resp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
|
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
response_model=list[ScoreResp] | list[LegacyScoreResp],
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
(
|
||||||
|
"用户谱面全部成绩\n\n"
|
||||||
|
"如果 `x-api-version >= 20220705`,返回值为 `Score`列表,"
|
||||||
|
"否则为 `LegacyScoreResp`列表。"
|
||||||
|
),
|
||||||
|
list[ScoreModel] | list[LegacyScoreResp],
|
||||||
|
DEFAULT_SCORE_INCLUDES,
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取用户谱面全部成绩",
|
name="获取用户谱面全部成绩",
|
||||||
description=(
|
description="获取指定用户在指定谱面上的全部成绩列表。",
|
||||||
"获取指定用户在指定谱面上的全部成绩列表。\n\n"
|
|
||||||
"如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表,"
|
|
||||||
"否则为 `LegacyScoreResp`列表。"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
async def get_user_all_beatmap_scores(
|
async def get_user_all_beatmap_scores(
|
||||||
db: Database,
|
db: Database,
|
||||||
@@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores(
|
|||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
return [await score.to_resp(db, api_version) for score in all_user_scores]
|
return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores]
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -413,9 +460,9 @@ async def create_solo_score(
|
|||||||
@router.put(
|
@router.put(
|
||||||
"/beatmaps/{beatmap_id}/solo/scores/{token}",
|
"/beatmaps/{beatmap_id}/solo/scores/{token}",
|
||||||
tags=["游玩"],
|
tags=["游玩"],
|
||||||
response_model=ScoreResp,
|
|
||||||
name="提交单曲成绩",
|
name="提交单曲成绩",
|
||||||
description="\n使用令牌提交单曲成绩。",
|
description="\n使用令牌提交单曲成绩。",
|
||||||
|
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
|
||||||
)
|
)
|
||||||
async def submit_solo_score(
|
async def submit_solo_score(
|
||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
@@ -520,6 +567,7 @@ async def create_playlist_score(
|
|||||||
tags=["游玩"],
|
tags=["游玩"],
|
||||||
name="提交房间项目成绩",
|
name="提交房间项目成绩",
|
||||||
description="\n提交房间游玩项目成绩。",
|
description="\n提交房间游玩项目成绩。",
|
||||||
|
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
|
||||||
)
|
)
|
||||||
async def submit_playlist_score(
|
async def submit_playlist_score(
|
||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
@@ -560,13 +608,13 @@ async def submit_playlist_score(
|
|||||||
room_id,
|
room_id,
|
||||||
playlist_id,
|
playlist_id,
|
||||||
user_id,
|
user_id,
|
||||||
score_resp.id,
|
score_resp["id"],
|
||||||
score_resp.total_score,
|
score_resp["total_score"],
|
||||||
session,
|
session,
|
||||||
redis,
|
redis,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed:
|
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]:
|
||||||
await process_daily_challenge_score(session, user_id, room_id)
|
await process_daily_challenge_score(session, user_id, room_id)
|
||||||
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -575,15 +623,23 @@ async def submit_playlist_score(
|
|||||||
|
|
||||||
class IndexedScoreResp(MultiplayerScores):
|
class IndexedScoreResp(MultiplayerScores):
|
||||||
total: int
|
total: int
|
||||||
user_score: ScoreResp | None = None
|
user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rooms/{room_id}/playlist/{playlist_id}/scores",
|
"/rooms/{room_id}/playlist/{playlist_id}/scores",
|
||||||
response_model=IndexedScoreResp,
|
# response_model=IndexedScoreResp,
|
||||||
name="获取房间项目排行榜",
|
name="获取房间项目排行榜",
|
||||||
description="获取房间游玩项目排行榜。",
|
description="获取房间游玩项目排行榜。",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": (
|
||||||
|
f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}"
|
||||||
|
),
|
||||||
|
"model": IndexedScoreResp,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def index_playlist_scores(
|
async def index_playlist_scores(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -620,16 +676,14 @@ async def index_playlist_scores(
|
|||||||
scores = scores[:-1]
|
scores = scores[:-1]
|
||||||
|
|
||||||
user_score = None
|
user_score = None
|
||||||
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
|
score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores]
|
||||||
for score in score_resp:
|
for score in score_resp:
|
||||||
score.position = await get_position(room_id, playlist_id, score.id, session)
|
if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[
|
||||||
if score.user_id == user_id:
|
"user_id"
|
||||||
|
] == user_id:
|
||||||
user_score = score
|
user_score = score
|
||||||
|
user_score["position"] = await get_position(room_id, playlist_id, score["id"], session)
|
||||||
if room.category == RoomCategory.DAILY_CHALLENGE:
|
break
|
||||||
score_resp = [s for s in score_resp if s.passed]
|
|
||||||
if user_score and not user_score.passed:
|
|
||||||
user_score = None
|
|
||||||
|
|
||||||
resp = IndexedScoreResp(
|
resp = IndexedScoreResp(
|
||||||
scores=score_resp,
|
scores=score_resp,
|
||||||
@@ -648,10 +702,16 @@ async def index_playlist_scores(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
||||||
response_model=ScoreResp,
|
|
||||||
name="获取房间项目单个成绩",
|
name="获取房间项目单个成绩",
|
||||||
description="获取指定房间游玩项目中单个成绩详情。",
|
description="获取指定房间游玩项目中单个成绩详情。",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间项目单个成绩详情。",
|
||||||
|
ScoreModel,
|
||||||
|
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def show_playlist_score(
|
async def show_playlist_score(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -687,39 +747,25 @@ async def show_playlist_score(
|
|||||||
break
|
break
|
||||||
if not score_record:
|
if not score_record:
|
||||||
raise HTTPException(status_code=404, detail="Score not found")
|
raise HTTPException(status_code=404, detail="Score not found")
|
||||||
resp = await ScoreResp.from_db(session, score_record.score)
|
includes = [
|
||||||
resp.position = await get_position(room_id, playlist_id, score_id, session)
|
*Score.MULTIPLAYER_BASE_INCLUDES,
|
||||||
|
"position",
|
||||||
|
]
|
||||||
if completed:
|
if completed:
|
||||||
scores = (
|
includes.append("scores_around")
|
||||||
await session.exec(
|
resp = await ScoreModel.transform(score_record.score, includes=includes)
|
||||||
select(PlaylistBestScore).where(
|
|
||||||
PlaylistBestScore.playlist_id == playlist_id,
|
|
||||||
PlaylistBestScore.room_id == room_id,
|
|
||||||
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
higher_scores = []
|
|
||||||
lower_scores = []
|
|
||||||
for score in scores:
|
|
||||||
resp = await ScoreResp.from_db(session, score.score)
|
|
||||||
if is_playlist and not resp.passed:
|
|
||||||
continue
|
|
||||||
if score.total_score > resp.total_score:
|
|
||||||
higher_scores.append(resp)
|
|
||||||
elif score.total_score < resp.total_score:
|
|
||||||
lower_scores.append(resp)
|
|
||||||
resp.scores_around = ScoreAround(
|
|
||||||
higher=MultiplayerScores(scores=higher_scores),
|
|
||||||
lower=MultiplayerScores(scores=lower_scores),
|
|
||||||
)
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
||||||
response_model=ScoreResp,
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"房间项目单个成绩详情。",
|
||||||
|
ScoreModel,
|
||||||
|
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
|
||||||
|
)
|
||||||
|
},
|
||||||
name="获取房间项目用户成绩",
|
name="获取房间项目用户成绩",
|
||||||
description="获取指定用户在房间游玩项目中的成绩。",
|
description="获取指定用户在房间游玩项目中的成绩。",
|
||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
@@ -749,8 +795,14 @@ async def get_user_playlist_score(
|
|||||||
if not score_record:
|
if not score_record:
|
||||||
raise HTTPException(status_code=404, detail="Score not found")
|
raise HTTPException(status_code=404, detail="Score not found")
|
||||||
|
|
||||||
resp = await ScoreResp.from_db(session, score_record.score)
|
resp = await ScoreModel.transform(
|
||||||
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
|
score_record.score,
|
||||||
|
includes=[
|
||||||
|
*Score.MULTIPLAYER_BASE_INCLUDES,
|
||||||
|
"position",
|
||||||
|
"scores_around",
|
||||||
|
],
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,17 +5,16 @@ from app.config import settings
|
|||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import (
|
from app.database import (
|
||||||
Beatmap,
|
Beatmap,
|
||||||
|
BeatmapModel,
|
||||||
BeatmapPlaycounts,
|
BeatmapPlaycounts,
|
||||||
BeatmapPlaycountsResp,
|
BeatmapsetModel,
|
||||||
BeatmapResp,
|
|
||||||
BeatmapsetResp,
|
|
||||||
User,
|
User,
|
||||||
UserResp,
|
|
||||||
)
|
)
|
||||||
|
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
|
||||||
from app.database.best_scores import BestScore
|
from app.database.best_scores import BestScore
|
||||||
from app.database.events import Event
|
from app.database.events import Event
|
||||||
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
|
from app.database.score import Score, get_user_first_scores
|
||||||
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
|
from app.database.user import UserModel
|
||||||
from app.dependencies.api_version import APIVersion
|
from app.dependencies.api_version import APIVersion
|
||||||
from app.dependencies.cache import UserCacheService
|
from app.dependencies.cache import UserCacheService
|
||||||
from app.dependencies.database import Database, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
|
|||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.models.user import BeatmapsetType
|
from app.models.user import BeatmapsetType
|
||||||
from app.service.user_cache_service import get_user_cache_service
|
from app.service.user_cache_service import get_user_cache_service
|
||||||
from app.utils import utcnow
|
from app.utils import api_doc, utcnow
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
|
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import exists, false, select
|
from sqlmodel import exists, false, select
|
||||||
from sqlmodel.sql.expression import col
|
from sqlmodel.sql.expression import col
|
||||||
|
|
||||||
|
|
||||||
class BatchUserResponse(BaseModel):
|
|
||||||
users: list[UserResp]
|
|
||||||
|
|
||||||
|
|
||||||
class BeatmapsPassedResponse(BaseModel):
|
|
||||||
beatmaps_passed: list[BeatmapResp]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_difficulty_reduction_mods() -> set[str]:
|
def _get_difficulty_reduction_mods() -> set[str]:
|
||||||
mods: set[str] = set()
|
mods: set[str] = set()
|
||||||
for ruleset_mods in API_MODS.values():
|
for ruleset_mods in API_MODS.values():
|
||||||
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/",
|
"/users/",
|
||||||
response_model=BatchUserResponse,
|
responses={
|
||||||
|
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
|
||||||
|
},
|
||||||
name="批量获取用户信息",
|
name="批量获取用户信息",
|
||||||
description="通过用户 ID 列表批量获取用户信息。",
|
description="通过用户 ID 列表批量获取用户信息。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
|
@router.get("/users/lookup", include_in_schema=False)
|
||||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
@router.get("/users/lookup/", include_in_schema=False)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
async def get_users(
|
async def get_users(
|
||||||
session: Database,
|
session: Database,
|
||||||
@@ -108,16 +100,15 @@ async def get_users(
|
|||||||
# 将查询到的用户添加到缓存并返回
|
# 将查询到的用户添加到缓存并返回
|
||||||
for searched_user in searched_users:
|
for searched_user in searched_users:
|
||||||
if searched_user.id != BANCHOBOT_ID:
|
if searched_user.id != BANCHOBOT_ID:
|
||||||
user_resp = await UserResp.from_db(
|
user_resp = await UserModel.transform(
|
||||||
searched_user,
|
searched_user,
|
||||||
session,
|
includes=User.CARD_INCLUDES,
|
||||||
include=SEARCH_INCLUDED,
|
|
||||||
)
|
)
|
||||||
cached_users.append(user_resp)
|
cached_users.append(user_resp)
|
||||||
# 异步缓存,不阻塞响应
|
# 异步缓存,不阻塞响应
|
||||||
background_task.add_task(cache_service.cache_user, user_resp)
|
background_task.add_task(cache_service.cache_user, user_resp)
|
||||||
|
|
||||||
response = BatchUserResponse(users=cached_users)
|
response = {"users": cached_users}
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
searched_users = (
|
searched_users = (
|
||||||
@@ -127,16 +118,15 @@ async def get_users(
|
|||||||
for searched_user in searched_users:
|
for searched_user in searched_users:
|
||||||
if searched_user.id == BANCHOBOT_ID:
|
if searched_user.id == BANCHOBOT_ID:
|
||||||
continue
|
continue
|
||||||
user_resp = await UserResp.from_db(
|
user_resp = await UserModel.transform(
|
||||||
searched_user,
|
searched_user,
|
||||||
session,
|
includes=User.CARD_INCLUDES,
|
||||||
include=SEARCH_INCLUDED,
|
|
||||||
)
|
)
|
||||||
users.append(user_resp)
|
users.append(user_resp)
|
||||||
# 异步缓存
|
# 异步缓存
|
||||||
background_task.add_task(cache_service.cache_user, user_resp)
|
background_task.add_task(cache_service.cache_user, user_resp)
|
||||||
|
|
||||||
response = BatchUserResponse(users=users)
|
response = {"users": users}
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -200,10 +190,12 @@ async def get_user_kudosu(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/{user_id}/beatmaps-passed",
|
"/users/{user_id}/beatmaps-passed",
|
||||||
response_model=BeatmapsPassedResponse,
|
|
||||||
name="获取用户已通过谱面",
|
name="获取用户已通过谱面",
|
||||||
description="获取指定用户在给定谱面集中的已通过谱面列表。",
|
description="获取指定用户在给定谱面集中的已通过谱面列表。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
|
responses={
|
||||||
|
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
|
||||||
|
},
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
async def get_user_beatmaps_passed(
|
async def get_user_beatmaps_passed(
|
||||||
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
|
|||||||
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
|
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
|
||||||
):
|
):
|
||||||
if not beatmapset_ids:
|
if not beatmapset_ids:
|
||||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
return {"beatmaps_passed": []}
|
||||||
if len(beatmapset_ids) > 50:
|
if len(beatmapset_ids) > 50:
|
||||||
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
|
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
|
||||||
|
|
||||||
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
|
|||||||
|
|
||||||
scores = (await session.exec(score_query)).all()
|
scores = (await session.exec(score_query)).all()
|
||||||
if not scores:
|
if not scores:
|
||||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
return {"beatmaps_passed": []}
|
||||||
|
|
||||||
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
|
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
|
||||||
passed_beatmap_ids: set[int] = set()
|
passed_beatmap_ids: set[int] = set()
|
||||||
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
|
|||||||
continue
|
continue
|
||||||
passed_beatmap_ids.add(beatmap_id)
|
passed_beatmap_ids.add(beatmap_id)
|
||||||
if not passed_beatmap_ids:
|
if not passed_beatmap_ids:
|
||||||
return BeatmapsPassedResponse(beatmaps_passed=[])
|
return {"beatmaps_passed": []}
|
||||||
|
|
||||||
beatmaps = (
|
beatmaps = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
|
|||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
return BeatmapsPassedResponse(
|
return {
|
||||||
beatmaps_passed=[
|
"beatmaps_passed": [
|
||||||
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
|
await BeatmapModel.transform(
|
||||||
|
beatmap,
|
||||||
|
)
|
||||||
|
for beatmap in beatmaps
|
||||||
]
|
]
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/{user_id}/{ruleset}",
|
"/users/{user_id}/{ruleset}",
|
||||||
response_model=UserResp,
|
|
||||||
name="获取用户信息(指定ruleset)",
|
name="获取用户信息(指定ruleset)",
|
||||||
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
|
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
|
responses={
|
||||||
|
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
async def get_user_info_ruleset(
|
async def get_user_info_ruleset(
|
||||||
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
|
|||||||
if should_not_show:
|
if should_not_show:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
|
|
||||||
include = SEARCH_INCLUDED
|
user_resp = await UserModel.transform(
|
||||||
if searched_is_self:
|
|
||||||
include = ALL_INCLUDED
|
|
||||||
user_resp = await UserResp.from_db(
|
|
||||||
searched_user,
|
searched_user,
|
||||||
session,
|
includes=User.USER_INCLUDES,
|
||||||
include=include,
|
|
||||||
ruleset=ruleset,
|
ruleset=ruleset,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 异步缓存结果
|
# 异步缓存结果
|
||||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||||
|
|
||||||
return user_resp
|
return user_resp
|
||||||
|
|
||||||
|
|
||||||
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
|
@router.get("/users/{user_id}/", include_in_schema=False)
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/{user_id}",
|
"/users/{user_id}",
|
||||||
response_model=UserResp,
|
|
||||||
name="获取用户信息",
|
name="获取用户信息",
|
||||||
description="通过用户 ID 或用户名获取单个用户的详细信息。",
|
description="通过用户 ID 或用户名获取单个用户的详细信息。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
|
responses={
|
||||||
|
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
async def get_user_info(
|
async def get_user_info(
|
||||||
@@ -381,27 +375,31 @@ async def get_user_info(
|
|||||||
if should_not_show:
|
if should_not_show:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
|
|
||||||
include = SEARCH_INCLUDED
|
user_resp = await UserModel.transform(
|
||||||
if searched_is_self:
|
|
||||||
include = ALL_INCLUDED
|
|
||||||
user_resp = await UserResp.from_db(
|
|
||||||
searched_user,
|
searched_user,
|
||||||
session,
|
includes=User.USER_INCLUDES,
|
||||||
include=include,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 异步缓存结果
|
# 异步缓存结果
|
||||||
background_task.add_task(cache_service.cache_user, user_resp)
|
background_task.add_task(cache_service.cache_user, user_resp)
|
||||||
|
|
||||||
return user_resp
|
return user_resp
|
||||||
|
|
||||||
|
|
||||||
|
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/{user_id}/beatmapsets/{type}",
|
"/users/{user_id}/beatmapsets/{type}",
|
||||||
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
|
|
||||||
name="获取用户谱面集列表",
|
name="获取用户谱面集列表",
|
||||||
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
|
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
|
responses={
|
||||||
|
200: api_doc(
|
||||||
|
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
|
||||||
|
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
|
||||||
|
beatmapset_includes,
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
@asset_proxy_response
|
@asset_proxy_response
|
||||||
async def get_user_beatmapsets(
|
async def get_user_beatmapsets(
|
||||||
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
|
|||||||
# 先尝试从缓存获取
|
# 先尝试从缓存获取
|
||||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
# 根据类型恢复对象
|
return cached_result
|
||||||
if type == BeatmapsetType.MOST_PLAYED:
|
|
||||||
return [BeatmapPlaycountsResp(**item) for item in cached_result]
|
|
||||||
else:
|
|
||||||
return [BeatmapsetResp(**item) for item in cached_result]
|
|
||||||
|
|
||||||
user = await session.get(User, user_id)
|
user = await session.get(User, user_id)
|
||||||
if not user or user.id == BANCHOBOT_ID:
|
if not user or user.id == BANCHOBOT_ID:
|
||||||
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
|
|||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||||
resp = [
|
resp = [
|
||||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
await BeatmapsetModel.transform(
|
||||||
|
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
|
||||||
|
)
|
||||||
|
for favourite in favourites
|
||||||
]
|
]
|
||||||
|
|
||||||
elif type == BeatmapsetType.MOST_PLAYED:
|
elif type == BeatmapsetType.MOST_PLAYED:
|
||||||
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
|
|||||||
.limit(limit)
|
.limit(limit)
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
)
|
)
|
||||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
resp = [
|
||||||
|
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
|
||||||
|
for most_played_beatmap in most_played
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||||
|
|
||||||
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users/{user_id}/scores/{type}",
|
"/users/{user_id}/scores/{type}",
|
||||||
response_model=list[ScoreResp] | list[LegacyScoreResp],
|
|
||||||
name="获取用户成绩列表",
|
name="获取用户成绩列表",
|
||||||
description=(
|
description=(
|
||||||
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
|
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
|
||||||
@@ -523,6 +522,7 @@ async def get_user_scores(
|
|||||||
gamemode = mode or db_user.playmode
|
gamemode = mode or db_user.playmode
|
||||||
order_by = None
|
order_by = None
|
||||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||||
|
includes = Score.USER_PROFILE_INCLUDES.copy()
|
||||||
if not include_fails:
|
if not include_fails:
|
||||||
where_clause &= col(Score.passed).is_(True)
|
where_clause &= col(Score.passed).is_(True)
|
||||||
if type == "pinned":
|
if type == "pinned":
|
||||||
@@ -531,6 +531,7 @@ async def get_user_scores(
|
|||||||
elif type == "best":
|
elif type == "best":
|
||||||
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
|
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
|
||||||
order_by = col(Score.pp).desc()
|
order_by = col(Score.pp).desc()
|
||||||
|
includes.append("weight")
|
||||||
elif type == "recent":
|
elif type == "recent":
|
||||||
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
|
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
|
||||||
order_by = col(Score.ended_at).desc()
|
order_by = col(Score.ended_at).desc()
|
||||||
@@ -551,6 +552,7 @@ async def get_user_scores(
|
|||||||
await score.to_resp(
|
await score.to_resp(
|
||||||
session,
|
session,
|
||||||
api_version,
|
api_version,
|
||||||
|
includes=includes,
|
||||||
)
|
)
|
||||||
for score in scores
|
for score in scores
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ Beatmapset缓存服务
|
|||||||
用于缓存beatmapset数据,减少数据库查询频率
|
用于缓存beatmapset数据,减少数据库查询频率
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.beatmapset import BeatmapsetResp
|
from app.database import BeatmapsetDict
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
|
from app.utils import safe_json_dumps
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
@@ -18,20 +18,6 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
|
||||||
"""处理datetime序列化的JSON编码器"""
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
return super().default(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_json_dumps(data) -> str:
|
|
||||||
"""安全的JSON序列化,处理datetime对象"""
|
|
||||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_hash(data) -> str:
|
def generate_hash(data) -> str:
|
||||||
"""生成数据的MD5哈希值"""
|
"""生成数据的MD5哈希值"""
|
||||||
content = data if isinstance(data, str) else safe_json_dumps(data)
|
content = data if isinstance(data, str) else safe_json_dumps(data)
|
||||||
@@ -57,15 +43,14 @@ class BeatmapsetCacheService:
|
|||||||
"""生成搜索结果缓存键"""
|
"""生成搜索结果缓存键"""
|
||||||
return f"beatmapset_search:{query_hash}:{cursor_hash}"
|
return f"beatmapset_search:{query_hash}:{cursor_hash}"
|
||||||
|
|
||||||
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None:
|
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None:
|
||||||
"""从缓存获取beatmapset信息"""
|
"""从缓存获取beatmapset信息"""
|
||||||
try:
|
try:
|
||||||
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
|
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
|
||||||
cached_data = await self.redis.get(cache_key)
|
cached_data = await self.redis.get(cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
|
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
|
||||||
data = json.loads(cached_data)
|
return json.loads(cached_data)
|
||||||
return BeatmapsetResp(**data)
|
|
||||||
return None
|
return None
|
||||||
except (ValueError, TypeError, AttributeError) as e:
|
except (ValueError, TypeError, AttributeError) as e:
|
||||||
logger.error(f"Error getting beatmapset from cache: {e}")
|
logger.error(f"Error getting beatmapset from cache: {e}")
|
||||||
@@ -73,24 +58,21 @@ class BeatmapsetCacheService:
|
|||||||
|
|
||||||
async def cache_beatmapset(
|
async def cache_beatmapset(
|
||||||
self,
|
self,
|
||||||
beatmapset_resp: BeatmapsetResp,
|
beatmapset_resp: BeatmapsetDict,
|
||||||
expire_seconds: int | None = None,
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存beatmapset信息"""
|
"""缓存beatmapset信息"""
|
||||||
try:
|
try:
|
||||||
if expire_seconds is None:
|
if expire_seconds is None:
|
||||||
expire_seconds = self._default_ttl
|
expire_seconds = self._default_ttl
|
||||||
if beatmapset_resp.id is None:
|
cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"])
|
||||||
logger.warning("Cannot cache beatmapset with None id")
|
cached_data = safe_json_dumps(beatmapset_resp)
|
||||||
return
|
|
||||||
cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id)
|
|
||||||
cached_data = beatmapset_resp.model_dump_json()
|
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
||||||
logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s")
|
logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s")
|
||||||
except (ValueError, TypeError, AttributeError) as e:
|
except (ValueError, TypeError, AttributeError) as e:
|
||||||
logger.error(f"Error caching beatmapset: {e}")
|
logger.error(f"Error caching beatmapset: {e}")
|
||||||
|
|
||||||
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None:
|
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None:
|
||||||
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
|
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
|
||||||
try:
|
try:
|
||||||
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
||||||
@@ -98,7 +80,7 @@ class BeatmapsetCacheService:
|
|||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
|
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
|
||||||
data = json.loads(cached_data)
|
data = json.loads(cached_data)
|
||||||
return BeatmapsetResp(**data)
|
return data
|
||||||
return None
|
return None
|
||||||
except (ValueError, TypeError, AttributeError) as e:
|
except (ValueError, TypeError, AttributeError) as e:
|
||||||
logger.error(f"Error getting beatmap lookup from cache: {e}")
|
logger.error(f"Error getting beatmap lookup from cache: {e}")
|
||||||
@@ -107,7 +89,7 @@ class BeatmapsetCacheService:
|
|||||||
async def cache_beatmap_lookup(
|
async def cache_beatmap_lookup(
|
||||||
self,
|
self,
|
||||||
beatmap_id: int,
|
beatmap_id: int,
|
||||||
beatmapset_resp: BeatmapsetResp,
|
beatmapset_resp: BeatmapsetDict,
|
||||||
expire_seconds: int | None = None,
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存通过beatmap ID查找的beatmapset信息"""
|
"""缓存通过beatmap ID查找的beatmapset信息"""
|
||||||
@@ -115,7 +97,7 @@ class BeatmapsetCacheService:
|
|||||||
if expire_seconds is None:
|
if expire_seconds is None:
|
||||||
expire_seconds = self._default_ttl
|
expire_seconds = self._default_ttl
|
||||||
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
||||||
cached_data = beatmapset_resp.model_dump_json()
|
cached_data = safe_json_dumps(beatmapset_resp)
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
||||||
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
|
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
|
||||||
except (ValueError, TypeError, AttributeError) as e:
|
except (ValueError, TypeError, AttributeError) as e:
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ from datetime import timedelta
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, NamedTuple
|
from typing import TYPE_CHECKING, NamedTuple, cast
|
||||||
|
|
||||||
from app.config import OldScoreProcessingMode, settings
|
from app.config import OldScoreProcessingMode, settings
|
||||||
from app.database.beatmap import Beatmap, BeatmapResp
|
from app.database.beatmap import Beatmap, BeatmapDict
|
||||||
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
|
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
|
||||||
from app.database.beatmapset import Beatmapset, BeatmapsetResp
|
from app.database.beatmapset import Beatmapset, BeatmapsetDict
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import get_redis, with_db
|
from app.dependencies.database import get_redis, with_db
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
@@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = {
|
|||||||
SCHEDULER_INTERVAL_MINUTES = 2
|
SCHEDULER_INTERVAL_MINUTES = 2
|
||||||
|
|
||||||
|
|
||||||
|
class EnsuredBeatmap(BeatmapDict):
|
||||||
|
checksum: str
|
||||||
|
ranked: int
|
||||||
|
|
||||||
|
|
||||||
|
class EnsuredBeatmapset(BeatmapsetDict):
|
||||||
|
ranked: int
|
||||||
|
ranked_date: datetime.datetime
|
||||||
|
last_updated: datetime.datetime
|
||||||
|
play_count: int
|
||||||
|
beatmaps: list[EnsuredBeatmap]
|
||||||
|
|
||||||
|
|
||||||
class ProcessingBeatmapset:
|
class ProcessingBeatmapset:
|
||||||
def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None:
|
def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None:
|
||||||
self.beatmapset = beatmapset
|
self.beatmapset = beatmapset
|
||||||
self.status = BeatmapRankStatus(self.beatmapset.ranked)
|
self.status = BeatmapRankStatus(self.beatmapset["ranked"])
|
||||||
self.record = record
|
self.record = record
|
||||||
|
|
||||||
def calculate_next_sync_time(
|
def calculate_next_sync_time(
|
||||||
@@ -76,19 +89,19 @@ class ProcessingBeatmapset:
|
|||||||
|
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
if self.status == BeatmapRankStatus.QUALIFIED:
|
if self.status == BeatmapRankStatus.QUALIFIED:
|
||||||
assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps"
|
assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps"
|
||||||
time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds()
|
time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds()
|
||||||
baseline = max(MIN_DELTA, time_to_ranked / 2)
|
baseline = max(MIN_DELTA, time_to_ranked / 2)
|
||||||
next_delta = max(MIN_DELTA, baseline)
|
next_delta = max(MIN_DELTA, baseline)
|
||||||
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
|
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
|
||||||
seconds_since_update = (now - self.beatmapset.last_updated).total_seconds()
|
seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds()
|
||||||
factor_update = max(1.0, seconds_since_update / TAU)
|
factor_update = max(1.0, seconds_since_update / TAU)
|
||||||
factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count)
|
factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"])
|
||||||
status_factor = STATUS_FACTOR[self.status]
|
status_factor = STATUS_FACTOR[self.status]
|
||||||
baseline = BASE * factor_play / factor_update * status_factor
|
baseline = BASE * factor_play / factor_update * status_factor
|
||||||
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
|
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
|
||||||
elif self.status == BeatmapRankStatus.GRAVEYARD:
|
elif self.status == BeatmapRankStatus.GRAVEYARD:
|
||||||
days_since_update = (now - self.beatmapset.last_updated).days
|
days_since_update = (now - self.beatmapset["last_updated"]).days
|
||||||
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
|
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
|
||||||
delta = MIN_DELTA * (2**doubling_periods)
|
delta = MIN_DELTA * (2**doubling_periods)
|
||||||
max_seconds = GRAVEYARD_MAX_DAYS * 86400
|
max_seconds = GRAVEYARD_MAX_DAYS * 86400
|
||||||
@@ -105,21 +118,24 @@ class ProcessingBeatmapset:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def beatmapset_changed(self) -> bool:
|
def beatmapset_changed(self) -> bool:
|
||||||
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked)
|
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def changed_beatmaps(self) -> list[ChangedBeatmap]:
|
def changed_beatmaps(self) -> list[ChangedBeatmap]:
|
||||||
changed_beatmaps = []
|
changed_beatmaps = []
|
||||||
for bm in self.beatmapset.beatmaps:
|
for bm in self.beatmapset["beatmaps"]:
|
||||||
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
|
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None)
|
||||||
if not saved or saved["is_deleted"]:
|
if not saved or saved["is_deleted"]:
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED))
|
||||||
elif saved["md5"] != bm.checksum:
|
elif saved["md5"] != bm["checksum"]:
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
|
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED))
|
||||||
elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked):
|
elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]):
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED))
|
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED))
|
||||||
for saved in self.record.beatmaps:
|
for saved in self.record.beatmaps:
|
||||||
if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]:
|
if (
|
||||||
|
not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"])
|
||||||
|
and not saved["is_deleted"]
|
||||||
|
):
|
||||||
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
|
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
|
||||||
return changed_beatmaps
|
return changed_beatmaps
|
||||||
|
|
||||||
@@ -132,7 +148,7 @@ class BeatmapsetUpdateService:
|
|||||||
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
|
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
|
||||||
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
|
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
|
||||||
if immediate:
|
if immediate:
|
||||||
await self._sync_immediately(beatmapset)
|
await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset))
|
||||||
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
|
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
|
||||||
return True
|
return True
|
||||||
await self.add(beatmapset)
|
await self.add(beatmapset)
|
||||||
@@ -172,7 +188,7 @@ class BeatmapsetUpdateService:
|
|||||||
BeatmapSync(
|
BeatmapSync(
|
||||||
beatmapset_id=missing,
|
beatmapset_id=missing,
|
||||||
beatmap_status=BeatmapRankStatus.GRAVEYARD,
|
beatmap_status=BeatmapRankStatus.GRAVEYARD,
|
||||||
next_sync_time=datetime.datetime.max,
|
next_sync_time=datetime.datetime(year=6000, month=1, day=1),
|
||||||
beatmaps=[],
|
beatmaps=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -185,11 +201,13 @@ class BeatmapsetUpdateService:
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
self._adding_missing = False
|
self._adding_missing = False
|
||||||
|
|
||||||
async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True):
|
async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True):
|
||||||
|
beatmapset = cast(EnsuredBeatmapset, set)
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
sync_record = await session.get(BeatmapSync, beatmapset.id)
|
beatmapset_id = beatmapset["id"]
|
||||||
|
sync_record = await session.get(BeatmapSync, beatmapset_id)
|
||||||
if not sync_record:
|
if not sync_record:
|
||||||
database_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
database_beatmapset = await session.get(Beatmapset, beatmapset_id)
|
||||||
if database_beatmapset:
|
if database_beatmapset:
|
||||||
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
|
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
|
||||||
await database_beatmapset.awaitable_attrs.beatmaps
|
await database_beatmapset.awaitable_attrs.beatmaps
|
||||||
@@ -203,19 +221,29 @@ class BeatmapsetUpdateService:
|
|||||||
for bm in database_beatmapset.beatmaps
|
for bm in database_beatmapset.beatmaps
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
status = BeatmapRankStatus(beatmapset.ranked)
|
ranked = beatmapset.get("ranked")
|
||||||
beatmaps = [
|
if ranked is None:
|
||||||
SavedBeatmapMeta(
|
raise ValueError("ranked field is required")
|
||||||
beatmap_id=bm.id,
|
status = BeatmapRankStatus(ranked)
|
||||||
md5=bm.checksum,
|
beatmap_list = beatmapset.get("beatmaps", [])
|
||||||
is_deleted=False,
|
beatmaps = []
|
||||||
beatmap_status=BeatmapRankStatus(bm.ranked),
|
for bm in beatmap_list:
|
||||||
|
bm_id = bm.get("id")
|
||||||
|
checksum = bm.get("checksum")
|
||||||
|
ranked = bm.get("ranked")
|
||||||
|
if bm_id is None or checksum is None or ranked is None:
|
||||||
|
continue
|
||||||
|
beatmaps.append(
|
||||||
|
SavedBeatmapMeta(
|
||||||
|
beatmap_id=bm_id,
|
||||||
|
md5=checksum,
|
||||||
|
is_deleted=False,
|
||||||
|
beatmap_status=BeatmapRankStatus(ranked),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for bm in beatmapset.beatmaps
|
|
||||||
]
|
|
||||||
|
|
||||||
sync_record = BeatmapSync(
|
sync_record = BeatmapSync(
|
||||||
beatmapset_id=beatmapset.id,
|
beatmapset_id=beatmapset_id,
|
||||||
beatmaps=beatmaps,
|
beatmaps=beatmaps,
|
||||||
beatmap_status=status,
|
beatmap_status=status,
|
||||||
)
|
)
|
||||||
@@ -223,13 +251,27 @@ class BeatmapsetUpdateService:
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(sync_record)
|
await session.refresh(sync_record)
|
||||||
else:
|
else:
|
||||||
sync_record.beatmaps = [
|
ranked = beatmapset.get("ranked")
|
||||||
SavedBeatmapMeta(
|
if ranked is None:
|
||||||
beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked)
|
raise ValueError("ranked field is required")
|
||||||
|
beatmap_list = beatmapset.get("beatmaps", [])
|
||||||
|
beatmaps = []
|
||||||
|
for bm in beatmap_list:
|
||||||
|
bm_id = bm.get("id")
|
||||||
|
checksum = bm.get("checksum")
|
||||||
|
bm_ranked = bm.get("ranked")
|
||||||
|
if bm_id is None or checksum is None or bm_ranked is None:
|
||||||
|
continue
|
||||||
|
beatmaps.append(
|
||||||
|
SavedBeatmapMeta(
|
||||||
|
beatmap_id=bm_id,
|
||||||
|
md5=checksum,
|
||||||
|
is_deleted=False,
|
||||||
|
beatmap_status=BeatmapRankStatus(bm_ranked),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for bm in beatmapset.beatmaps
|
sync_record.beatmaps = beatmaps
|
||||||
]
|
sync_record.beatmap_status = BeatmapRankStatus(ranked)
|
||||||
sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
|
|
||||||
if calculate_next_sync:
|
if calculate_next_sync:
|
||||||
processing = ProcessingBeatmapset(beatmapset, sync_record)
|
processing = ProcessingBeatmapset(beatmapset, sync_record)
|
||||||
next_time_delta = processing.calculate_next_sync_time()
|
next_time_delta = processing.calculate_next_sync_time()
|
||||||
@@ -238,17 +280,19 @@ class BeatmapsetUpdateService:
|
|||||||
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
|
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
|
||||||
return
|
return
|
||||||
sync_record.next_sync_time = utcnow() + next_time_delta
|
sync_record.next_sync_time = utcnow() + next_time_delta
|
||||||
logger.opt(colors=True).info(f"<g>[{beatmapset.id}]</g> next sync at {sync_record.next_sync_time}")
|
beatmapset_id = beatmapset.get("id")
|
||||||
|
if beatmapset_id:
|
||||||
|
logger.opt(colors=True).debug(f"<g>[{beatmapset_id}]</g> next sync at {sync_record.next_sync_time}")
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None:
|
async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None:
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
record = await session.get(BeatmapSync, beatmapset.id)
|
record = await session.get(BeatmapSync, beatmapset["id"])
|
||||||
if not record:
|
if not record:
|
||||||
record = BeatmapSync(
|
record = BeatmapSync(
|
||||||
beatmapset_id=beatmapset.id,
|
beatmapset_id=beatmapset["id"],
|
||||||
beatmaps=[],
|
beatmaps=[],
|
||||||
beatmap_status=BeatmapRankStatus(beatmapset.ranked),
|
beatmap_status=BeatmapRankStatus(beatmapset["ranked"]),
|
||||||
)
|
)
|
||||||
session.add(record)
|
session.add(record)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -261,19 +305,18 @@ class BeatmapsetUpdateService:
|
|||||||
record: BeatmapSync,
|
record: BeatmapSync,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
beatmapset: BeatmapsetResp | None = None,
|
beatmapset: EnsuredBeatmapset | None = None,
|
||||||
):
|
):
|
||||||
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> syncing...")
|
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> syncing...")
|
||||||
if beatmapset is None:
|
if beatmapset is None:
|
||||||
try:
|
try:
|
||||||
beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id)
|
beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
|
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
|
||||||
logger.opt(colors=True).warning(
|
logger.opt(colors=True).warning(
|
||||||
f"<g>[{record.beatmapset_id}]</g> beatmapset not found (404), removing from sync list"
|
f"<g>[{record.beatmapset_id}]</g> beatmapset not found (404), removing from sync list"
|
||||||
)
|
)
|
||||||
await session.delete(record)
|
await session.delete(record)
|
||||||
await session.commit()
|
|
||||||
return
|
return
|
||||||
if isinstance(e, HTTPError):
|
if isinstance(e, HTTPError):
|
||||||
logger.opt(colors=True).warning(
|
logger.opt(colors=True).warning(
|
||||||
@@ -292,20 +335,20 @@ class BeatmapsetUpdateService:
|
|||||||
if changed:
|
if changed:
|
||||||
record.beatmaps = [
|
record.beatmaps = [
|
||||||
SavedBeatmapMeta(
|
SavedBeatmapMeta(
|
||||||
beatmap_id=bm.id,
|
beatmap_id=bm["id"],
|
||||||
md5=bm.checksum,
|
md5=bm["checksum"],
|
||||||
is_deleted=False,
|
is_deleted=False,
|
||||||
beatmap_status=BeatmapRankStatus(bm.ranked),
|
beatmap_status=BeatmapRankStatus(bm["ranked"]),
|
||||||
)
|
)
|
||||||
for bm in beatmapset.beatmaps
|
for bm in beatmapset["beatmaps"]
|
||||||
]
|
]
|
||||||
record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
|
record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"])
|
||||||
record.consecutive_no_change = 0
|
record.consecutive_no_change = 0
|
||||||
|
|
||||||
bg_tasks.add_task(
|
bg_tasks.add_task(
|
||||||
self._process_changed_beatmaps,
|
self._process_changed_beatmaps,
|
||||||
changed_beatmaps,
|
changed_beatmaps,
|
||||||
beatmapset.beatmaps,
|
beatmapset["beatmaps"],
|
||||||
)
|
)
|
||||||
bg_tasks.add_task(
|
bg_tasks.add_task(
|
||||||
self._process_changed_beatmapset,
|
self._process_changed_beatmapset,
|
||||||
@@ -317,13 +360,13 @@ class BeatmapsetUpdateService:
|
|||||||
next_time_delta = processing.calculate_next_sync_time()
|
next_time_delta = processing.calculate_next_sync_time()
|
||||||
if not next_time_delta:
|
if not next_time_delta:
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
f"<yellow>[{beatmapset.id}]</yellow> beatmapset has transformed to ranked or loved,"
|
f"<yellow>[{beatmapset['id']}]</yellow> beatmapset has transformed to ranked or loved,"
|
||||||
f" removing from sync list"
|
f" removing from sync list"
|
||||||
)
|
)
|
||||||
await session.delete(record)
|
await session.delete(record)
|
||||||
else:
|
else:
|
||||||
record.next_sync_time = utcnow() + next_time_delta
|
record.next_sync_time = utcnow() + next_time_delta
|
||||||
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
|
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
|
||||||
|
|
||||||
async def _update_beatmaps(self):
|
async def _update_beatmaps(self):
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
@@ -338,18 +381,18 @@ class BeatmapsetUpdateService:
|
|||||||
await self.sync(record, session)
|
await self.sync(record, session)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
|
async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset):
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
db_beatmapset = await session.get(Beatmapset, beatmapset["id"])
|
||||||
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
|
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType]
|
||||||
if db_beatmapset:
|
if db_beatmapset:
|
||||||
await session.merge(new_beatmapset)
|
await session.merge(new_beatmapset)
|
||||||
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id)
|
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"])
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]):
|
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]):
|
||||||
storage_service = get_storage_service()
|
storage_service = get_storage_service()
|
||||||
beatmaps = {bm.id: bm for bm in beatmaps_list}
|
beatmaps = {bm["id"]: bm for bm in beatmaps_list}
|
||||||
|
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
|
|
||||||
@@ -380,9 +423,9 @@ class BeatmapsetUpdateService:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
f"<g>[{beatmap.beatmapset_id}]</g> adding beatmap <blue>{beatmap.id}</blue>"
|
f"<g>[{beatmap['beatmapset_id']}]</g> adding beatmap <blue>{beatmap['id']}</blue>"
|
||||||
)
|
)
|
||||||
await Beatmap.from_resp_no_save(session, beatmap)
|
await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
|
||||||
else:
|
else:
|
||||||
beatmap = beatmaps.get(change.beatmap_id)
|
beatmap = beatmaps.get(change.beatmap_id)
|
||||||
if not beatmap:
|
if not beatmap:
|
||||||
@@ -391,10 +434,10 @@ class BeatmapsetUpdateService:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
f"<g>[{beatmap.beatmapset_id}]</g> processing beatmap <blue>{beatmap.id}</blue> "
|
f"<g>[{beatmap['beatmapset_id']}]</g> processing beatmap <blue>{beatmap['id']}</blue> "
|
||||||
f"change <cyan>{change.type}</cyan>"
|
f"change <cyan>{change.type}</cyan>"
|
||||||
)
|
)
|
||||||
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap)
|
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
|
||||||
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
|
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
|
||||||
if existing_beatmap:
|
if existing_beatmap:
|
||||||
await session.merge(new_db_beatmap)
|
await session.merge(new_db_beatmap)
|
||||||
|
|||||||
@@ -4,16 +4,15 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
from app.database.statistics import UserStatistics, UserStatisticsModel
|
||||||
from app.helpers.asset_proxy_helper import replace_asset_urls
|
from app.helpers.asset_proxy_helper import replace_asset_urls
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.utils import utcnow
|
from app.utils import safe_json_dumps, utcnow
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
@@ -23,20 +22,6 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
|
||||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
return super().default(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_json_dumps(data) -> str:
|
|
||||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
|
||||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
|
|
||||||
|
|
||||||
|
|
||||||
class RankingCacheService:
|
class RankingCacheService:
|
||||||
"""用户排行榜缓存服务"""
|
"""用户排行榜缓存服务"""
|
||||||
|
|
||||||
@@ -311,7 +296,7 @@ class RankingCacheService:
|
|||||||
col(UserStatistics.pp) > 0,
|
col(UserStatistics.pp) > 0,
|
||||||
col(UserStatistics.is_ranked).is_(True),
|
col(UserStatistics.is_ranked).is_(True),
|
||||||
]
|
]
|
||||||
include = ["user"]
|
include = UserStatistics.RANKING_INCLUDES.copy()
|
||||||
|
|
||||||
if type == "performance":
|
if type == "performance":
|
||||||
order_by = col(UserStatistics.pp).desc()
|
order_by = col(UserStatistics.pp).desc()
|
||||||
@@ -321,6 +306,7 @@ class RankingCacheService:
|
|||||||
|
|
||||||
if country:
|
if country:
|
||||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||||
|
include.append("country_rank")
|
||||||
|
|
||||||
# 获取总用户数用于统计
|
# 获取总用户数用于统计
|
||||||
total_users_query = select(UserStatistics).where(*wheres)
|
total_users_query = select(UserStatistics).where(*wheres)
|
||||||
@@ -353,9 +339,9 @@ class RankingCacheService:
|
|||||||
# 转换为响应格式并确保正确序列化
|
# 转换为响应格式并确保正确序列化
|
||||||
ranking_data = []
|
ranking_data = []
|
||||||
for statistics in statistics_data:
|
for statistics in statistics_data:
|
||||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include)
|
||||||
|
|
||||||
user_dict = user_stats_resp.model_dump()
|
user_dict = user_stats_resp
|
||||||
|
|
||||||
# 应用资源代理处理
|
# 应用资源代理处理
|
||||||
if settings.enable_asset_proxy:
|
if settings.enable_asset_proxy:
|
||||||
|
|||||||
@@ -8,14 +8,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
from app.database import ChatMessageDict
|
||||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
|
||||||
|
from app.database.user import User, UserModel
|
||||||
from app.dependencies.database import get_redis_message, with_db
|
from app.dependencies.database import get_redis_message, with_db
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.utils import bg_tasks
|
from app.utils import bg_tasks, safe_json_dumps
|
||||||
|
|
||||||
|
|
||||||
class RedisMessageSystem:
|
class RedisMessageSystem:
|
||||||
@@ -35,7 +35,7 @@ class RedisMessageSystem:
|
|||||||
content: str,
|
content: str,
|
||||||
is_action: bool = False,
|
is_action: bool = False,
|
||||||
user_uuid: str | None = None,
|
user_uuid: str | None = None,
|
||||||
) -> ChatMessageResp:
|
) -> "ChatMessageDict":
|
||||||
"""
|
"""
|
||||||
发送消息 - 立即存储到 Redis 并返回
|
发送消息 - 立即存储到 Redis 并返回
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class RedisMessageSystem:
|
|||||||
user_uuid: 用户UUID
|
user_uuid: 用户UUID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatMessageResp: 消息响应对象
|
ChatMessage: 消息响应对象
|
||||||
"""
|
"""
|
||||||
# 生成消息ID和时间戳
|
# 生成消息ID和时间戳
|
||||||
message_id = await self._generate_message_id(channel_id)
|
message_id = await self._generate_message_id(channel_id)
|
||||||
@@ -57,28 +57,16 @@ class RedisMessageSystem:
|
|||||||
if not user.id:
|
if not user.id:
|
||||||
raise ValueError("User ID is required")
|
raise ValueError("User ID is required")
|
||||||
|
|
||||||
# 获取频道类型以判断是否需要存储到数据库
|
|
||||||
async with with_db() as session:
|
|
||||||
from app.database.chat import ChannelType, ChatChannel
|
|
||||||
|
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
|
|
||||||
channel_type = channel_result.first()
|
|
||||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
|
||||||
|
|
||||||
# 准备消息数据
|
# 准备消息数据
|
||||||
message_data = {
|
message_data: "ChatMessageDict" = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
"sender_id": user.id,
|
"sender_id": user.id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"timestamp": timestamp.isoformat(),
|
"timestamp": timestamp,
|
||||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
|
||||||
"uuid": user_uuid or "",
|
"uuid": user_uuid or "",
|
||||||
"status": "cached", # Redis 缓存状态
|
"is_action": is_action,
|
||||||
"created_at": time.time(),
|
|
||||||
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 立即存储到 Redis
|
# 立即存储到 Redis
|
||||||
@@ -86,51 +74,13 @@ class RedisMessageSystem:
|
|||||||
|
|
||||||
# 创建响应对象
|
# 创建响应对象
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
|
||||||
|
message_data["sender"] = user_resp
|
||||||
|
|
||||||
# 确保 statistics 不为空
|
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||||
if user_resp.statistics is None:
|
return message_data
|
||||||
from app.database.statistics import UserStatisticsResp
|
|
||||||
|
|
||||||
user_resp.statistics = UserStatisticsResp(
|
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||||
mode=user.playmode,
|
|
||||||
global_rank=0,
|
|
||||||
country_rank=0,
|
|
||||||
pp=0.0,
|
|
||||||
ranked_score=0,
|
|
||||||
hit_accuracy=0.0,
|
|
||||||
play_count=0,
|
|
||||||
play_time=0,
|
|
||||||
total_score=0,
|
|
||||||
total_hits=0,
|
|
||||||
maximum_combo=0,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
|
||||||
level={"current": 1, "progress": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = ChatMessageResp(
|
|
||||||
message_id=message_id,
|
|
||||||
channel_id=channel_id,
|
|
||||||
content=content,
|
|
||||||
timestamp=timestamp,
|
|
||||||
sender_id=user.id,
|
|
||||||
sender=user_resp,
|
|
||||||
is_action=is_action,
|
|
||||||
uuid=user_uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_multiplayer:
|
|
||||||
logger.info(
|
|
||||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
|
|
||||||
" will not be persisted to database"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
|
|
||||||
"""
|
"""
|
||||||
获取频道消息 - 优先从 Redis 获取最新消息
|
获取频道消息 - 优先从 Redis 获取最新消息
|
||||||
|
|
||||||
@@ -140,9 +90,9 @@ class RedisMessageSystem:
|
|||||||
since: 起始消息ID
|
since: 起始消息ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[ChatMessageResp]: 消息列表
|
List[ChatMessageDict]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages = []
|
messages: list["ChatMessageDict"] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 从 Redis 获取最新消息
|
# 从 Redis 获取最新消息
|
||||||
@@ -154,45 +104,21 @@ class RedisMessageSystem:
|
|||||||
# 获取发送者信息
|
# 获取发送者信息
|
||||||
sender = await session.get(User, msg_data["sender_id"])
|
sender = await session.get(User, msg_data["sender_id"])
|
||||||
if sender:
|
if sender:
|
||||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
|
||||||
|
|
||||||
if user_resp.statistics is None:
|
from app.database.chat import ChatMessageDict
|
||||||
from app.database.statistics import UserStatisticsResp
|
|
||||||
|
|
||||||
user_resp.statistics = UserStatisticsResp(
|
message_resp: ChatMessageDict = {
|
||||||
mode=sender.playmode,
|
"message_id": msg_data["message_id"],
|
||||||
global_rank=0,
|
"channel_id": msg_data["channel_id"],
|
||||||
country_rank=0,
|
"content": msg_data["content"],
|
||||||
pp=0.0,
|
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
|
||||||
ranked_score=0,
|
"sender_id": msg_data["sender_id"],
|
||||||
hit_accuracy=0.0,
|
"sender": user_resp,
|
||||||
play_count=0,
|
"is_action": msg_data["type"] == MessageType.ACTION.value,
|
||||||
play_time=0,
|
"uuid": msg_data.get("uuid") or None,
|
||||||
total_score=0,
|
"type": MessageType(msg_data["type"]),
|
||||||
total_hits=0,
|
}
|
||||||
maximum_combo=0,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
grade_counts={
|
|
||||||
"ssh": 0,
|
|
||||||
"ss": 0,
|
|
||||||
"sh": 0,
|
|
||||||
"s": 0,
|
|
||||||
"a": 0,
|
|
||||||
},
|
|
||||||
level={"current": 1, "progress": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
message_resp = ChatMessageResp(
|
|
||||||
message_id=msg_data["message_id"],
|
|
||||||
channel_id=msg_data["channel_id"],
|
|
||||||
content=msg_data["content"],
|
|
||||||
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
|
|
||||||
sender_id=msg_data["sender_id"],
|
|
||||||
sender=user_resp,
|
|
||||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
|
||||||
uuid=msg_data.get("uuid") or None,
|
|
||||||
)
|
|
||||||
messages.append(message_resp)
|
messages.append(message_resp)
|
||||||
|
|
||||||
# 如果 Redis 消息不够,从数据库补充
|
# 如果 Redis 消息不够,从数据库补充
|
||||||
@@ -216,86 +142,46 @@ class RedisMessageSystem:
|
|||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
|
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
|
||||||
"""存储消息到 Redis"""
|
"""存储消息到 Redis"""
|
||||||
try:
|
try:
|
||||||
# 检查是否是多人房间消息
|
# 存储消息数据为 JSON 字符串
|
||||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
await self.redis.set(
|
||||||
|
|
||||||
# 存储消息数据
|
|
||||||
await self.redis.hset(
|
|
||||||
f"msg:{channel_id}:{message_id}",
|
f"msg:{channel_id}:{message_id}",
|
||||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
|
safe_json_dumps(message_data),
|
||||||
|
ex=604800, # 7天过期
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置消息过期时间(7天)
|
# 添加到频道消息列表(按时间排序)
|
||||||
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
|
|
||||||
|
|
||||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
|
||||||
channel_messages_key = f"channel:{channel_id}:messages"
|
channel_messages_key = f"channel:{channel_id}:messages"
|
||||||
|
|
||||||
# 更健壮的键类型检查和清理
|
# 检查并清理错误类型的键
|
||||||
try:
|
try:
|
||||||
key_type = await self.redis.type(channel_messages_key)
|
key_type = await self.redis.type(channel_messages_key)
|
||||||
if key_type == "none":
|
if key_type not in ("none", "zset"):
|
||||||
# 键不存在,这是正常的
|
|
||||||
pass
|
|
||||||
elif key_type != "zset":
|
|
||||||
# 键类型错误,需要清理
|
|
||||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||||
await self.redis.delete(channel_messages_key)
|
await self.redis.delete(channel_messages_key)
|
||||||
|
|
||||||
# 验证删除是否成功
|
|
||||||
verify_type = await self.redis.type(channel_messages_key)
|
|
||||||
if verify_type != "none":
|
|
||||||
logger.error(
|
|
||||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
|
||||||
)
|
|
||||||
# 强制删除
|
|
||||||
await self.redis.unlink(channel_messages_key)
|
|
||||||
|
|
||||||
except Exception as type_check_error:
|
except Exception as type_check_error:
|
||||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||||
# 如果检查失败,尝试强制删除键以确保清理
|
await self.redis.delete(channel_messages_key)
|
||||||
try:
|
|
||||||
await self.redis.delete(channel_messages_key)
|
|
||||||
except Exception:
|
|
||||||
# 最后的努力:使用unlink
|
|
||||||
try:
|
|
||||||
await self.redis.unlink(channel_messages_key)
|
|
||||||
except Exception as final_error:
|
|
||||||
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
|
|
||||||
|
|
||||||
# 添加到频道消息列表(sorted set)
|
# 添加到频道消息列表(sorted set)
|
||||||
try:
|
await self.redis.zadd(
|
||||||
await self.redis.zadd(
|
channel_messages_key,
|
||||||
channel_messages_key,
|
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
)
|
||||||
)
|
|
||||||
except Exception as zadd_error:
|
|
||||||
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
|
|
||||||
# 如果添加失败,再次尝试清理并重试
|
|
||||||
await self.redis.delete(channel_messages_key)
|
|
||||||
await self.redis.zadd(
|
|
||||||
channel_messages_key,
|
|
||||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保持频道消息列表大小(最多1000条)
|
# 保持频道消息列表大小(最多1000条)
|
||||||
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
|
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
|
||||||
|
|
||||||
# 只有非多人房间消息才添加到待持久化队列
|
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||||
if not is_multiplayer:
|
logger.debug(f"Message {message_id} added to persistence queue")
|
||||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
|
||||||
logger.debug(f"Message {message_id} added to persistence queue")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to store message to Redis: {e}")
|
logger.error(f"Failed to store message to Redis: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
|
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||||
"""从 Redis 获取消息"""
|
"""从 Redis 获取消息"""
|
||||||
try:
|
try:
|
||||||
# 获取消息键列表,按消息ID排序
|
# 获取消息键列表,按消息ID排序
|
||||||
@@ -314,28 +200,16 @@ class RedisMessageSystem:
|
|||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for key in message_keys:
|
for key in message_keys:
|
||||||
# 获取消息数据
|
# 获取消息数据(JSON 字符串)
|
||||||
raw_data = await self.redis.hgetall(key)
|
raw_data = await self.redis.get(key)
|
||||||
if raw_data:
|
if raw_data:
|
||||||
# 解码数据
|
try:
|
||||||
message_data: dict[str, Any] = {}
|
# 解析 JSON 字符串为字典
|
||||||
for k, v in raw_data.items():
|
message_data = json.loads(raw_data)
|
||||||
# 尝试解析 JSON
|
messages.append(message_data)
|
||||||
try:
|
except json.JSONDecodeError as e:
|
||||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
logger.error(f"Failed to decode message JSON from {key}: {e}")
|
||||||
message_data[k] = json.loads(v)
|
continue
|
||||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
|
||||||
message_data[k] = int(v)
|
|
||||||
elif k == "is_multiplayer":
|
|
||||||
message_data[k] = v == "True"
|
|
||||||
elif k == "created_at":
|
|
||||||
message_data[k] = float(v)
|
|
||||||
else:
|
|
||||||
message_data[k] = v
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
message_data[k] = v
|
|
||||||
|
|
||||||
messages.append(message_data)
|
|
||||||
|
|
||||||
# 确保消息按ID正序排序(时间顺序)
|
# 确保消息按ID正序排序(时间顺序)
|
||||||
messages.sort(key=lambda x: x.get("message_id", 0))
|
messages.sort(key=lambda x: x.get("message_id", 0))
|
||||||
@@ -350,15 +224,15 @@ class RedisMessageSystem:
|
|||||||
logger.error(f"Failed to get messages from Redis: {e}")
|
logger.error(f"Failed to get messages from Redis: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
|
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
|
||||||
"""从数据库补充历史消息"""
|
"""从数据库补充历史消息"""
|
||||||
try:
|
try:
|
||||||
# 找到最小的消息ID
|
# 找到最小的消息ID
|
||||||
min_id = float("inf")
|
min_id = float("inf")
|
||||||
if existing_messages:
|
if existing_messages:
|
||||||
for msg in existing_messages:
|
for msg in existing_messages:
|
||||||
if msg.message_id is not None and msg.message_id < min_id:
|
if msg["message_id"] is not None and msg["message_id"] < min_id:
|
||||||
min_id = msg.message_id
|
min_id = msg["message_id"]
|
||||||
|
|
||||||
needed = limit - len(existing_messages)
|
needed = limit - len(existing_messages)
|
||||||
|
|
||||||
@@ -378,13 +252,13 @@ class RedisMessageSystem:
|
|||||||
db_messages = (await session.exec(query)).all()
|
db_messages = (await session.exec(query)).all()
|
||||||
|
|
||||||
for msg in reversed(db_messages): # 按时间正序插入
|
for msg in reversed(db_messages): # 按时间正序插入
|
||||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||||
existing_messages.insert(0, msg_resp)
|
existing_messages.insert(0, msg_resp)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to backfill from database: {e}")
|
logger.error(f"Failed to backfill from database: {e}")
|
||||||
|
|
||||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
|
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
|
||||||
"""仅从数据库获取消息(回退方案)"""
|
"""仅从数据库获取消息(回退方案)"""
|
||||||
try:
|
try:
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
@@ -402,7 +276,7 @@ class RedisMessageSystem:
|
|||||||
|
|
||||||
messages = (await session.exec(query)).all()
|
messages = (await session.exec(query)).all()
|
||||||
|
|
||||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
|
||||||
|
|
||||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||||
if since == 0:
|
if since == 0:
|
||||||
@@ -450,27 +324,17 @@ class RedisMessageSystem:
|
|||||||
# 解析频道ID和消息ID
|
# 解析频道ID和消息ID
|
||||||
channel_id, message_id = map(int, key.split(":"))
|
channel_id, message_id = map(int, key.split(":"))
|
||||||
|
|
||||||
# 从 Redis 获取消息数据
|
# 从 Redis 获取消息数据(JSON 字符串)
|
||||||
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
|
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
|
||||||
|
|
||||||
if not raw_data:
|
if not raw_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 解码数据
|
# 解析 JSON 字符串为字典
|
||||||
message_data = {}
|
try:
|
||||||
for k, v in raw_data.items():
|
message_data = json.loads(raw_data)
|
||||||
message_data[k] = v
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
|
||||||
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
|
||||||
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
|
||||||
if is_multiplayer:
|
|
||||||
# 多人房间消息不存储到数据库,直接标记为已跳过
|
|
||||||
await self.redis.hset(
|
|
||||||
f"msg:{channel_id}:{message_id}",
|
|
||||||
"status",
|
|
||||||
"skipped_multiplayer",
|
|
||||||
)
|
|
||||||
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查消息是否已存在于数据库
|
# 检查消息是否已存在于数据库
|
||||||
@@ -491,13 +355,6 @@ class RedisMessageSystem:
|
|||||||
|
|
||||||
session.add(db_message)
|
session.add(db_message)
|
||||||
|
|
||||||
# 更新 Redis 中的状态
|
|
||||||
await self.redis.hset(
|
|
||||||
f"msg:{channel_id}:{message_id}",
|
|
||||||
"status",
|
|
||||||
"persisted",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Message {message_id} persisted to database")
|
logger.debug(f"Message {message_id} persisted to database")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
|
|
||||||
|
|
||||||
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
|
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
|
||||||
db_room = room.to_room()
|
db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})})
|
||||||
db_room.host_id = host_id
|
|
||||||
db_room.starts_at = utcnow()
|
db_room.starts_at = utcnow()
|
||||||
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
|
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
|
||||||
session.add(db_room)
|
session.add(db_room)
|
||||||
|
|||||||
@@ -3,19 +3,19 @@
|
|||||||
用于缓存用户信息,提供热缓存和实时刷新功能
|
用于缓存用户信息,提供热缓存和实时刷新功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import User, UserResp
|
from app.database import User
|
||||||
from app.database.score import LegacyScoreResp, ScoreResp
|
from app.database.score import LegacyScoreResp
|
||||||
from app.database.user import SEARCH_INCLUDED
|
from app.database.user import UserDict, UserModel
|
||||||
from app.dependencies.database import with_db
|
from app.dependencies.database import with_db
|
||||||
from app.helpers.asset_proxy_helper import replace_asset_urls
|
from app.helpers.asset_proxy_helper import replace_asset_urls
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
from app.utils import safe_json_dumps
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
@@ -25,20 +25,6 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
|
||||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime):
|
|
||||||
return obj.isoformat()
|
|
||||||
return super().default(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_json_dumps(data: Any) -> str:
|
|
||||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
|
||||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
class UserCacheService:
|
class UserCacheService:
|
||||||
"""用户缓存服务"""
|
"""用户缓存服务"""
|
||||||
|
|
||||||
@@ -125,7 +111,7 @@ class UserCacheService:
|
|||||||
"""生成用户谱面集缓存键"""
|
"""生成用户谱面集缓存键"""
|
||||||
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
||||||
|
|
||||||
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
|
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None:
|
||||||
"""从缓存获取用户信息"""
|
"""从缓存获取用户信息"""
|
||||||
try:
|
try:
|
||||||
cache_key = self._get_user_cache_key(user_id, ruleset)
|
cache_key = self._get_user_cache_key(user_id, ruleset)
|
||||||
@@ -133,7 +119,7 @@ class UserCacheService:
|
|||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"User cache hit for user {user_id}")
|
logger.debug(f"User cache hit for user {user_id}")
|
||||||
data = json.loads(cached_data)
|
data = json.loads(cached_data)
|
||||||
return UserResp(**data)
|
return data
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user from cache: {e}")
|
logger.error(f"Error getting user from cache: {e}")
|
||||||
@@ -141,7 +127,7 @@ class UserCacheService:
|
|||||||
|
|
||||||
async def cache_user(
|
async def cache_user(
|
||||||
self,
|
self,
|
||||||
user_resp: UserResp,
|
user_resp: UserDict,
|
||||||
ruleset: GameMode | None = None,
|
ruleset: GameMode | None = None,
|
||||||
expire_seconds: int | None = None,
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
@@ -149,13 +135,10 @@ class UserCacheService:
|
|||||||
try:
|
try:
|
||||||
if expire_seconds is None:
|
if expire_seconds is None:
|
||||||
expire_seconds = settings.user_cache_expire_seconds
|
expire_seconds = settings.user_cache_expire_seconds
|
||||||
if user_resp.id is None:
|
cache_key = self._get_user_cache_key(user_resp["id"], ruleset)
|
||||||
logger.warning("Cannot cache user with None id")
|
cached_data = safe_json_dumps(user_resp)
|
||||||
return
|
|
||||||
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
|
|
||||||
cached_data = user_resp.model_dump_json()
|
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||||
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
|
logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching user: {e}")
|
logger.error(f"Error caching user: {e}")
|
||||||
|
|
||||||
@@ -168,10 +151,9 @@ class UserCacheService:
|
|||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
is_legacy: bool = False,
|
is_legacy: bool = False,
|
||||||
) -> list[ScoreResp] | list[LegacyScoreResp] | None:
|
) -> list[UserDict] | list[LegacyScoreResp] | None:
|
||||||
"""从缓存获取用户成绩"""
|
"""从缓存获取用户成绩"""
|
||||||
try:
|
try:
|
||||||
model = LegacyScoreResp if is_legacy else ScoreResp
|
|
||||||
cache_key = self._get_user_scores_cache_key(
|
cache_key = self._get_user_scores_cache_key(
|
||||||
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
||||||
)
|
)
|
||||||
@@ -179,7 +161,7 @@ class UserCacheService:
|
|||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
||||||
data = json.loads(cached_data)
|
data = json.loads(cached_data)
|
||||||
return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType]
|
return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user scores from cache: {e}")
|
logger.error(f"Error getting user scores from cache: {e}")
|
||||||
@@ -189,7 +171,7 @@ class UserCacheService:
|
|||||||
self,
|
self,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
score_type: str,
|
score_type: str,
|
||||||
scores: list[ScoreResp] | list[LegacyScoreResp],
|
scores: list[UserDict] | list[LegacyScoreResp],
|
||||||
include_fail: bool,
|
include_fail: bool,
|
||||||
mode: GameMode | None = None,
|
mode: GameMode | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
@@ -204,8 +186,12 @@ class UserCacheService:
|
|||||||
cache_key = self._get_user_scores_cache_key(
|
cache_key = self._get_user_scores_cache_key(
|
||||||
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
||||||
)
|
)
|
||||||
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
if len(scores) == 0:
|
||||||
scores_json_list = [score.model_dump_json() for score in scores]
|
return
|
||||||
|
if isinstance(scores[0], dict):
|
||||||
|
scores_json_list = [safe_json_dumps(score) for score in scores]
|
||||||
|
else:
|
||||||
|
scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue]
|
||||||
cached_data = f"[{','.join(scores_json_list)}]"
|
cached_data = f"[{','.join(scores_json_list)}]"
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||||
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
||||||
@@ -308,7 +294,7 @@ class UserCacheService:
|
|||||||
for user in users:
|
for user in users:
|
||||||
if user.id != BANCHOBOT_ID:
|
if user.id != BANCHOBOT_ID:
|
||||||
try:
|
try:
|
||||||
await self._cache_single_user(user, session)
|
await self._cache_single_user(user)
|
||||||
cached_count += 1
|
cached_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to cache user {user.id}: {e}")
|
logger.error(f"Failed to cache user {user.id}: {e}")
|
||||||
@@ -320,10 +306,10 @@ class UserCacheService:
|
|||||||
finally:
|
finally:
|
||||||
self._refreshing = False
|
self._refreshing = False
|
||||||
|
|
||||||
async def _cache_single_user(self, user: User, session: AsyncSession):
|
async def _cache_single_user(self, user: User):
|
||||||
"""缓存单个用户"""
|
"""缓存单个用户"""
|
||||||
try:
|
try:
|
||||||
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
|
user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES)
|
||||||
|
|
||||||
# 应用资源代理处理
|
# 应用资源代理处理
|
||||||
if settings.enable_asset_proxy:
|
if settings.enable_asset_proxy:
|
||||||
@@ -347,7 +333,7 @@ class UserCacheService:
|
|||||||
# 立即重新加载用户信息
|
# 立即重新加载用户信息
|
||||||
user = await session.get(User, user_id)
|
user = await session.get(User, user_id)
|
||||||
if user and user.id != BANCHOBOT_ID:
|
if user and user.id != BANCHOBOT_ID:
|
||||||
await self._cache_single_user(user, session)
|
await self._cache_single_user(user)
|
||||||
logger.info(f"Refreshed cache for user {user_id} after score submit")
|
logger.info(f"Refreshed cache for user {user_id} after score submit")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error refreshing user cache on score submit: {e}")
|
logger.error(f"Error refreshing user cache on score submit: {e}")
|
||||||
|
|||||||
53
app/utils.py
53
app/utils.py
@@ -4,10 +4,13 @@ from datetime import UTC, datetime
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
from types import NoneType, UnionType
|
||||||
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -299,3 +302,51 @@ def hex_to_hue(hex_color: str) -> int:
|
|||||||
hue = (60 * ((r - g) / delta) + 240) % 360
|
hue = (60 * ((r - g) / delta) + 240) % 360
|
||||||
|
|
||||||
return int(hue)
|
return int(hue)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_json_dumps(data) -> str:
|
||||||
|
return json.dumps(jsonable_encoder(data), ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def type_is_optional(typ: type):
|
||||||
|
origin_type = get_origin(typ)
|
||||||
|
args = get_args(typ)
|
||||||
|
return (origin_type is UnionType or origin_type is Union) and len(args) == 2 and NoneType in args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_type(typ: type, includes: tuple[str, ...]) -> Any:
|
||||||
|
from app.database._base import DatabaseModel
|
||||||
|
|
||||||
|
origin = get_origin(typ)
|
||||||
|
if issubclass(typ, DatabaseModel):
|
||||||
|
return typ.generate_typeddict(includes)
|
||||||
|
elif origin is list:
|
||||||
|
item_type = typ.__args__[0]
|
||||||
|
return list[_get_type(item_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
|
||||||
|
elif origin is dict:
|
||||||
|
key_type, value_type = typ.__args__
|
||||||
|
return dict[key_type, _get_type(value_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
|
||||||
|
elif type_is_optional(typ):
|
||||||
|
inner_type = next(arg for arg in get_args(typ) if arg is not NoneType)
|
||||||
|
return Union[_get_type(inner_type, includes), None] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
|
||||||
|
elif origin is UnionType or origin is Union:
|
||||||
|
new_types = []
|
||||||
|
for arg in get_args(typ):
|
||||||
|
new_types.append(_get_type(arg, includes)) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
|
||||||
|
return Union[tuple(new_types)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
|
||||||
|
else:
|
||||||
|
return typ
|
||||||
|
|
||||||
|
|
||||||
|
def api_doc(desc: str, model: Any, includes: list[str] = [], *, name: str = "APIDict"):
|
||||||
|
if includes:
|
||||||
|
includes_str = ", ".join(f"`{inc}`" for inc in includes)
|
||||||
|
desc += f"\n\n包含:{includes_str}"
|
||||||
|
if isinstance(model, dict):
|
||||||
|
fields = {}
|
||||||
|
for k, v in model.items():
|
||||||
|
fields[k] = _get_type(v, tuple(includes))
|
||||||
|
typed_dict = TypedDict(name, fields) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
|
||||||
|
else:
|
||||||
|
typed_dict = _get_type(model, tuple(includes))
|
||||||
|
return {"description": desc, "model": typed_dict}
|
||||||
|
|||||||
@@ -0,0 +1,498 @@
|
|||||||
|
"""project: remove unused fields in database
|
||||||
|
|
||||||
|
Revision ID: 23707640303c
|
||||||
|
Revises: 3f0f22f38c3d
|
||||||
|
Create Date: 2025-11-23 08:14:05.284238
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "23707640303c"
|
||||||
|
down_revision: str | Sequence[str] | None = "3f0f22f38c3d"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column(
|
||||||
|
"beatmaps",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.drop_column("beatmaps", "current_user_playcount")
|
||||||
|
op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=True)
|
||||||
|
op.alter_column(
|
||||||
|
"best_scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_user_statistics",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=True)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_users",
|
||||||
|
"playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_users",
|
||||||
|
"g0v0_playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.drop_column("lazer_users", "beatmap_playcounts_count")
|
||||||
|
op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=False)
|
||||||
|
op.alter_column(
|
||||||
|
"rank_history",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"rank_top",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"rooms",
|
||||||
|
"type",
|
||||||
|
existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.execute("UPDATE rooms SET channel_id = 0 WHERE channel_id IS NULL")
|
||||||
|
op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=False)
|
||||||
|
op.alter_column(
|
||||||
|
"score_tokens",
|
||||||
|
"ruleset_id",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"teams",
|
||||||
|
"playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"total_score_best_scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column(
|
||||||
|
"total_score_best_scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"teams",
|
||||||
|
"playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"score_tokens",
|
||||||
|
"ruleset_id",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=True)
|
||||||
|
op.alter_column(
|
||||||
|
"rooms",
|
||||||
|
"type",
|
||||||
|
existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"rank_top",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"rank_history",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=True)
|
||||||
|
op.add_column(
|
||||||
|
"lazer_users", sa.Column("beatmap_playcounts_count", mysql.INTEGER(), autoincrement=False, nullable=False)
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_users",
|
||||||
|
"g0v0_playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_users",
|
||||||
|
"playmode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=False)
|
||||||
|
op.alter_column(
|
||||||
|
"lazer_user_statistics",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"best_scores",
|
||||||
|
"gamemode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=False)
|
||||||
|
op.add_column("beatmaps", sa.Column("current_user_playcount", mysql.INTEGER(), autoincrement=False, nullable=False))
|
||||||
|
op.alter_column(
|
||||||
|
"beatmaps",
|
||||||
|
"mode",
|
||||||
|
existing_type=mysql.ENUM(
|
||||||
|
"OSU",
|
||||||
|
"TAIKO",
|
||||||
|
"FRUITS",
|
||||||
|
"MANIA",
|
||||||
|
"OSURX",
|
||||||
|
"OSUAP",
|
||||||
|
"TAIKORX",
|
||||||
|
"FRUITSRX",
|
||||||
|
"SENTAKKI",
|
||||||
|
"TAU",
|
||||||
|
"RUSH",
|
||||||
|
"HISHIGATA",
|
||||||
|
"SOYOKAZE",
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
Reference in New Issue
Block a user