From 9ce99398abce3bf7c4bd05246b0d30f958576b21 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Wed, 30 Jul 2025 16:17:09 +0000 Subject: [PATCH 01/45] refactor(user): refactor user database MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Breaking Change** 用户表变为 lazer_users 建议删除与用户关联的表进行迁移 --- README.md | 423 +++++++------- app/auth.py | 8 +- app/database/__init__.py | 64 +-- app/database/achievement.py | 40 ++ app/database/auth.py | 8 +- app/database/beatmap.py | 3 +- app/database/beatmapset.py | 3 +- app/database/best_score.py | 4 +- app/database/daily_challenge.py | 58 ++ app/database/lazer_user.py | 300 ++++++++++ app/database/legacy.py | 94 ---- app/database/relationship.py | 26 +- app/database/score.py | 41 +- app/database/score_token.py | 11 +- app/database/statistics.py | 95 ++++ app/database/team.py | 10 +- app/database/user.py | 527 ------------------ app/database/user_account_history.py | 45 ++ app/dependencies/user.py | 13 +- app/models/model.py | 15 + app/models/room.py | 18 +- app/models/user.py | 130 +---- app/router/auth.py | 139 ++--- app/router/beatmap.py | 15 +- app/router/beatmapset.py | 4 +- app/router/me.py | 40 +- app/router/relationship.py | 11 +- app/router/score.py | 19 +- app/router/user.py | 111 ++-- app/signalr/hub/metadata.py | 2 +- app/signalr/hub/spectator.py | 7 +- app/signalr/router.py | 4 +- app/utils.py | 459 --------------- create_sample_data.py | 242 -------- main.py | 4 - ...c71791_score_remove_best_id_in_database.py | 38 -- ..._score_add_nlarge_tick_hit_nsmall_tick_.py | 36 -- 37 files changed, 994 insertions(+), 2073 deletions(-) create mode 100644 app/database/achievement.py create mode 100644 app/database/daily_challenge.py create mode 100644 app/database/lazer_user.py delete mode 100644 app/database/legacy.py create mode 100644 app/database/statistics.py delete mode 100644 app/database/user.py create mode 100644 app/database/user_account_history.py create mode 100644 app/models/model.py delete mode 100644 create_sample_data.py delete mode 100644 migrations/versions/78be13c71791_score_remove_best_id_in_database.py delete mode 100644 migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py diff --git a/README.md b/README.md index 267e2b5..a4e1e22 100644 --- a/README.md +++ b/README.md @@ -1,218 +1,205 @@ -# osu! API 模拟服务器 - -这是一个使用 FastAPI + MySQL + Redis 实现的 osu! API 模拟服务器,提供了完整的用户认证和数据管理功能。 - -## 功能特性 - -- **OAuth 2.0 认证**: 支持密码流和刷新令牌流 -- **用户数据管理**: 完整的用户信息、统计数据、成就等 -- **多游戏模式支持**: osu!, taiko, fruits, mania -- **数据库持久化**: MySQL 存储用户数据 -- **缓存支持**: Redis 缓存令牌和会话信息 -- **容器化部署**: Docker 和 Docker Compose 支持 - -## API 端点 - -### 认证端点 -- `POST /oauth/token` - OAuth 令牌获取/刷新 - -### 用户端点 -- `GET /api/v2/me/{ruleset}` - 获取当前用户信息 - -### 其他端点 -- `GET /` - 根端点 -- `GET /health` - 健康检查 - -## 快速开始 - -### 使用 Docker Compose (推荐) - -1. 克隆项目 -```bash -git clone -cd osu_lazer_api -``` - -2. 启动服务 -```bash -docker-compose up -d -``` - -3. 创建示例数据 -```bash -docker-compose exec api python create_sample_data.py -``` - -4. 测试 API -```bash -# 获取访问令牌 -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" - -# 使用令牌获取用户信息 -curl -X GET http://localhost:8000/api/v2/me/osu \ - -H "Authorization: Bearer YOUR_ACCESS_TOKEN" -``` - -### 本地开发 - -1. 安装依赖 -```bash -pip install -r requirements.txt -``` - -2. 配置环境变量 -```bash -# 复制服务器配置文件 -cp .env .env.local - -# 复制客户端配置文件(用于测试脚本) -cp .env.client .env.client.local -``` - -3. 启动 MySQL 和 Redis -```bash -# 使用 Docker -docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0 -docker run -d --name redis -p 6379:6379 redis:7-alpine -``` - -4. 创建示例数据 -```bash -python create_sample_data.py -``` - -5. 启动应用 -```bash -uvicorn main:app --reload -``` - -6. 测试 API -```bash -# 使用测试脚本(会自动加载 .env 文件) -python test_api.py - -# 或使用原始示例脚本 -python osu_api_example.py -``` - -## 项目结构 - -``` -osu_lazer_api/ -├── app/ -│ ├── __init__.py -│ ├── models.py # Pydantic 数据模型 -│ ├── database.py # SQLAlchemy 数据库模型 -│ ├── config.py # 配置设置 -│ ├── dependencies.py # 依赖注入 -│ ├── auth.py # 认证和令牌管理 -│ └── utils.py # 工具函数 -├── main.py # FastAPI 应用主文件 -├── create_sample_data.py # 示例数据创建脚本 -├── requirements.txt # Python 依赖 -├── .env # 环境变量配置 -├── docker-compose.yml # Docker Compose 配置 -├── Dockerfile # Docker 镜像配置 -└── README.md # 项目说明 -``` - -## 示例用户 - -创建示例数据后,您可以使用以下凭据进行测试: - -- **用户名**: `Googujiang` -- **密码**: `password123` -- **用户ID**: `15651670` - -## 环境变量配置 - -项目包含两个环境配置文件: - -### 服务器配置 (`.env`) -用于配置 FastAPI 服务器的运行参数: - -| 变量名 | 描述 | 默认值 | -|--------|------|--------| -| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` | -| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` | -| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` | -| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` | -| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | -| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | -| `HOST` | 服务器监听地址 | `0.0.0.0` | -| `PORT` | 服务器监听端口 | `8000` | -| `DEBUG` | 调试模式 | `True` | - -### 客户端配置 (`.env.client`) -用于配置客户端脚本的 API 连接参数: - -| 变量名 | 描述 | 默认值 | -|--------|------|--------| -| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | -| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | -| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` | - -> **注意**: 在生产环境中,请务必更改默认的密钥和密码! - -## API 使用示例 - -### 获取访问令牌 - -```bash -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" -``` - -响应: -```json -{ - "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", - "token_type": "Bearer", - "expires_in": 86400, - "refresh_token": "abc123...", - "scope": "*" -} -``` - -### 获取用户信息 - -```bash -curl -X GET http://localhost:8000/api/v2/me/osu \ - -H "Authorization: Bearer YOUR_ACCESS_TOKEN" -``` - -### 刷新令牌 - -```bash -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" -``` - -## 开发 - -### 添加新用户 - -您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。 - -### 扩展功能 - -- 添加更多 API 端点(排行榜、谱面信息等) -- 实现实时功能(WebSocket) -- 添加管理面板 -- 实现数据导入/导出功能 - -### 迁移数据库 - -参考[数据库迁移指南](./MIGRATE_GUIDE.md) - -## 许可证 - -MIT License - -## 贡献 - -欢迎提交 Issue 和 Pull Request! +# osu! API 模拟服务器 + +这是一个使用 FastAPI + MySQL + Redis 实现的 osu! API 模拟服务器,提供了完整的用户认证和数据管理功能。 + +## 功能特性 + +- **OAuth 2.0 认证**: 支持密码流和刷新令牌流 +- **用户数据管理**: 完整的用户信息、统计数据、成就等 +- **多游戏模式支持**: osu!, taiko, fruits, mania +- **数据库持久化**: MySQL 存储用户数据 +- **缓存支持**: Redis 缓存令牌和会话信息 +- **容器化部署**: Docker 和 Docker Compose 支持 + +## API 端点 + +### 认证端点 +- `POST /oauth/token` - OAuth 令牌获取/刷新 + +### 用户端点 +- `GET /api/v2/me/{ruleset}` - 获取当前用户信息 + +### 其他端点 +- `GET /` - 根端点 +- `GET /health` - 健康检查 + +## 快速开始 + +### 使用 Docker Compose (推荐) + +1. 克隆项目 +```bash +git clone +cd osu_lazer_api +``` + +2. 启动服务 +```bash +docker-compose up -d +``` + +3. 创建示例数据 +```bash +docker-compose exec api python create_sample_data.py +``` + +4. 测试 API +```bash +# 获取访问令牌 +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" + +# 使用令牌获取用户信息 +curl -X GET http://localhost:8000/api/v2/me/osu \ + -H "Authorization: Bearer YOUR_ACCESS_TOKEN" +``` + +### 本地开发 + +1. 安装依赖 +```bash +pip install -r requirements.txt +``` + +2. 配置环境变量 +```bash +# 复制服务器配置文件 +cp .env .env.local + +# 复制客户端配置文件(用于测试脚本) +cp .env.client .env.client.local +``` + +3. 启动 MySQL 和 Redis +```bash +# 使用 Docker +docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0 +docker run -d --name redis -p 6379:6379 redis:7-alpine +``` + + +4. 启动应用 +```bash +uvicorn main:app --reload +``` + +## 项目结构 + +``` +osu_lazer_api/ +├── app/ +│ ├── __init__.py +│ ├── models.py # Pydantic 数据模型 +│ ├── database.py # SQLAlchemy 数据库模型 +│ ├── config.py # 配置设置 +│ ├── dependencies.py # 依赖注入 +│ ├── auth.py # 认证和令牌管理 +│ └── utils.py # 工具函数 +├── main.py # FastAPI 应用主文件 +├── create_sample_data.py # 示例数据创建脚本 +├── requirements.txt # Python 依赖 +├── .env # 环境变量配置 +├── docker-compose.yml # Docker Compose 配置 +├── Dockerfile # Docker 镜像配置 +└── README.md # 项目说明 +``` + +## 示例用户 + +创建示例数据后,您可以使用以下凭据进行测试: + +- **用户名**: `Googujiang` +- **密码**: `password123` +- **用户ID**: `15651670` + +## 环境变量配置 + +项目包含两个环境配置文件: + +### 服务器配置 (`.env`) +用于配置 FastAPI 服务器的运行参数: + +| 变量名 | 描述 | 默认值 | +|--------|------|--------| +| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` | +| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` | +| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` | +| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` | +| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | +| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | +| `HOST` | 服务器监听地址 | `0.0.0.0` | +| `PORT` | 服务器监听端口 | `8000` | +| `DEBUG` | 调试模式 | `True` | + +### 客户端配置 (`.env.client`) +用于配置客户端脚本的 API 连接参数: + +| 变量名 | 描述 | 默认值 | +|--------|------|--------| +| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | +| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | +| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` | + +> **注意**: 在生产环境中,请务必更改默认的密钥和密码! + +## API 使用示例 + +### 获取访问令牌 + +```bash +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" +``` + +响应: +```json +{ + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", + "token_type": "Bearer", + "expires_in": 86400, + "refresh_token": "abc123...", + "scope": "*" +} +``` + +### 获取用户信息 + +```bash +curl -X GET http://localhost:8000/api/v2/me/osu \ + -H "Authorization: Bearer YOUR_ACCESS_TOKEN" +``` + +### 刷新令牌 + +```bash +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" +``` + +## 开发 + +### 添加新用户 + +您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。 + +### 扩展功能 + +- 添加更多 API 端点(排行榜、谱面信息等) +- 实现实时功能(WebSocket) +- 添加管理面板 +- 实现数据导入/导出功能 + +### 迁移数据库 + +参考[数据库迁移指南](./MIGRATE_GUIDE.md) + +## 许可证 + +MIT License + +## 贡献 + +欢迎提交 Issue 和 Pull Request! diff --git a/app/auth.py b/app/auth.py index 4c690f8..4762662 100644 --- a/app/auth.py +++ b/app/auth.py @@ -8,7 +8,7 @@ import string from app.config import settings from app.database import ( OAuthToken, - User as DBUser, + User, ) from app.log import logger @@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str: async def authenticate_user_legacy( db: AsyncSession, name: str, password: str -) -> DBUser | None: +) -> User | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -82,7 +82,7 @@ async def authenticate_user_legacy( pw_md5 = hashlib.md5(password.encode()).hexdigest() # 2. 根据用户名查找用户 - statement = select(DBUser).where(DBUser.name == name) + statement = select(User).where(User.username == name) user = (await db.exec(statement)).first() if not user: return None @@ -113,7 +113,7 @@ async def authenticate_user_legacy( async def authenticate_user( db: AsyncSession, username: str, password: str -) -> DBUser | None: +) -> User | None: """验证用户身份""" return await authenticate_user_legacy(db, username, password) diff --git a/app/database/__init__.py b/app/database/__init__.py index 191a193..91bc7cc 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,3 +1,4 @@ +from .achievement import UserAchievement, UserAchievementResp from .auth import OAuthToken from .beatmap import ( Beatmap as Beatmap, @@ -8,7 +9,11 @@ from .beatmapset import ( BeatmapsetResp as BeatmapsetResp, ) from .best_score import BestScore -from .legacy import LegacyOAuthToken, LegacyUserStatistics +from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .lazer_user import ( + User, + UserResp, +) from .relationship import Relationship, RelationshipResp, RelationshipType from .score import ( Score, @@ -17,29 +22,17 @@ from .score import ( ScoreStatistics, ) from .score_token import ScoreToken, ScoreTokenResp +from .statistics import ( + UserStatistics, + UserStatisticsResp, +) from .team import Team, TeamMember -from .user import ( - DailyChallengeStats, - LazerUserAchievement, - LazerUserBadge, - LazerUserBanners, - LazerUserCountry, - LazerUserCounts, - LazerUserKudosu, - LazerUserMonthlyPlaycounts, - LazerUserPreviousUsername, - LazerUserProfile, - LazerUserProfileSections, - LazerUserReplaysWatched, - LazerUserStatistics, - RankHistory, - User, - UserAchievement, - UserAvatar, +from .user_account_history import ( + UserAccountHistory, + UserAccountHistoryResp, + UserAccountHistoryType, ) -BeatmapsetResp.model_rebuild() -BeatmapResp.model_rebuild() __all__ = [ "Beatmap", "BeatmapResp", @@ -47,22 +40,8 @@ __all__ = [ "BeatmapsetResp", "BestScore", "DailyChallengeStats", - "LazerUserAchievement", - "LazerUserBadge", - "LazerUserBanners", - "LazerUserCountry", - "LazerUserCounts", - "LazerUserKudosu", - "LazerUserMonthlyPlaycounts", - "LazerUserPreviousUsername", - "LazerUserProfile", - "LazerUserProfileSections", - "LazerUserReplaysWatched", - "LazerUserStatistics", - "LegacyOAuthToken", - "LegacyUserStatistics", + "DailyChallengeStatsResp", "OAuthToken", - "RankHistory", "Relationship", "RelationshipResp", "RelationshipType", @@ -75,6 +54,17 @@ __all__ = [ "Team", "TeamMember", "User", + "UserAccountHistory", + "UserAccountHistoryResp", + "UserAccountHistoryType", "UserAchievement", - "UserAvatar", + "UserAchievement", + "UserAchievementResp", + "UserResp", + "UserStatistics", + "UserStatisticsResp", ] + +for i in __all__: + if i.endswith("Resp"): + globals()[i].model_rebuild() # type: ignore[call-arg] diff --git a/app/database/achievement.py b/app/database/achievement.py new file mode 100644 index 0000000..4be587f --- /dev/null +++ b/app/database/achievement.py @@ -0,0 +1,40 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class UserAchievementBase(SQLModel, UTCBaseModel): + achievement_id: int = Field(primary_key=True) + achieved_at: datetime = Field( + default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) + ) + + +class UserAchievement(UserAchievementBase, table=True): + __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True + ) + user: "User" = Relationship(back_populates="achievement") + + +class UserAchievementResp(UserAchievementBase): + @classmethod + def from_db(cls, db_model: UserAchievement) -> "UserAchievementResp": + return cls.model_validate(db_model) diff --git a/app/database/auth.py b/app/database/auth.py index ae49676..554dced 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -1,19 +1,21 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class OAuthToken(SQLModel, table=True): +class OAuthToken(UTCBaseModel, SQLModel, table=True): __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 48e7fa0..751bc5c 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING from app.models.beatmap import BeatmapRankStatus +from app.models.model import UTCBaseModel from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp @@ -20,7 +21,7 @@ class BeatmapOwner(SQLModel): username: str -class BeatmapBase(SQLModel): +class BeatmapBase(SQLModel, UTCBaseModel): # Beatmap url: str mode: GameMode diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 1e6ba27..2ef6280 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING, TypedDict, cast from app.models.beatmap import BeatmapRankStatus, Genre, Language +from app.models.model import UTCBaseModel from app.models.score import GameMode from pydantic import BaseModel, model_serializer @@ -82,7 +83,7 @@ class BeatmapTranslationText(BaseModel): id: int | None = None -class BeatmapsetBase(SQLModel): +class BeatmapsetBase(SQLModel, UTCBaseModel): # Beatmapset artist: str = Field(index=True) artist_unicode: str = Field(index=True) diff --git a/app/database/best_score.py b/app/database/best_score.py index 313da3e..9993b63 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from app.models.score import GameMode -from .user import User +from .lazer_user import User from sqlmodel import ( BigInteger, @@ -22,7 +22,7 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) score_id: int = Field( sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) diff --git a/app/database/daily_challenge.py b/app/database/daily_challenge.py new file mode 100644 index 0000000..abf874f --- /dev/null +++ b/app/database/daily_challenge.py @@ -0,0 +1,58 @@ +from datetime import datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class DailyChallengeStatsBase(SQLModel, UTCBaseModel): + daily_streak_best: int = Field(default=0) + daily_streak_current: int = Field(default=0) + last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) + last_weekly_streak: datetime | None = Field( + default=None, sa_column=Column(DateTime) + ) + playcount: int = Field(default=0) + top_10p_placements: int = Field(default=0) + top_50p_placements: int = Field(default=0) + weekly_streak_best: int = Field(default=0) + weekly_streak_current: int = Field(default=0) + + +class DailyChallengeStats(DailyChallengeStatsBase, table=True): + __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] + + user_id: int | None = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + unique=True, + index=True, + primary_key=True, + ), + ) + user: "User" = Relationship(back_populates="daily_challenge_stats") + + +class DailyChallengeStatsResp(DailyChallengeStatsBase): + user_id: int + + @classmethod + def from_db( + cls, + obj: DailyChallengeStats, + ) -> "DailyChallengeStatsResp": + return cls.model_validate(obj) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py new file mode 100644 index 0000000..9627015 --- /dev/null +++ b/app/database/lazer_user.py @@ -0,0 +1,300 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING, NotRequired, TypedDict + +from app.models.model import UTCBaseModel +from app.models.score import GameMode +from app.models.user import Country, Page, RankHistory + +from .achievement import UserAchievement, UserAchievementResp +from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .statistics import UserStatistics, UserStatisticsResp +from .team import Team, TeamMember +from .user_account_history import UserAccountHistory, UserAccountHistoryResp + +from sqlalchemy.orm import joinedload, selectinload +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + Relationship, + SQLModel, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.database.relationship import RelationshipResp + + +class Kudosu(TypedDict): + available: int + total: int + + +class RankHighest(TypedDict): + rank: int + updated_at: datetime + + +class UserProfileCover(TypedDict): + url: str + custom_url: NotRequired[str] + id: NotRequired[str] + + +Badge = TypedDict( + "Badge", + { + "awarded_at": datetime, + "description": str, + "image@2x_url": str, + "image_url": str, + "url": str, + }, +) + + +class UserBase(UTCBaseModel, SQLModel): + avatar_url: str = "" + country_code: str = Field(default="CN", max_length=2, index=True) + # ? default_group: str|None + is_active: bool = True + is_bot: bool = False + is_supporter: bool = False + last_visit: datetime = Field( + default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) + ) + pm_friends_only: bool = False + profile_colour: str | None = None + username: str = Field(max_length=32, unique=True, index=True) + page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw="")) + previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + # TODO: replays_watched_counts + support_level: int = 0 + badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON)) + + # optional + is_restricted: bool = False + # blocks + cover: UserProfileCover = Field( + default=UserProfileCover( + url="https://assets.ppy.sh/user-profile-covers/default.jpeg" + ), + sa_column=Column(JSON), + ) + beatmap_playcounts_count: int = 0 + # kudosu + + # UserExtended + playmode: GameMode = GameMode.OSU + discord: str | None = None + has_supported: bool = False + interests: str | None = None + join_date: datetime = Field(default=datetime.now(UTC)) + location: str | None = None + max_blocks: int = 50 + max_friends: int = 500 + occupation: str | None = None + playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + # TODO: post_count + profile_hue: int | None = None + profile_order: list[str] = Field( + default_factory=lambda: [ + "me", + "recent_activity", + "top_ranks", + "medals", + "historical", + "beatmaps", + "kudosu", + ], + sa_column=Column(JSON), + ) + title: str | None = None + title_url: str | None = None + twitter: str | None = None + website: str | None = None + + # undocumented + comments_count: int = 0 + post_count: int = 0 + is_admin: bool = False + is_gmt: bool = False + is_qat: bool = False + is_bng: bool = False + + +class User(UserBase, table=True): + __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) + account_history: list[UserAccountHistory] = Relationship() + statistics: list[UserStatistics] = Relationship() + achievement: list[UserAchievement] = Relationship(back_populates="user") + team_membership: TeamMember | None = Relationship(back_populates="user") + daily_challenge_stats: DailyChallengeStats | None = Relationship( + back_populates="user" + ) + + email: str = Field(max_length=254, unique=True, index=True, exclude=True) + priv: int = Field(default=1, exclude=True) + pw_bcrypt: str = Field(max_length=60, exclude=True) + silence_end_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)), exclude=True + ) + donor_end_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)), exclude=True + ) + + @classmethod + def all_select_option(cls): + return ( + selectinload(cls.account_history), # pyright: ignore[reportArgumentType] + selectinload(cls.statistics), # pyright: ignore[reportArgumentType] + selectinload(cls.achievement), # pyright: ignore[reportArgumentType] + joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] + joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] + ) + + +class UserResp(UserBase): + id: int | None = None + is_online: bool = True # TODO + groups: list = [] # TODO + country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) + favourite_beatmapset_count: int = 0 # TODO + graveyard_beatmapset_count: int = 0 # TODO + guest_beatmapset_count: int = 0 # TODO + loved_beatmapset_count: int = 0 # TODO + mapping_follower_count: int = 0 # TODO + nominated_beatmapset_count: int = 0 # TODO + pending_beatmapset_count: int = 0 # TODO + ranked_beatmapset_count: int = 0 # TODO + follow_user_mapping: list[int] = Field(default_factory=list) + follower_count: int = 0 + friends: list["RelationshipResp"] | None = None + scores_best_count: int = 0 + scores_first_count: int = 0 + scores_recent_count: int = 0 + scores_pinned_count: int = 0 + account_history: list[UserAccountHistoryResp] = [] + active_tournament_banners: list[dict] = [] # TODO + kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO + monthly_playcounts: list = Field(default_factory=list) # TODO + unread_pm_count: int = 0 # TODO + rank_history: RankHistory | None = None # TODO + rank_highest: RankHighest | None = None # TODO + statistics: UserStatisticsResp | None = None + statistics_rulesets: dict[str, UserStatisticsResp] | None = None + user_achievements: list[UserAchievementResp] = Field(default_factory=list) + cover_url: str = "" # deprecated + team: Team | None = None + session_verified: bool = True + daily_challenge_user_stats: DailyChallengeStatsResp | None = None # TODO + + # TODO: monthly_playcounts, unread_pm_count, rank_history, user_preferences + + @classmethod + async def from_db( + cls, + obj: User, + session: AsyncSession, + include: list[str] = [], + ruleset: GameMode | None = None, + ) -> "UserResp": + from .best_score import BestScore + from .relationship import Relationship, RelationshipResp, RelationshipType + + u = cls.model_validate(obj.model_dump()) + u.id = obj.id + u.follower_count = ( + await session.exec( + select(func.count()) + .select_from(Relationship) + .where( + Relationship.target_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + ).one() + u.scores_best_count = ( + await session.exec( + select(func.count()) + .select_from(BestScore) + .where( + BestScore.user_id == obj.id, + ) + .limit(200) + ) + ).one() + u.cover_url = ( + obj.cover.get( + "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" + ) + if obj.cover + else "https://assets.ppy.sh/user-profile-covers/default.jpeg" + ) + + if "friends" in include: + u.friends = [ + await RelationshipResp.from_db(session, r) + for r in ( + await session.exec( + select(Relationship) + .options( + joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType] + *User.all_select_option() + ) + ) + .where( + Relationship.user_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + ).all() + ] + + if "team" in include: + if obj.team_membership: + u.team = obj.team_membership.team + + if "account_history" in include: + u.account_history = [ + UserAccountHistoryResp.from_db(ah) for ah in obj.account_history + ] + + if "daily_challenge_user_stats": + if obj.daily_challenge_stats: + u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( + obj.daily_challenge_stats + ) + + if "statistics" in include: + current_stattistics = None + for i in obj.statistics: + if i.mode == (ruleset or obj.playmode): + current_stattistics = i + break + u.statistics = ( + UserStatisticsResp.from_db(current_stattistics) + if current_stattistics + else None + ) + + if "statistics_rulesets" in include: + u.statistics_rulesets = { + i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics + } + + if "achievements" in include: + u.user_achievements = [ + UserAchievementResp.from_db(ua) for ua in obj.achievement + ] + + return u diff --git a/app/database/legacy.py b/app/database/legacy.py deleted file mode 100644 index ff1e957..0000000 --- a/app/database/legacy.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING - -from sqlalchemy import JSON, Column, DateTime -from sqlalchemy.orm import Mapped -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel - -if TYPE_CHECKING: - from .user import User -# ============================================ -# 旧的兼容性表模型(保留以便向后兼容) -# ============================================ - - -class LegacyUserStatistics(SQLModel, table=True): - __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) # osu, taiko, fruits, mania - - # 基本统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.0) - pp_exp: float = Field(default=0.0) - ranked_score: int = Field(default=0) - hit_accuracy: float = Field(default=0.0) - total_score: int = Field(default=0) - total_hits: int = Field(default=0) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: Mapped["User"] = Relationship(back_populates="statistics") - - -class LegacyOAuthToken(SQLModel, table=True): - __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - access_token: str = Field(max_length=255, index=True) - refresh_token: str = Field(max_length=255, index=True) - expires_at: datetime = Field(sa_column=Column(DateTime)) - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON)) - replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON)) - - # 用户关系 - user: "User" = Relationship() diff --git a/app/database/relationship.py b/app/database/relationship.py index 61dc109..07daa25 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -1,8 +1,6 @@ from enum import Enum -from app.models.user import User as APIUser - -from .user import User as DBUser +from .lazer_user import User, UserResp from pydantic import BaseModel from sqlmodel import ( @@ -28,7 +26,7 @@ class Relationship(SQLModel, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), + ForeignKey("lazer_users.id"), primary_key=True, index=True, ), @@ -37,20 +35,20 @@ class Relationship(SQLModel, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), + ForeignKey("lazer_users.id"), primary_key=True, index=True, ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) - target: DBUser = SQLRelationship( + target: User = SQLRelationship( sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} ) class RelationshipResp(BaseModel): target_id: int - target: APIUser + target: UserResp mutual: bool = False type: RelationshipType @@ -58,8 +56,6 @@ class RelationshipResp(BaseModel): async def from_db( cls, session: AsyncSession, relationship: Relationship ) -> "RelationshipResp": - from app.utils import convert_db_user_to_api_user - target_relationship = ( await session.exec( select(Relationship).where( @@ -75,7 +71,17 @@ class RelationshipResp(BaseModel): ) return cls( target_id=relationship.target_id, - target=await convert_db_user_to_api_user(relationship.target), + target=await UserResp.from_db( + relationship.target, + session, + include=[ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + ], + ), mutual=mutual, type=relationship.type, ) diff --git a/app/database/score.py b/app/database/score.py index 046f83c..c805563 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -12,9 +12,8 @@ from app.calculator import ( calculate_weighted_pp, clamp, ) -from app.database.score_token import ScoreToken -from app.database.user import LazerUserStatistics, User from app.models.beatmap import BeatmapRankStatus +from app.models.model import UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( INT_TO_MODE, @@ -26,11 +25,12 @@ from app.models.score import ( ScoreStatistics, SoloScoreSubmissionInfo, ) -from app.models.user import User as APIUser from .beatmap import Beatmap, BeatmapResp from .beatmapset import Beatmapset, BeatmapsetResp from .best_score import BestScore +from .lazer_user import User, UserResp +from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime @@ -54,7 +54,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel): +class ScoreBase(SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -94,7 +94,7 @@ class Score(ScoreBase, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), + ForeignKey("lazer_users.id"), index=True, ), ) @@ -112,8 +112,8 @@ class Score(ScoreBase, table=True): gamemode: GameMode = Field(index=True) # optional - beatmap: "Beatmap" = Relationship() - user: "User" = Relationship() + beatmap: Beatmap = Relationship() + user: User = Relationship() @property def is_perfect_combo(self) -> bool: @@ -173,7 +173,7 @@ class ScoreResp(ScoreBase): ruleset_id: int | None = None beatmap: BeatmapResp | None = None beatmapset: BeatmapsetResp | None = None - user: APIUser | None = None + user: UserResp | None = None statistics: ScoreStatistics | None = None maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None @@ -183,8 +183,6 @@ class ScoreResp(ScoreBase): async def from_db( cls, session: AsyncSession, score: Score, user: User | None = None ) -> "ScoreResp": - from app.utils import convert_db_user_to_api_user - s = cls.model_validate(score.model_dump()) assert score.id s.beatmap = BeatmapResp.from_db(score.beatmap) @@ -221,7 +219,12 @@ class ScoreResp(ScoreBase): HitResult.GREAT: score.beatmap.max_combo, } if user: - s.user = await convert_db_user_to_api_user(user) + s.user = await UserResp.from_db( + user, + session, + include=["statistics", "team", "daily_challenge_user_stats"], + ruleset=score.gamemode, + ) s.rank_global = ( await get_score_position_by_id( session, @@ -494,21 +497,20 @@ async def get_user_best_pp( async def process_user( session: AsyncSession, user: User, score: Score, ranked: bool = False ): + assert user.id previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) statistics = None add_to_db = False - for i in user.lazer_statistics: + for i in user.statistics: if i.mode == score.gamemode.value: statistics = i break if statistics is None: - statistics = LazerUserStatistics( - mode=score.gamemode.value, - user_id=user.id, + raise ValueError( + f"User {user.id} does not have statistics for mode {score.gamemode.value}" ) - add_to_db = True # pc, pt, tth, tts statistics.total_score += score.total_score @@ -546,6 +548,10 @@ async def process_user( statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) statistics.play_count += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) + statistics.count_100 += score.n100 + score.nkatu + statistics.count_300 += score.n300 + score.ngeki + statistics.count_50 += score.n50 + statistics.count_miss += score.nmiss statistics.total_hits += ( score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu ) @@ -564,8 +570,6 @@ async def process_user( statistics.pp = pp_sum statistics.hit_accuracy = acc_sum - statistics.updated_at = datetime.now(UTC) - if add_to_db: session.add(statistics) await session.commit() @@ -582,6 +586,7 @@ async def process_score( session: AsyncSession, redis: Redis, ) -> Score: + assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) score = Score( accuracy=info.accuracy, diff --git a/app/database/score_token.py b/app/database/score_token.py index 6a6edb3..4467b8b 100644 --- a/app/database/score_token.py +++ b/app/database/score_token.py @@ -1,15 +1,16 @@ from datetime import datetime +from app.models.model import UTCBaseModel from app.models.score import GameMode from .beatmap import Beatmap -from .user import User +from .lazer_user import User from sqlalchemy import Column, DateTime, Index from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel -class ScoreTokenBase(SQLModel): +class ScoreTokenBase(SQLModel, UTCBaseModel): score_id: int | None = Field(sa_column=Column(BigInteger), default=None) ruleset_id: GameMode playlist_item_id: int | None = Field(default=None) # playlist @@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True): autoincrement=True, ), ) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) beatmap_id: int = Field(foreign_key="beatmaps.id") - user: "User" = Relationship() - beatmap: "Beatmap" = Relationship() + user: User = Relationship() + beatmap: Beatmap = Relationship() class ScoreTokenResp(ScoreTokenBase): diff --git a/app/database/statistics.py b/app/database/statistics.py new file mode 100644 index 0000000..cac2971 --- /dev/null +++ b/app/database/statistics.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class UserStatisticsBase(SQLModel): + mode: GameMode + count_100: int = Field(default=0, sa_column=Column(BigInteger)) + count_300: int = Field(default=0, sa_column=Column(BigInteger)) + count_50: int = Field(default=0, sa_column=Column(BigInteger)) + count_miss: int = Field(default=0, sa_column=Column(BigInteger)) + + global_rank: int | None = Field(default=None) + country_rank: int | None = Field(default=None) + + pp: float = Field(default=0.0) + ranked_score: int = Field(default=0) + hit_accuracy: float = Field(default=0.00) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + total_hits: int = Field(default=0, sa_column=Column(BigInteger)) + maximum_combo: int = Field(default=0) + + play_count: int = Field(default=0) + play_time: int = Field(default=0, sa_column=Column(BigInteger)) + replays_watched_by_others: int = Field(default=0) + is_ranked: bool = Field(default=True) + + +class UserStatistics(UserStatisticsBase, table=True): + __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + grade_ss: int = Field(default=0) + grade_ssh: int = Field(default=0) + grade_s: int = Field(default=0) + grade_sh: int = Field(default=0) + grade_a: int = Field(default=0) + + level_current: int = Field(default=1) + level_progress: int = Field(default=0) + + user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type] + + +class UserStatisticsResp(UserStatisticsBase): + grade_counts: dict[str, int] = Field( + default_factory=lambda: { + "ss": 0, + "ssh": 0, + "s": 0, + "sh": 0, + "a": 0, + } + ) + level: dict[str, int] = Field( + default_factory=lambda: { + "current": 1, + "progress": 0, + } + ) + + @classmethod + def from_db(cls, obj: UserStatistics) -> "UserStatisticsResp": + s = cls.model_validate(obj) + s.grade_counts = { + "ss": obj.grade_ss, + "ssh": obj.grade_ssh, + "s": obj.grade_s, + "sh": obj.grade_sh, + "a": obj.grade_a, + } + s.level = { + "current": obj.level_current, + "progress": obj.level_progress, + } + return s diff --git a/app/database/team.py b/app/database/team.py index 360e805..146ca9f 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -1,14 +1,16 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class Team(SQLModel, table=True): +class Team(SQLModel, UTCBaseModel, table=True): __tablename__ = "teams" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -22,11 +24,11 @@ class Team(SQLModel, table=True): members: list["TeamMember"] = Relationship(back_populates="team") -class TeamMember(SQLModel, table=True): +class TeamMember(SQLModel, UTCBaseModel, table=True): __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) team_id: int = Field(foreign_key="teams.id") joined_at: datetime = Field( default_factory=datetime.utcnow, sa_column=Column(DateTime) diff --git a/app/database/user.py b/app/database/user.py deleted file mode 100644 index a188497..0000000 --- a/app/database/user.py +++ /dev/null @@ -1,527 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Optional - -from .legacy import LegacyUserStatistics -from .team import TeamMember - -from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text -from sqlalchemy.dialects.mysql import VARCHAR -from sqlalchemy.orm import joinedload, selectinload -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select - - -class User(SQLModel, table=True): - __tablename__ = "users" # pyright: ignore[reportAssignmentType] - - # 主键 - id: int = Field( - default=None, sa_column=Column(BigInteger, primary_key=True, index=True) - ) - - # 基本信息(匹配 migrations_old 中的结构) - name: str = Field(max_length=32, unique=True, index=True) # 用户名 - safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名 - email: str = Field(max_length=254, unique=True, index=True) - priv: int = Field(default=1) # 权限 - pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码 - country: str = Field(default="CN", max_length=2) # 国家代码 - - # 状态和时间 - silence_end: int = Field(default=0) - donor_end: int = Field(default=0) - creation_time: int = Field(default=0) # Unix 时间戳 - latest_activity: int = Field(default=0) # Unix 时间戳 - - # 游戏相关 - preferred_mode: int = Field(default=0) # 偏好游戏模式 - play_style: int = Field(default=0) # 游戏风格 - - # 扩展信息 - clan_id: int = Field(default=0) - clan_priv: int = Field(default=0) - custom_badge_name: str | None = Field(default=None, max_length=16) - custom_badge_icon: str | None = Field(default=None, max_length=64) - userpage_content: str | None = Field(default=None, max_length=2048) - api_key: str | None = Field(default=None, max_length=36, unique=True) - - # 虚拟字段用于兼容性 - @property - def username(self): - return self.name - - @property - def country_code(self): - return self.country - - @property - def join_date(self): - creation_time = getattr(self, "creation_time", 0) - return ( - datetime.fromtimestamp(creation_time) - if creation_time > 0 - else datetime.utcnow() - ) - - @property - def last_visit(self): - latest_activity = getattr(self, "latest_activity", 0) - return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None - - @property - def is_supporter(self): - return self.lazer_profile.is_supporter if self.lazer_profile else False - - # 关联关系 - lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user") - lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user") - lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user") - lazer_achievements: list["LazerUserAchievement"] = Relationship( - back_populates="user" - ) - lazer_profile_sections: list["LazerUserProfileSections"] = Relationship( - back_populates="user" - ) - statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user") - team_membership: Optional["TeamMember"] = Relationship(back_populates="user") - daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship( - back_populates="user" - ) - rank_history: list["RankHistory"] = Relationship(back_populates="user") - avatar: Optional["UserAvatar"] = Relationship(back_populates="user") - active_banners: list["LazerUserBanners"] = Relationship(back_populates="user") - lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user") - lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship( - back_populates="user" - ) - lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship( - back_populates="user" - ) - lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship( - back_populates="user" - ) - - @classmethod - def all_select_option(cls): - return ( - joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType] - joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - joinedload(cls.avatar), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership), # pyright: ignore[reportArgumentType] - selectinload(cls.rank_history), # pyright: ignore[reportArgumentType] - selectinload(cls.active_banners), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType] - ) - - @classmethod - def all_select_clause(cls): - return select(cls).options(*cls.all_select_option()) - - -# ============================================ -# Lazer API 专用表模型 -# ============================================ - - -class LazerUserProfile(SQLModel, table=True): - __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 基本状态字段 - is_active: bool = Field(default=True) - is_bot: bool = Field(default=False) - is_deleted: bool = Field(default=False) - is_online: bool = Field(default=True) - is_supporter: bool = Field(default=False) - is_restricted: bool = Field(default=False) - session_verified: bool = Field(default=False) - has_supported: bool = Field(default=False) - pm_friends_only: bool = Field(default=False) - - # 基本资料字段 - default_group: str = Field(default="default", max_length=50) - last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime)) - join_date: datetime | None = Field(default=None, sa_column=Column(DateTime)) - profile_colour: str | None = Field(default=None, max_length=7) - profile_hue: int | None = Field(default=None) - - # 社交媒体和个人资料字段 - avatar_url: str | None = Field(default=None, max_length=500) - cover_url: str | None = Field(default=None, max_length=500) - discord: str | None = Field(default=None, max_length=100) - twitter: str | None = Field(default=None, max_length=100) - website: str | None = Field(default=None, max_length=500) - title: str | None = Field(default=None, max_length=100) - title_url: str | None = Field(default=None, max_length=500) - interests: str | None = Field(default=None, sa_column=Column(Text)) - location: str | None = Field(default=None, max_length=100) - - occupation: str | None = Field(default=None) # 职业字段,默认为 None - - # 游戏相关字段 - playmode: str = Field(default="osu", max_length=10) - support_level: int = Field(default=0) - max_blocks: int = Field(default=100) - max_friends: int = Field(default=500) - post_count: int = Field(default=0) - - # 页面内容 - page_html: str | None = Field(default=None, sa_column=Column(Text)) - page_raw: str | None = Field(default=None, sa_column=Column(Text)) - - profile_order: str = Field( - default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu" - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_profile") - - -class LazerUserProfileSections(SQLModel, table=True): - __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - section_name: str = Field(sa_column=Column(VARCHAR(50))) - display_order: int | None = Field(default=None) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_profile_sections") - - -class LazerUserCountry(SQLModel, table=True): - __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - code: str = Field(max_length=2) - name: str = Field(max_length=100) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserKudosu(SQLModel, table=True): - __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - available: int = Field(default=0) - total: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserCounts(SQLModel, table=True): - __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 统计计数字段 - beatmap_playcounts_count: int = Field(default=0) - comments_count: int = Field(default=0) - favourite_beatmapset_count: int = Field(default=0) - follower_count: int = Field(default=0) - graveyard_beatmapset_count: int = Field(default=0) - guest_beatmapset_count: int = Field(default=0) - loved_beatmapset_count: int = Field(default=0) - mapping_follower_count: int = Field(default=0) - nominated_beatmapset_count: int = Field(default=0) - pending_beatmapset_count: int = Field(default=0) - ranked_beatmapset_count: int = Field(default=0) - ranked_and_approved_beatmapset_count: int = Field(default=0) - unranked_beatmapset_count: int = Field(default=0) - scores_best_count: int = Field(default=0) - scores_first_count: int = Field(default=0) - scores_pinned_count: int = Field(default=0) - scores_recent_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_counts") - - -class LazerUserStatistics(SQLModel, table=True): - __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - mode: str = Field(default="osu", max_length=10, primary_key=True) - - # 基本命中统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - ranked_score: int = Field(default=0, sa_column=Column(BigInteger)) - hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2))) - total_score: int = Field(default=0, sa_column=Column(BigInteger)) - total_hits: int = Field(default=0, sa_column=Column(BigInteger)) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) # 秒 - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_statistics") - - -class LazerUserBanners(SQLModel, table=True): - __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - tournament_id: int - image_url: str = Field(sa_column=Column(VARCHAR(500))) - is_active: bool | None = Field(default=None) - - # 修正user关系的back_populates值 - user: "User" = Relationship(back_populates="active_banners") - - -class LazerUserAchievement(SQLModel, table=True): - __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - achievement_id: int - achieved_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_achievements") - - -class LazerUserBadge(SQLModel, table=True): - __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - badge_id: int - awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) - description: str | None = Field(default=None, sa_column=Column(Text)) - image_url: str | None = Field(default=None, max_length=500) - url: str | None = Field(default=None, max_length=500) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_badges") - - -class LazerUserMonthlyPlaycounts(SQLModel, table=True): - __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - play_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_monthly_playcounts") - - -class LazerUserPreviousUsername(SQLModel, table=True): - __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - username: str = Field(max_length=32) - changed_at: datetime = Field(sa_column=Column(DateTime)) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_previous_usernames") - - -class LazerUserReplaysWatched(SQLModel, table=True): - __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_replays_watched") - - -# 类型转换用的 UserAchievement(不是 SQLAlchemy 模型) -@dataclass -class UserAchievement: - achieved_at: datetime - achievement_id: int - - -class DailyChallengeStats(SQLModel, table=True): - __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True) - ) - - daily_streak_best: int = Field(default=0) - daily_streak_current: int = Field(default=0) - last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) - last_weekly_streak: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - playcount: int = Field(default=0) - top_10p_placements: int = Field(default=0) - top_50p_placements: int = Field(default=0) - weekly_streak_best: int = Field(default=0) - weekly_streak_current: int = Field(default=0) - - user: "User" = Relationship(back_populates="daily_challenge_stats") - - -class RankHistory(SQLModel, table=True): - __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) - rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks - date_recorded: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="rank_history") - - -class UserAvatar(SQLModel, table=True): - __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - filename: str = Field(max_length=255) - original_filename: str = Field(max_length=255) - file_size: int - mime_type: str = Field(max_length=100) - is_active: bool = Field(default=True) - created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - r2_original_url: str | None = Field(default=None, max_length=500) - r2_game_url: str | None = Field(default=None, max_length=500) - - user: "User" = Relationship(back_populates="avatar") diff --git a/app/database/user_account_history.py b/app/database/user_account_history.py new file mode 100644 index 0000000..217c8eb --- /dev/null +++ b/app/database/user_account_history.py @@ -0,0 +1,45 @@ +from datetime import UTC, datetime +from enum import Enum + +from app.models.model import UTCBaseModel + +from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel + + +class UserAccountHistoryType(str, Enum): + NOTE = "note" + RESTRICTION = "restriction" + SLIENCE = "silence" + TOURNAMENT_BAN = "tournament_ban" + + +class UserAccountHistoryBase(SQLModel, UTCBaseModel): + description: str | None = None + length: int + permanent: bool = False + timestamp: datetime = Field(default=datetime.now(UTC)) + type: UserAccountHistoryType + + +class UserAccountHistory(UserAccountHistoryBase, table=True): + __tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + sa_column=Column( + Integer, + autoincrement=True, + index=True, + primary_key=True, + ) + ) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + + +class UserAccountHistoryResp(UserAccountHistoryBase): + id: int | None = None + + @classmethod + def from_db(cls, db_model: UserAccountHistory) -> "UserAccountHistoryResp": + return cls.model_validate(db_model) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 0c8f8bc..769247c 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -1,14 +1,13 @@ from __future__ import annotations from app.auth import get_token_by_access_token -from app.database import ( - User as DBUser, -) +from app.database import User from .database import get_db from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() @@ -17,7 +16,7 @@ security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), -) -> DBUser: +) -> User: """获取当前认证用户""" token = credentials.credentials @@ -27,13 +26,15 @@ async def get_current_user( return user -async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None: +async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None: token_record = await get_token_by_access_token(db, token) if not token_record: return None user = ( await db.exec( - DBUser.all_select_clause().where(DBUser.id == token_record.user_id) + select(User) + .options(*User.all_select_option()) + .where(User.id == token_record.user_id) ) ).first() return user diff --git a/app/models/model.py b/app/models/model.py new file mode 100644 index 0000000..bc00585 --- /dev/null +++ b/app/models/model.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from pydantic import BaseModel, field_serializer + + +class UTCBaseModel(BaseModel): + @field_serializer("*", when_used="json") + def serialize_datetime(self, v, _info): + if isinstance(v, datetime): + if v.tzinfo is None: + v = v.replace(tzinfo=UTC) + return v.astimezone(UTC).isoformat() + return v diff --git a/app/models/room.py b/app/models/room.py index 9ca24d2..85aae24 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -3,11 +3,13 @@ from __future__ import annotations from datetime import datetime from enum import Enum +from app.database import User from app.database.beatmap import Beatmap -from app.database.user import User from app.models.mods import APIMod -from pydantic import BaseModel +from .model import UTCBaseModel + +from pydantic import BaseModel, Field class RoomCategory(str, Enum): @@ -40,15 +42,15 @@ class RoomStatus(str, Enum): PLAYING = "playing" -class PlaylistItem(BaseModel): +class PlaylistItem(UTCBaseModel): id: int | None owner_id: int ruleset_id: int expired: bool playlist_order: int | None played_at: datetime | None - allowed_mods: list[APIMod] = [] - required_mods: list[APIMod] = [] + allowed_mods: list[APIMod] = Field(default_factory=list) + required_mods: list[APIMod] = Field(default_factory=list) beatmap_id: int beatmap: Beatmap | None freestyle: bool @@ -75,7 +77,7 @@ class PlaylistAggregateScore(BaseModel): playlist_item_attempts: list[ItemAttemptsCount] -class Room(BaseModel): +class Room(UTCBaseModel): id: int | None name: str = "" password: str | None @@ -86,9 +88,9 @@ class Room(BaseModel): starts_at: datetime | None ends_at: datetime | None participant_count: int = 0 - recent_participants: list[User] = [] + recent_participants: list[User] = Field(default_factory=list) max_attempts: int | None - playlist: list[PlaylistItem] = [] + playlist: list[PlaylistItem] = Field(default_factory=list) playlist_item_stats: RoomPlaylistItemStats | None difficulty_range: RoomDifficultyRange | None type: MatchType = MatchType.PLAYLISTS diff --git a/app/models/user.py b/app/models/user.py index dd90e47..3052eef 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,15 +2,11 @@ from __future__ import annotations from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING -from .score import GameMode +from .model import UTCBaseModel from pydantic import BaseModel -if TYPE_CHECKING: - from app.database import LazerUserAchievement, Team - class PlayStyle(str, Enum): MOUSE = "mouse" @@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel): count: int -class UserAchievement(BaseModel): - achieved_at: datetime - achievement_id: int - - # 添加数据库模型转换方法 - def to_db_model(self, user_id: int) -> "LazerUserAchievement": - from app.database import ( - LazerUserAchievement, - ) - - return LazerUserAchievement( - user_id=user_id, - achievement_id=self.achievement_id, - achieved_at=self.achieved_at, - ) - - -class RankHighest(BaseModel): +class RankHighest(UTCBaseModel): rank: int updated_at: datetime @@ -104,111 +83,6 @@ class RankHistory(BaseModel): data: list[int] -class DailyChallengeStats(BaseModel): - daily_streak_best: int = 0 - daily_streak_current: int = 0 - last_update: datetime | None = None - last_weekly_streak: datetime | None = None - playcount: int = 0 - top_10p_placements: int = 0 - top_50p_placements: int = 0 - user_id: int - weekly_streak_best: int = 0 - weekly_streak_current: int = 0 - - class Page(BaseModel): html: str = "" raw: str = "" - - -class User(BaseModel): - # 基本信息 - id: int - username: str - avatar_url: str - country_code: str - default_group: str = "default" - is_active: bool = True - is_bot: bool = False - is_deleted: bool = False - is_online: bool = True - is_supporter: bool = False - is_restricted: bool = False - last_visit: datetime | None = None - pm_friends_only: bool = False - profile_colour: str | None = None - - # 个人资料 - cover_url: str | None = None - discord: str | None = None - has_supported: bool = False - interests: str | None = None - join_date: datetime - location: str | None = None - max_blocks: int = 100 - max_friends: int = 500 - occupation: str | None = None - playmode: GameMode = GameMode.OSU - playstyle: list[PlayStyle] = [] - post_count: int = 0 - profile_hue: int | None = None - profile_order: list[str] = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - title: str | None = None - title_url: str | None = None - twitter: str | None = None - website: str | None = None - session_verified: bool = False - support_level: int = 0 - - # 关联对象 - country: Country - cover: Cover - kudosu: Kudosu - statistics: Statistics - statistics_rulesets: dict[str, Statistics] - - # 计数信息 - beatmap_playcounts_count: int = 0 - comments_count: int = 0 - favourite_beatmapset_count: int = 0 - follower_count: int = 0 - graveyard_beatmapset_count: int = 0 - guest_beatmapset_count: int = 0 - loved_beatmapset_count: int = 0 - mapping_follower_count: int = 0 - nominated_beatmapset_count: int = 0 - pending_beatmapset_count: int = 0 - ranked_beatmapset_count: int = 0 - ranked_and_approved_beatmapset_count: int = 0 - unranked_beatmapset_count: int = 0 - scores_best_count: int = 0 - scores_first_count: int = 0 - scores_pinned_count: int = 0 - scores_recent_count: int = 0 - - # 历史数据 - account_history: list[dict] = [] - active_tournament_banner: dict | None = None - active_tournament_banners: list[dict] = [] - badges: list[dict] = [] - current_season_stats: dict | None = None - daily_challenge_user_stats: DailyChallengeStats | None = None - groups: list[dict] = [] - monthly_playcounts: list[MonthlyPlaycount] = [] - page: Page = Page() - previous_usernames: list[str] = [] - rank_highest: RankHighest | None = None - rank_history: RankHistory | None = None - rankHistory: RankHistory | None = None # 兼容性别名 - replays_watched_counts: list[dict] = [] - team: "Team | None" = None - user_achievements: list[UserAchievement] = [] diff --git a/app/router/auth.py b/app/router/auth.py index 0f41b32..7a2a14d 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import timedelta +from datetime import UTC, datetime, timedelta import re from app.auth import ( @@ -12,17 +12,21 @@ from app.auth import ( store_token, ) from app.config import settings -from app.database import User as DBUser +from app.database import DailyChallengeStats, User +from app.database.statistics import UserStatistics from app.dependencies import get_db +from app.log import logger from app.models.oauth import ( OAuthErrorResponse, RegistrationRequestErrors, TokenResponse, UserRegistrationErrors, ) +from app.models.score import GameMode from fastapi import APIRouter, Depends, Form from fastapi.responses import JSONResponse +from sqlalchemy import text from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -110,12 +114,12 @@ async def register_user( email_errors = validate_email(user_email) password_errors = validate_password(user_password) - result = await db.exec(select(DBUser).where(DBUser.name == user_username)) + result = await db.exec(select(User).where(User.username == user_username)) existing_user = result.first() if existing_user: username_errors.append("Username is already taken") - result = await db.exec(select(DBUser).where(DBUser.email == user_email)) + result = await db.exec(select(User).where(User.email == user_email)) existing_email = result.first() if existing_email: email_errors.append("Email is already taken") @@ -135,119 +139,41 @@ async def register_user( try: # 创建新用户 - from datetime import datetime - import time + # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) + result = await db.execute( # pyright: ignore[reportDeprecated] + text( + "SELECT AUTO_INCREMENT FROM information_schema.TABLES " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'" + ) + ) + next_id = result.one()[0] + if next_id <= 2: + await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3")) + await db.commit() - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), # 安全用户名(小写) + new_user = User( + username=user_username, email=user_email, pw_bcrypt=get_password_hash(user_password), priv=1, # 普通用户权限 - country="CN", # 默认国家 - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, # 默认模式 - play_style=0, # 默认游戏风格 + country_code="CN", # 默认国家 + join_date=datetime.now(UTC), + last_visit=datetime.now(UTC), ) - db.add(new_user) await db.commit() await db.refresh(new_user) - - # 保存用户ID,因为会话可能会关闭 - user_id = new_user.id - - if user_id <= 2: - await db.rollback() - try: - from sqlalchemy import text - - # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) - await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3")) - await db.commit() - - # 重新创建用户 - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), - email=user_email, - pw_bcrypt=get_password_hash(user_password), - priv=1, - country="CN", - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, - play_style=0, - ) - - db.add(new_user) - await db.commit() - await db.refresh(new_user) - user_id = new_user.id - - # 最终检查ID是否有效 - if user_id <= 2: - await db.rollback() - errors = RegistrationRequestErrors( - message=( - "Failed to create account with valid ID. " - "Please contact support." - ) - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - except Exception as fix_error: - await db.rollback() - print(f"Failed to fix AUTO_INCREMENT: {fix_error}") - errors = RegistrationRequestErrors( - message="Failed to create account with valid ID. Please try again." - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - # 创建默认的 lazer_profile - from app.database.user import LazerUserProfile - - lazer_profile = LazerUserProfile( - user_id=user_id, - is_active=True, - is_bot=False, - is_deleted=False, - is_online=True, - is_supporter=False, - is_restricted=False, - session_verified=False, - has_supported=False, - pm_friends_only=False, - default_group="default", - join_date=datetime.utcnow(), - playmode="osu", - support_level=0, - max_blocks=50, - max_friends=250, - post_count=0, - ) - - db.add(lazer_profile) + assert new_user.id is not None, "New user ID should not be None" + for i in GameMode: + statistics = UserStatistics(mode=i, user_id=new_user.id) + db.add(statistics) + daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id) + db.add(daily_challenge_user_stats) await db.commit() - - # 返回成功响应 - return JSONResponse( - status_code=201, - content={"message": "Account created successfully", "user_id": user_id}, - ) - - except Exception as e: + except Exception: await db.rollback() # 打印详细错误信息用于调试 - print(f"Registration error: {e}") - import traceback - - traceback.print_exc() + logger.exception(f"Registration error for user {user_username}") # 返回通用错误 errors = RegistrationRequestErrors( @@ -323,6 +249,7 @@ async def oauth_token( refresh_token_str = generate_refresh_token() # 存储令牌 + assert user.id await store_token( db, user.id, diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 71d554f..4af9c9a 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,12 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import ( - Beatmap, - BeatmapResp, - User as DBUser, -) -from app.database.beatmapset import Beatmapset +from app.database import Beatmap, BeatmapResp, Beatmapset, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -39,7 +34,7 @@ async def lookup_beatmap( id: int | None = Query(default=None, alias="id"), md5: str | None = Query(default=None, alias="checksum"), filename: str | None = Query(default=None, alias="filename"), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -62,7 +57,7 @@ async def lookup_beatmap( @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -81,7 +76,7 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( b_ids: list[int] = Query(alias="id", default_factory=list), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if not b_ids: @@ -126,7 +121,7 @@ async def batch_get_beatmaps( ) async def get_beatmap_attributes( beatmap: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), mods: list[str] = Query(default_factory=list), ruleset: GameMode | None = Query(default=None), ruleset_id: int | None = Query(default=None), diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index db2dd77..e551727 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -3,7 +3,7 @@ from __future__ import annotations from app.database import ( Beatmapset, BeatmapsetResp, - User as DBUser, + User, ) from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher @@ -22,7 +22,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): diff --git a/app/router/me.py b/app/router/me.py index 93dcbdc..e3aa734 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -1,28 +1,34 @@ from __future__ import annotations -from typing import Literal - -from app.database import ( - User as DBUser, -) +from app.database import User, UserResp from app.dependencies import get_current_user -from app.models.user import ( - User as ApiUser, -) -from app.utils import convert_db_user_to_api_user +from app.dependencies.database import get_db +from app.models.score import GameMode from .api_router import router from fastapi import Depends +from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/me/{ruleset}", response_model=ApiUser) -@router.get("/me/", response_model=ApiUser) +@router.get("/me/{ruleset}", response_model=UserResp) +@router.get("/me/", response_model=UserResp) async def get_user_info_default( - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", - current_user: DBUser = Depends(get_current_user), + ruleset: GameMode | None = None, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), ): - """获取当前用户信息(默认使用osu模式)""" - # 默认使用osu模式 - api_user = await convert_db_user_to_api_user(current_user, ruleset) - return api_user + return await UserResp.from_db( + current_user, + session, + [ + "friends", + "team", + "account_history", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + ], + ruleset, + ) diff --git a/app/router/relationship.py b/app/router/relationship.py index 9ed5b0f..9e39e8b 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user from .api_router import router from fastapi import Depends, HTTPException, Query, Request +from pydantic import BaseModel from sqlalchemy.orm import joinedload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -36,7 +37,11 @@ async def get_relationship( return [await RelationshipResp.from_db(db, rel) for rel in relationships] -@router.post("/friends", tags=["relationship"], response_model=RelationshipResp) +class AddFriendResp(BaseModel): + user_relation: RelationshipResp + + +@router.post("/friends", tags=["relationship"], response_model=AddFriendResp) @router.post("/blocks", tags=["relationship"]) async def add_relationship( request: Request, @@ -98,7 +103,9 @@ async def add_relationship( ) ).first() assert relationship, "Relationship should exist after commit" - return await RelationshipResp.from_db(db, relationship) + return AddFriendResp( + user_relation=await RelationshipResp.from_db(db, relationship) + ) @router.delete("/friends/{target}", tags=["relationship"]) diff --git a/app/router/score.py b/app/router/score.py index cc38dcc..baab3a2 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,11 +1,7 @@ from __future__ import annotations -from app.database import ( - User as DBUser, -) -from app.database.beatmap import Beatmap -from app.database.score import Score, ScoreResp, process_score, process_user -from app.database.score_token import ScoreToken, ScoreTokenResp +from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User +from app.database.score import process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -41,7 +37,7 @@ async def get_beatmap_scores( mode: GameMode | None = Query(None), # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 type: str = Query(None), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -94,7 +90,7 @@ async def get_user_beatmap_score( legacy_only: bool = Query(None), mode: str = Query(None), mods: str = Query(None), # TODO:添加mods筛选 - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -134,7 +130,7 @@ async def get_user_all_beatmap_scores( user: int, legacy_only: bool = Query(None), ruleset: str = Query(None), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -166,9 +162,10 @@ async def create_solo_score( version_hash: str = Form(""), beatmap_hash: str = Form(), ruleset_id: int = Form(..., ge=0, le=3), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + assert current_user.id async with db: score_token = ScoreToken( user_id=current_user.id, @@ -190,7 +187,7 @@ async def submit_solo_score( beatmap: int, token: int, info: SoloScoreSubmissionInfo, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), diff --git a/app/router/user.py b/app/router/user.py index 6e169c3..cfe136c 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -1,12 +1,8 @@ from __future__ import annotations -from typing import Literal - -from app.database import User as DBUser +from app.database import User, UserResp from app.dependencies.database import get_db -from app.models.score import INT_TO_MODE -from app.models.user import User as ApiUser -from app.utils import convert_db_user_to_api_user +from app.models.score import GameMode from .api_router import router @@ -17,28 +13,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import col -# ---------- Shared Utility ---------- -async def get_user_by_lookup( - db: AsyncSession, lookup: str, key: str = "id" -) -> DBUser | None: - """根据查找方式获取用户""" - if key == "id": - try: - user_id = int(lookup) - result = await db.exec(select(DBUser).where(DBUser.id == user_id)) - return result.first() - except ValueError: - return None - elif key == "username": - result = await db.exec(select(DBUser).where(DBUser.name == lookup)) - return result.first() - else: - return None - - -# ---------- Batch Users ---------- class BatchUserResponse(BaseModel): - users: list[ApiUser] + users: list[UserResp] + + +SEARCH_INCLUDE = [ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", +] @router.get("/users", response_model=BatchUserResponse) @@ -52,74 +37,54 @@ async def get_users( if user_ids: searched_users = ( await session.exec( - DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids)) + select(User) + .options(*User.all_select_option()) + .limit(50) + .where(col(User.id).in_(user_ids)) ) ).all() else: searched_users = ( - await session.exec(DBUser.all_select_clause().limit(50)) + await session.exec( + select(User).options(*User.all_select_option()).limit(50) + ) ).all() return BatchUserResponse( users=[ - await convert_db_user_to_api_user( - searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value + await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDE, ) for searched_user in searched_users ] ) -# # ---------- Individual User ---------- -# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser) -# async def get_user_with_mode( -# user_lookup: str, -# mode: Literal["osu", "taiko", "fruits", "mania"], -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取指定游戏模式的用户信息""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, mode) - - -# @router.get("/users/{user_lookup}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/", response_model=ApiUser) -# async def get_user_default( -# user_lookup: str, -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, "osu") - - -@router.get("/users/{user}/{ruleset}", response_model=ApiUser) -@router.get("/users/{user}/", response_model=ApiUser) -@router.get("/users/{user}", response_model=ApiUser) +@router.get("/users/{user}/{ruleset}", response_model=UserResp) +@router.get("/users/{user}/", response_model=UserResp) +@router.get("/users/{user}", response_model=UserResp) async def get_user_info( user: str, - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", + ruleset: GameMode | None = None, session: AsyncSession = Depends(get_db), ): searched_user = ( await session.exec( - DBUser.all_select_clause().where( - DBUser.id == int(user) + select(User) + .options(*User.all_select_option()) + .where( + User.id == int(user) if user.isdigit() - else DBUser.name == user.removeprefix("@") + else User.username == user.removeprefix("@") ) ) ).first() if not searched_user: raise HTTPException(404, detail="User not found") - return await convert_db_user_to_api_user(searched_user, ruleset=ruleset) + return await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDE, + ruleset=ruleset, + ) diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 821d831..2712883 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import Coroutine from typing import override -from app.database.relationship import Relationship, RelationshipType +from app.database import Relationship, RelationshipType from app.dependencies.database import engine from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index 0d0899e..f388c92 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -7,10 +7,9 @@ import struct import time from typing import override -from app.database import Beatmap +from app.database import Beatmap, User from app.database.score import Score from app.database.score_token import ScoreToken -from app.database.user import User from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int @@ -197,7 +196,7 @@ class SpectatorHub(Hub[StoreClientState]): ).first() if not user: return - name = user.name + name = user.username store.state = state store.beatmap_status = beatmap.beatmap_status store.checksum = beatmap.checksum @@ -339,7 +338,7 @@ class SpectatorHub(Hub[StoreClientState]): async with AsyncSession(engine) as session: async with session.begin(): username = ( - await session.exec(select(User.name).where(User.id == user_id)) + await session.exec(select(User.username).where(User.id == user_id)) ).first() if not username: return diff --git a/app/signalr/router.py b/app/signalr/router.py index 237a575..72b22ac 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -6,7 +6,7 @@ import time from typing import Literal import uuid -from app.database import User as DBUser +from app.database import User from app.dependencies import get_current_user from app.dependencies.database import get_db from app.dependencies.user import get_current_user_by_token @@ -25,7 +25,7 @@ router = APIRouter() async def negotiate( hub: Literal["spectator", "multiplayer", "metadata"], negotiate_version: int = Query(1, alias="negotiateVersion"), - user: DBUser = Depends(get_current_user), + user: User = Depends(get_current_user), ): connectionId = str(user.id) connectionToken = f"{connectionId}:{uuid.uuid4()}" diff --git a/app/utils.py b/app/utils.py index 9008706..09e8fdc 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,465 +1,6 @@ from __future__ import annotations -from datetime import UTC, datetime - -from app.database import ( - LazerUserCounts, - LazerUserProfile, - LazerUserStatistics, - User as DBUser, -) -from app.models.user import ( - Country, - Cover, - DailyChallengeStats, - GradeCounts, - Kudosu, - Level, - Page, - RankHighest, - RankHistory, - Statistics, - User, - UserAchievement, -) - def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" return (timestamp + 62135596800) * 10_000_000 - - -async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User: - """将数据库用户模型转换为API用户模型(使用 Lazer 表)""" - - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - user_country_code = user_country # 在User模型中,country字段就是country_code - - # 获取 Lazer 用户资料 - profile = db_user.lazer_profile - if not profile: - # 如果没有 lazer 资料,使用默认值 - profile = LazerUserProfile( - user_id=user_id, - ) - - # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系 - lzrcnt = db_user.lazer_counts - - if not lzrcnt: - # 如果没有 lazer 计数,使用默认值 - lzrcnt = LazerUserCounts(user_id=user_id) - - # 获取指定模式的统计信息 - user_stats = None - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - if stat.mode == ruleset: - user_stats = stat - break - - if not user_stats: - # 如果没有找到指定模式的统计,创建默认统计 - user_stats = LazerUserStatistics(user_id=user_id) - - # 获取国家信息 - country_code = db_user.country_code if db_user.country_code is not None else "XX" - - country = Country(code=str(country_code), name=get_country_name(str(country_code))) - - # 获取 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 获取计数信息 - # counts = LazerUserCounts(user_id=user_id) - - # 转换统计信息 - statistics = Statistics( - count_100=user_stats.count_100, - count_300=user_stats.count_300, - count_50=user_stats.count_50, - count_miss=user_stats.count_miss, - level=Level( - current=user_stats.level_current, progress=user_stats.level_progress - ), - global_rank=user_stats.global_rank, - global_rank_exp=user_stats.global_rank_exp, - pp=float(user_stats.pp) if user_stats.pp else 0.0, - pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0, - ranked_score=user_stats.ranked_score, - hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0, - play_count=user_stats.play_count, - play_time=user_stats.play_time, - total_score=user_stats.total_score, - total_hits=user_stats.total_hits, - maximum_combo=user_stats.maximum_combo, - replays_watched_by_others=user_stats.replays_watched_by_others, - is_ranked=user_stats.is_ranked, - grade_counts=GradeCounts( - ss=user_stats.grade_ss, - ssh=user_stats.grade_ssh, - s=user_stats.grade_s, - sh=user_stats.grade_sh, - a=user_stats.grade_a, - ), - country_rank=user_stats.country_rank, - rank={"country": user_stats.country_rank} if user_stats.country_rank else None, - ) - - # 转换所有模式的统计信息 - statistics_rulesets = {} - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - statistics_rulesets[stat.mode] = Statistics( - count_100=stat.count_100, - count_300=stat.count_300, - count_50=stat.count_50, - count_miss=stat.count_miss, - level=Level(current=stat.level_current, progress=stat.level_progress), - global_rank=stat.global_rank, - global_rank_exp=stat.global_rank_exp, - pp=float(stat.pp) if stat.pp else 0.0, - pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0, - ranked_score=stat.ranked_score, - hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0, - play_count=stat.play_count, - play_time=stat.play_time, - total_score=stat.total_score, - total_hits=stat.total_hits, - maximum_combo=stat.maximum_combo, - replays_watched_by_others=stat.replays_watched_by_others, - is_ranked=stat.is_ranked, - grade_counts=GradeCounts( - ss=stat.grade_ss, - ssh=stat.grade_ssh, - s=stat.grade_s, - sh=stat.grade_sh, - a=stat.grade_a, - ), - country_rank=stat.country_rank, - rank={"country": stat.country_rank} if stat.country_rank else None, - ) - - # 转换国家信息 - country = Country(code=user_country_code, name=get_country_name(user_country_code)) - - # 转换封面信息 - cover_url = ( - profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg" - ) - cover = Cover( - custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None - ) - - # 转换 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 转换成就信息 - user_achievements = [] - if db_user.lazer_achievements: - for achievement in db_user.lazer_achievements: - user_achievements.append( - UserAchievement( - achieved_at=achievement.achieved_at, - achievement_id=achievement.achievement_id, - ) - ) - - # 转换排名历史 - rank_history = None - rank_history_data = None - for rh in db_user.rank_history: - if rh.mode == ruleset: - rank_history_data = rh.rank_data - break - - if rank_history_data: - rank_history = RankHistory(mode=ruleset, data=rank_history_data) - - # 转换每日挑战统计 - # daily_challenge_stats = None - # if db_user.daily_challenge_stats: - # dcs = db_user.daily_challenge_stats - # daily_challenge_stats = DailyChallengeStats( - # daily_streak_best=dcs.daily_streak_best, - # daily_streak_current=dcs.daily_streak_current, - # last_update=dcs.last_update, - # last_weekly_streak=dcs.last_weekly_streak, - # playcount=dcs.playcount, - # top_10p_placements=dcs.top_10p_placements, - # top_50p_placements=dcs.top_50p_placements, - # user_id=dcs.user_id, - # weekly_streak_best=dcs.weekly_streak_best, - # weekly_streak_current=dcs.weekly_streak_current, - # ) - - # 转换最高排名 - rank_highest = None - if user_stats.rank_highest: - rank_highest = RankHighest( - rank=user_stats.rank_highest, - updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(), - ) - - # 转换团队信息 - team = None - if db_user.team_membership: - team_member = db_user.team_membership # 假设用户只属于一个团队 - team = team_member.team - - # 创建用户对象 - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - - # 获取用户头像URL - avatar_url = None - - # 首先检查 profile 中的 avatar_url - if profile and hasattr(profile, "avatar_url") and profile.avatar_url: - avatar_url = str(profile.avatar_url) - - # 然后检查是否有关联的头像记录 - if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None: - if db_user.avatar.r2_game_url: - # 优先使用游戏用的头像URL - avatar_url = str(db_user.avatar.r2_game_url) - elif db_user.avatar.r2_original_url: - # 其次使用原始头像URL - avatar_url = str(db_user.avatar.r2_original_url) - - # 如果还是没有找到,通过查询获取 - # if db_session and avatar_url is None: - # try: - # # 导入UserAvatar模型 - - # # 尝试查找用户的头像记录 - # statement = select(UserAvatar).where( - # UserAvatar.user_id == user_id, UserAvatar.is_active == True - # ) - # avatar_record = db_session.exec(statement).first() - # if avatar_record is not None: - # if avatar_record.r2_game_url is not None: - # # 优先使用游戏用的头像URL - # avatar_url = str(avatar_record.r2_game_url) - # elif avatar_record.r2_original_url is not None: - # # 其次使用原始头像URL - # avatar_url = str(avatar_record.r2_original_url) - # except Exception as e: - # print(f"获取用户头像时出错: {e}") - # print(f"最终头像URL: {avatar_url}") - # 如果仍然没有找到头像URL,则使用默认URL - if avatar_url is None: - avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1" - - # 处理 profile_order 列表排序 - profile_order = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - if profile and profile.profile_order: - profile_order = profile.profile_order.split(",") - - # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理 - active_tournament_banners = [] - if db_user.active_banners: - for banner in db_user.active_banners: - active_tournament_banners.append( - { - "tournament_id": banner.tournament_id, - "image_url": banner.image_url, - "is_active": banner.is_active, - } - ) - - # 在convert_db_user_to_api_user函数中添加badges处理 - badges = [] - if db_user.lazer_badges: - for badge in db_user.lazer_badges: - badges.append( - { - "badge_id": badge.badge_id, - "awarded_at": badge.awarded_at, - "description": badge.description, - "image_url": badge.image_url, - "url": badge.url, - } - ) - - # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理 - monthly_playcounts = [] - if db_user.lazer_monthly_playcounts: - for playcount in db_user.lazer_monthly_playcounts: - monthly_playcounts.append( - { - "start_date": playcount.start_date.isoformat() - if playcount.start_date - else None, - "play_count": playcount.play_count, - } - ) - - # 在convert_db_user_to_api_user函数中添加previous_usernames处理 - previous_usernames = [] - if db_user.lazer_previous_usernames: - for username in db_user.lazer_previous_usernames: - previous_usernames.append( - { - "username": username.username, - "changed_at": username.changed_at.isoformat() - if username.changed_at - else None, - } - ) - - # 在convert_db_user_to_api_user函数中添加replays_watched_counts处理 - replays_watched_counts = [] - if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched: - for replay in db_user.lazer_replays_watched: - replays_watched_counts.append( - { - "start_date": replay.start_date.isoformat() - if replay.start_date - else None, - "count": replay.count, - } - ) - - # 创建用户对象 - user = User( - id=user_id, - username=user_name, - avatar_url=avatar_url, - country_code=str(country_code), - default_group=profile.default_group if profile else "default", - is_active=profile.is_active, - is_bot=profile.is_bot, - is_deleted=profile.is_deleted, - is_online=profile.is_online, - is_supporter=profile.is_supporter, - is_restricted=profile.is_restricted, - last_visit=db_user.last_visit, - pm_friends_only=profile.pm_friends_only, - profile_colour=profile.profile_colour, - cover_url=profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg", - discord=profile.discord if profile else None, - has_supported=profile.has_supported if profile else False, - interests=profile.interests if profile else None, - join_date=profile.join_date if profile.join_date else datetime.now(UTC), - location=profile.location if profile else None, - max_blocks=profile.max_blocks if profile and profile.max_blocks else 100, - max_friends=profile.max_friends if profile and profile.max_friends else 500, - post_count=profile.post_count if profile and profile.post_count else 0, - profile_hue=profile.profile_hue if profile and profile.profile_hue else None, - profile_order=profile_order, # 使用排序后的 profile_order - title=profile.title if profile else None, - title_url=profile.title_url if profile else None, - twitter=profile.twitter if profile else None, - website=profile.website if profile else None, - session_verified=True, - support_level=profile.support_level if profile else 0, - country=country, - cover=cover, - kudosu=kudosu, - statistics=statistics, - statistics_rulesets=statistics_rulesets, - beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0, - comments_count=lzrcnt.comments_count if lzrcnt else 0, - favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0, - follower_count=lzrcnt.follower_count if lzrcnt else 0, - graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0, - guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0, - loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0, - mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0, - nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0, - pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0, - ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0, - ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count - if lzrcnt - else 0, - unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0, - scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0, - scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0, - scores_pinned_count=lzrcnt.scores_pinned_count, - scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0, - account_history=[], # TODO: 获取用户历史账户信息 - # active_tournament_banner=len(active_tournament_banners), - active_tournament_banners=active_tournament_banners, - badges=badges, - current_season_stats=None, - daily_challenge_user_stats=DailyChallengeStats( - user_id=user_id, - daily_streak_best=db_user.daily_challenge_stats.daily_streak_best - if db_user.daily_challenge_stats - else 0, - daily_streak_current=db_user.daily_challenge_stats.daily_streak_current - if db_user.daily_challenge_stats - else 0, - last_update=db_user.daily_challenge_stats.last_update - if db_user.daily_challenge_stats - else None, - last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak - if db_user.daily_challenge_stats - else None, - playcount=db_user.daily_challenge_stats.playcount - if db_user.daily_challenge_stats - else 0, - top_10p_placements=db_user.daily_challenge_stats.top_10p_placements - if db_user.daily_challenge_stats - else 0, - top_50p_placements=db_user.daily_challenge_stats.top_50p_placements - if db_user.daily_challenge_stats - else 0, - weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best - if db_user.daily_challenge_stats - else 0, - weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current - if db_user.daily_challenge_stats - else 0, - ), - groups=[], - monthly_playcounts=monthly_playcounts, - page=Page(html=profile.page_html or "", raw=profile.page_raw or "") - if profile.page_html or profile.page_raw - else Page(), - previous_usernames=previous_usernames, - rank_highest=rank_highest, - rank_history=rank_history, - rankHistory=rank_history, - replays_watched_counts=replays_watched_counts, - team=team, - user_achievements=user_achievements, - ) - - return user - - -def get_country_name(country_code: str) -> str: - """根据国家代码获取国家名称""" - country_names = { - "CN": "China", - "JP": "Japan", - "US": "United States", - "GB": "United Kingdom", - "DE": "Germany", - "FR": "France", - "KR": "South Korea", - "CA": "Canada", - "AU": "Australia", - "BR": "Brazil", - # 可以添加更多国家 - } - return country_names.get(country_code, "Unknown") diff --git a/create_sample_data.py b/create_sample_data.py deleted file mode 100644 index 5dcd79a..0000000 --- a/create_sample_data.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 -""" -osu! API 模拟服务器的示例数据填充脚本 -""" - -from __future__ import annotations - -import asyncio -from datetime import datetime -import random - -from app.auth import get_password_hash -from app.database import ( - User, -) -from app.database.beatmap import Beatmap -from app.database.beatmapset import Beatmapset -from app.database.score import Score -from app.dependencies.database import create_tables, engine -from app.models.beatmap import BeatmapRankStatus, Genre, Language -from app.models.mods import APIMod -from app.models.score import GameMode, Rank - -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - - -async def create_sample_user(): - """创建示例用户数据""" - async with AsyncSession(engine) as session: - async with session.begin(): - # 检查用户是否已存在 - result = await session.exec(select(User).where(User.name == "Googujiang")) - result2 = await session.exec( - select(User).where(User.name == "MingxuanGame") - ) - existing_user = result.first() - existing_user2 = result2.first() - if existing_user is not None and existing_user2 is not None: - print("示例用户已存在,跳过创建") - return - - # 当前时间戳 - # current_timestamp = int(time.time()) - join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) - last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) - - # 创建用户 - user = User( - name="Googujiang", - safe_name="googujiang", # 安全用户名(小写) - email="googujiang@example.com", - priv=1, # 默认权限 - pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 - country="JP", - silence_end=0, - donor_end=0, - creation_time=join_timestamp, - latest_activity=last_visit_timestamp, - clan_id=0, - clan_priv=0, - preferred_mode=0, # 0 = osu! - play_style=0, - custom_badge_name=None, - custom_badge_icon=None, - userpage_content="「世界に忘れられた」", - api_key=None, - ) - user2 = User( - name="MingxuanGame", - safe_name="mingxuangame", # 安全用户名(小写) - email="mingxuangame@example.com", - priv=1, # 默认权限 - pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 - country="US", - silence_end=0, - donor_end=0, - creation_time=join_timestamp, - latest_activity=last_visit_timestamp, - clan_id=0, - clan_priv=0, - preferred_mode=0, # 0 = osu! - play_style=0, - custom_badge_name=None, - custom_badge_icon=None, - userpage_content="For love and fun!", - api_key=None, - ) - - session.add(user) - session.add(user2) - print(f"成功创建示例用户: {user.name} (ID: {user.id})") - print(f"安全用户名: {user.safe_name}") - print(f"邮箱: {user.email}") - print(f"国家: {user.country}") - print(f"成功创建示例用户: {user2.name} (ID: {user2.id})") - print(f"安全用户名: {user2.safe_name}") - print(f"邮箱: {user2.email}") - print(f"国家: {user2.country}") - - -async def create_sample_beatmap_data(): - """创建示例谱面数据""" - async with AsyncSession(engine) as session: - async with session.begin(): - user_id = random.randint(1, 1000) - # 检查谱面集是否已存在 - statement = select(Beatmapset).where(Beatmapset.id == 1) - result = await session.exec(statement) - existing_beatmapset = result.first() - if existing_beatmapset: - print("示例谱面集已存在,跳过创建") - return existing_beatmapset - - # 创建谱面集 - beatmapset = Beatmapset( - id=1, - artist="Example Artist", - artist_unicode="Example Artist", - covers=None, - creator="Googujiang", - favourite_count=0, - hype_current=0, - hype_required=0, - nsfw=False, - play_count=0, - preview_url="", - source="", - spotlight=False, - title="Example Song", - title_unicode="Example Song", - user_id=user_id, - video=False, - availability_info=None, - download_disabled=False, - bpm=180.0, - can_be_hyped=False, - discussion_locked=False, - last_updated=datetime.now(), - ranked_date=datetime.now(), - storyboard=False, - submitted_date=datetime.now(), - current_nominations=[], - beatmap_status=BeatmapRankStatus.RANKED, - beatmap_genre=Genre.ANY, # 使用整数表示Genre枚举 - beatmap_language=Language.ANY, # 使用整数表示Language枚举 - nominations_required=0, - nominations_current=0, - pack_tags=[], - ratings=[], - ) - session.add(beatmapset) - - # 创建谱面 - beatmap = Beatmap( - id=1, - url="", - mode=GameMode.OSU, - beatmapset_id=1, - difficulty_rating=5.5, - beatmap_status=BeatmapRankStatus.RANKED, - total_length=195, - user_id=user_id, - version="Example Difficulty", - checksum="example_checksum", - current_user_playcount=0, - max_combo=1200, - ar=9.0, - cs=4.0, - drain=5.0, - accuracy=8.0, - bpm=180.0, - count_circles=1000, - count_sliders=200, - count_spinners=1, - deleted_at=None, - hit_length=180, - last_updated=datetime.now(), - passcount=10, - playcount=50, - ) - session.add(beatmap) - - # 创建成绩 - score = Score( - id=1, - accuracy=0.9876, - map_md5="example_checksum", - user_id=1, - best_id=1, - build_id=None, - classic_total_score=1234567, - ended_at=datetime.now(), - has_replay=True, - max_combo=1100, - mods=[ - APIMod(acronym="HD", settings={}), - APIMod(acronym="DT", settings={}), - ], - passed=True, - playlist_item_id=None, - pp=250.5, - preserve=True, - rank=Rank.S, - room_id=None, - gamemode=GameMode.OSU, - started_at=datetime.now(), - total_score=1234567, - type="solo_score", - position=None, - beatmap_id=1, - n300=950, - n100=30, - n50=20, - nmiss=5, - ngeki=150, - nkatu=50, - nlarge_tick_miss=None, - nslider_tail_hit=None, - ) - session.add(score) - - print(f"成功创建示例谱面集: {beatmapset.title} (ID: {beatmapset.id})") - print(f"成功创建示例谱面: {beatmap.version} (ID: {beatmap.id})") - print(f"成功创建示例成绩: ID {score.id}") - return beatmapset - - -async def main(): - print("开始创建示例数据...") - await create_tables() - await create_sample_user() - await create_sample_beatmap_data() - print("示例数据创建完成!") - # print(f"用户名: {user.name}") - # print("密码: password123") - # print("现在您可以使用这些凭据来测试API了。") - await engine.dispose() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/main.py b/main.py index 526d593..92d4402 100644 --- a/main.py +++ b/main.py @@ -4,16 +4,12 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.database import Team # noqa: F401 from app.dependencies.database import create_tables, engine from app.dependencies.fetcher import get_fetcher -from app.models.user import User from app.router import api_router, auth_router, fetcher_router, signalr_router from fastapi import FastAPI -User.model_rebuild() - @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py b/migrations/versions/78be13c71791_score_remove_best_id_in_database.py deleted file mode 100644 index d0cab2b..0000000 --- a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py +++ /dev/null @@ -1,38 +0,0 @@ -"""score: remove best_id in database - -Revision ID: 78be13c71791 -Revises: dc4d25c428c7 -Create Date: 2025-07-29 07:57:33.764517 - -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import mysql - -# revision identifiers, used by Alembic. -revision: str = "78be13c71791" -down_revision: str | Sequence[str] | None = "dc4d25c428c7" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "best_id") - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "scores", - sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True), - ) - # ### end Alembic commands ### diff --git a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py deleted file mode 100644 index d90ec3d..0000000 --- a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py +++ /dev/null @@ -1,36 +0,0 @@ -"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator - -Revision ID: dc4d25c428c7 -Revises: -Create Date: 2025-07-29 01:43:40.221070 - -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from alembic import op -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision: str = "dc4d25c428c7" -down_revision: str | Sequence[str] | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True)) - op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "nsmall_tick_hit") - op.drop_column("scores", "nlarge_tick_hit") - # ### end Alembic commands ### From a15c3cef04649b97a5cf9c33283a523f570bf716 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 02:13:18 +0000 Subject: [PATCH 02/45] feat(user): add monthly playcounts --- app/database/lazer_user.py | 34 +++++++++++++++++++++-- app/database/monthly_playcounts.py | 43 ++++++++++++++++++++++++++++++ app/database/score.py | 23 +++++++++++++--- app/router/me.py | 11 ++------ app/router/user.py | 14 +++------- 5 files changed, 99 insertions(+), 26 deletions(-) create mode 100644 app/database/monthly_playcounts.py diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 9627015..9b98c98 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -7,6 +7,7 @@ from app.models.user import Country, Page, RankHistory from .achievement import UserAchievement, UserAchievementResp from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp @@ -141,6 +142,7 @@ class User(UserBase, table=True): daily_challenge_stats: DailyChallengeStats | None = Relationship( back_populates="user" ) + monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -160,6 +162,7 @@ class User(UserBase, table=True): selectinload(cls.achievement), # pyright: ignore[reportArgumentType] joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] + selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType] ) @@ -186,7 +189,7 @@ class UserResp(UserBase): account_history: list[UserAccountHistoryResp] = [] active_tournament_banners: list[dict] = [] # TODO kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO - monthly_playcounts: list = Field(default_factory=list) # TODO + monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list) unread_pm_count: int = 0 # TODO rank_history: RankHistory | None = None # TODO rank_highest: RankHighest | None = None # TODO @@ -196,7 +199,7 @@ class UserResp(UserBase): cover_url: str = "" # deprecated team: Team | None = None session_verified: bool = True - daily_challenge_user_stats: DailyChallengeStatsResp | None = None # TODO + daily_challenge_user_stats: DailyChallengeStatsResp | None = None # TODO: monthly_playcounts, unread_pm_count, rank_history, user_preferences @@ -292,9 +295,36 @@ class UserResp(UserBase): i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics } + if "monthly_playcounts" in include: + u.monthly_playcounts = [ + MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts + ] + if "achievements" in include: u.user_achievements = [ UserAchievementResp.from_db(ua) for ua in obj.achievement ] return u + + +ALL_INCLUDED = [ + "friends", + "team", + "account_history", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + "monthly_playcounts", +] + + +SEARCH_INCLUDED = [ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + "monthly_playcounts", +] diff --git a/app/database/monthly_playcounts.py b/app/database/monthly_playcounts.py new file mode 100644 index 0000000..46192d1 --- /dev/null +++ b/app/database/monthly_playcounts.py @@ -0,0 +1,43 @@ +from datetime import date +from typing import TYPE_CHECKING + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class MonthlyPlaycounts(SQLModel, table=True): + __tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True), + ) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + year: int = Field(index=True) + month: int = Field(index=True) + playcount: int = Field(default=0) + + user: "User" = Relationship(back_populates="monthly_playcounts") + + +class MonthlyPlaycountsResp(SQLModel): + start_date: date + count: int + + @classmethod + def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp": + return cls( + start_date=date(db_model.year, db_model.month, 1), + count=db_model.playcount, + ) diff --git a/app/database/score.py b/app/database/score.py index c805563..32c8cf5 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import Sequence -from datetime import UTC, datetime +from datetime import UTC, date, datetime import math from typing import TYPE_CHECKING @@ -30,6 +30,7 @@ from .beatmap import Beatmap, BeatmapResp from .beatmapset import Beatmapset, BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp +from .monthly_playcounts import MonthlyPlaycounts from .score_token import ScoreToken from redis import Redis @@ -501,8 +502,22 @@ async def process_user( previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) - statistics = None add_to_db = False + mouthly_playcount = ( + await session.exec( + select(MonthlyPlaycounts).where( + MonthlyPlaycounts.user_id == user.id, + MonthlyPlaycounts.year == date.today().year, + MonthlyPlaycounts.month == date.today().month, + ) + ) + ).first() + if mouthly_playcount is None: + mouthly_playcount = MonthlyPlaycounts( + user_id=user.id, year=date.today().year, month=date.today().month + ) + add_to_db = True + statistics = None for i in user.statistics: if i.mode == score.gamemode.value: statistics = i @@ -547,6 +562,7 @@ async def process_user( statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) statistics.play_count += 1 + mouthly_playcount.playcount += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) statistics.count_100 += score.n100 + score.nkatu statistics.count_300 += score.n300 + score.ngeki @@ -569,9 +585,8 @@ async def process_user( acc_sum = clamp(acc_sum, 0.0, 100.0) statistics.pp = pp_sum statistics.hit_accuracy = acc_sum - if add_to_db: - session.add(statistics) + session.add(mouthly_playcount) await session.commit() await session.refresh(user) diff --git a/app/router/me.py b/app/router/me.py index e3aa734..b6d7d26 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -1,6 +1,7 @@ from __future__ import annotations from app.database import User, UserResp +from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user from app.dependencies.database import get_db from app.models.score import GameMode @@ -21,14 +22,6 @@ async def get_user_info_default( return await UserResp.from_db( current_user, session, - [ - "friends", - "team", - "account_history", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - "achievements", - ], + ALL_INCLUDED, ruleset, ) diff --git a/app/router/user.py b/app/router/user.py index cfe136c..3df5a49 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -1,6 +1,7 @@ from __future__ import annotations from app.database import User, UserResp +from app.database.lazer_user import SEARCH_INCLUDED from app.dependencies.database import get_db from app.models.score import GameMode @@ -17,15 +18,6 @@ class BatchUserResponse(BaseModel): users: list[UserResp] -SEARCH_INCLUDE = [ - "team", - "daily_challenge_user_stats", - "statistics", - "statistics_rulesets", - "achievements", -] - - @router.get("/users", response_model=BatchUserResponse) @router.get("/users/lookup", response_model=BatchUserResponse) @router.get("/users/lookup/", response_model=BatchUserResponse) @@ -54,7 +46,7 @@ async def get_users( await UserResp.from_db( searched_user, session, - include=SEARCH_INCLUDE, + include=SEARCH_INCLUDED, ) for searched_user in searched_users ] @@ -85,6 +77,6 @@ async def get_user_info( return await UserResp.from_db( searched_user, session, - include=SEARCH_INCLUDE, + include=SEARCH_INCLUDED, ruleset=ruleset, ) From bcca895f4d043df2a33d67da6c1544b1724f6b42 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 02:13:56 +0000 Subject: [PATCH 03/45] fix(spectator): don't save replay for passed score --- app/signalr/hub/spectator.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index f388c92..bd311ec 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -296,19 +296,18 @@ class SpectatorHub(Hub[StoreClientState]): score_record.id, ) # save replay - if store.state.state == SpectatedUserState.Passed: - score_record.has_replay = True - await session.commit() - await session.refresh(score_record) - save_replay( - ruleset_id=store.ruleset_id, - md5=store.checksum, - username=store.score.score_info.user.name, - score=score_record, - statistics=store.score.score_info.statistics, - maximum_statistics=store.score.score_info.maximum_statistics, - frames=store.score.replay_frames, - ) + score_record.has_replay = True + await session.commit() + await session.refresh(score_record) + save_replay( + ruleset_id=store.ruleset_id, + md5=store.checksum, + username=store.score.score_info.user.name, + score=score_record, + statistics=store.score.score_info.statistics, + maximum_statistics=store.score.score_info.maximum_statistics, + frames=store.score.replay_frames, + ) async def _end_session(self, user_id: int, state: SpectatorState) -> None: if state.state == SpectatedUserState.Playing: From 1281e75bb16df766dc9b2b57a82244b40a90d9a6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 02:29:51 +0000 Subject: [PATCH 04/45] feat(beatmapset): support download beatmapset --- app/router/beatmapset.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index e551727..80396fd 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -12,7 +12,8 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Query +from fastapi.responses import RedirectResponse from httpx import HTTPStatusError from sqlalchemy.orm import selectinload from sqlmodel import select @@ -42,3 +43,20 @@ async def get_beatmapset( else: resp = BeatmapsetResp.from_db(beatmapset) return resp + + +@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) +async def download_beatmapset( + beatmapset: int, + no_video: bool = Query(True, alias="noVideo"), + current_user: User = Depends(get_current_user), +): + if current_user.country_code == "CN": + return RedirectResponse( + f"https://txy1.sayobot.cn/beatmaps/download/" + f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto" + ) + else: + return RedirectResponse( + f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" + ) From be401e88850b79a48398c94cfdf9b7fc820c3779 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 20:11:22 +0800 Subject: [PATCH 05/45] =?UTF-8?q?refactor(database):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=85=B3=E8=81=94=E5=AF=B9=E8=B1=A1?= =?UTF-8?q?=E7=9A=84=E8=BD=BD=E5=85=A5=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/beatmap.py | 23 ++++++------------ app/database/beatmapset.py | 9 ++++---- app/database/lazer_user.py | 43 +++++++++++++--------------------- app/database/relationship.py | 6 +++-- app/database/score.py | 45 ++++++++---------------------------- app/database/team.py | 8 +++++-- app/dependencies/user.py | 8 +------ app/models/beatmap.py | 11 +++++---- app/router/beatmap.py | 33 +++++--------------------- app/router/beatmapset.py | 11 ++------- app/router/relationship.py | 13 +++-------- app/router/score.py | 12 ++++------ app/router/user.py | 17 +++----------- 13 files changed, 73 insertions(+), 166 deletions(-) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 751bc5c..2ab5ad0 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime -from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True): beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + beatmapset: Beatmapset = Relationship( + back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} + ) @property def can_ranked(self) -> bool: @@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True): session.add(beatmap) await session.commit() beatmap = ( - await session.exec( - select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) - .where(Beatmap.id == resp.id) - ) + await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) ).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True): ) -> "Beatmap": beatmap = ( await session.exec( - select(Beatmap) - .where( + select(Beatmap).where( Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 ) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) ) ).first() if not beatmap: @@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase): url: str = "" @classmethod - def from_db( + async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, @@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 2ef6280..5a618b7 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -7,6 +7,7 @@ from app.models.score import GameMode from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): tags: str = Field(default="", sa_column=Column(Text)) -class Beatmapset(BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase): nominations: BeatmapNominations | None = None @classmethod - def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": from .beatmap import BeatmapResp beatmaps = [ - BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in beatmapset.beatmaps + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps ] return cls.model_validate( { diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 9b98c98..d502ccb 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( JSON, BigInteger, @@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel): is_bng: bool = False -class User(UserBase, table=True): +class User(AsyncAttrs, UserBase, table=True): __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] id: int | None = Field( @@ -154,17 +154,6 @@ class User(UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) - @classmethod - def all_select_option(cls): - return ( - selectinload(cls.account_history), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.achievement), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType] - ) - class UserResp(UserBase): id: int | None = None @@ -249,13 +238,7 @@ class UserResp(UserBase): await RelationshipResp.from_db(session, r) for r in ( await session.exec( - select(Relationship) - .options( - joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType] - *User.all_select_option() - ) - ) - .where( + select(Relationship).where( Relationship.user_id == obj.id, Relationship.type == RelationshipType.FOLLOW, ) @@ -264,23 +247,26 @@ class UserResp(UserBase): ] if "team" in include: - if obj.team_membership: + if await obj.awaitable_attrs.team_membership: + assert obj.team_membership u.team = obj.team_membership.team if "account_history" in include: u.account_history = [ - UserAccountHistoryResp.from_db(ah) for ah in obj.account_history + UserAccountHistoryResp.from_db(ah) + for ah in await obj.awaitable_attrs.account_history ] if "daily_challenge_user_stats": - if obj.daily_challenge_stats: + if await obj.awaitable_attrs.daily_challenge_stats: + assert obj.daily_challenge_stats u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( obj.daily_challenge_stats ) if "statistics" in include: current_stattistics = None - for i in obj.statistics: + for i in await obj.awaitable_attrs.statistics: if i.mode == (ruleset or obj.playmode): current_stattistics = i break @@ -292,17 +278,20 @@ class UserResp(UserBase): if "statistics_rulesets" in include: u.statistics_rulesets = { - i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics + i.mode.value: UserStatisticsResp.from_db(i) + for i in await obj.awaitable_attrs.statistics } if "monthly_playcounts" in include: u.monthly_playcounts = [ - MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts + MonthlyPlaycountsResp.from_db(pc) + for pc in await obj.awaitable_attrs.monthly_playcounts ] if "achievements" in include: u.user_achievements = [ - UserAchievementResp.from_db(ua) for ua in obj.achievement + UserAchievementResp.from_db(ua) + for ua in await obj.awaitable_attrs.achievement ] return u diff --git a/app/database/relationship.py b/app/database/relationship.py index 07daa25..7a351aa 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True): ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) target: User = SQLRelationship( - sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } ) @@ -79,7 +82,6 @@ class RelationshipResp(BaseModel): "daily_challenge_user_stats", "statistics", "statistics_rulesets", - "achievements", ], ), mutual=mutual, diff --git a/app/database/score.py b/app/database/score.py index 32c8cf5..4cee832 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -27,7 +27,7 @@ from app.models.score import ( ) from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp from .monthly_playcounts import MonthlyPlaycounts @@ -35,7 +35,8 @@ from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime -from sqlalchemy.orm import aliased, joinedload +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import aliased from sqlmodel import ( JSON, BigInteger, @@ -55,7 +56,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel, UTCBaseModel): +class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -114,27 +115,12 @@ class Score(ScoreBase, table=True): # optional beatmap: Beatmap = Relationship() - user: User = Relationship() + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: return self.max_combo == self.beatmap.max_combo - @staticmethod - def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]: - clause = select(Score).options( - joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - ) - if with_user: - return clause.options( - joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType] - ) - return clause - @staticmethod def select_clause_unique( *where_clauses: ColumnExpressionArgument[bool] | bool, @@ -148,18 +134,7 @@ class Score(ScoreBase, table=True): ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) - return ( - select(best) - .where(subq.c.rn == 1) - .options( - joinedload(best.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType] - ) - ) + return select(best).where(subq.c.rn == 1) class ScoreResp(ScoreBase): @@ -186,8 +161,9 @@ class ScoreResp(ScoreBase): ) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id - s.beatmap = BeatmapResp.from_db(score.beatmap) - s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) + await score.awaitable_attrs.beatmap + s.beatmap = await BeatmapResp.from_db(score.beatmap) + s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -303,7 +279,6 @@ async def get_leaderboard( query = ( select(Score) .join(Beatmap) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] .where( Score.map_md5 == beatmap_md5, Score.gamemode == mode, @@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap( ) -> Score | None: return ( await session.exec( - Score.select_clause(False) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, diff --git a/app/database/team.py b/app/database/team.py index 146ca9f..562b0c8 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: "User" = Relationship(back_populates="team_membership") - team: "Team" = Relationship(back_populates="members") + user: "User" = Relationship( + back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} + ) + team: "Team" = Relationship( + back_populates="members", sa_relationship_kwargs={"lazy": "joined"} + ) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 769247c..5537f4f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = ( - await db.exec( - select(User) - .options(*User.all_select_option()) - .where(User.id == token_record.user_id) - ) - ).first() + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 4f12e13..fae18ba 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -42,11 +42,12 @@ class Language(IntEnum): KOREAN = 6 FRENCH = 7 GERMAN = 8 - ITALIAN = 9 - SPANISH = 10 - RUSSIAN = 11 - POLISH = 12 - OTHER = 13 + SWEDISH = 9 + ITALIAN = 10 + SPANISH = 11 + RUSSIAN = 12 + POLISH = 13 + OTHER = 14 class BeatmapAttributes(BaseModel): diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 4af9c9a..df5f20d 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,7 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import Beatmap, BeatmapResp, Beatmapset, User +from app.database import Beatmap, BeatmapResp, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel from redis import Redis import rosu_pp_py as rosu -from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -51,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -63,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -83,35 +82,15 @@ async def batch_get_beatmaps( # select 50 beatmaps by last_updated beatmaps = ( await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .order_by(col(Beatmap.last_updated).desc()) - .limit(50) + select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) ) ).all() else: beatmaps = ( - await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .where(col(Beatmap.id).in_(b_ids)) - .limit(50) - ) + await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index 80396fd..b82678d 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -15,7 +15,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError -from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,13 +26,7 @@ async def get_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = ( - await db.exec( - select(Beatmapset) - .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] - .where(Beatmapset.id == sid) - ) - ).first() + beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() if not beatmapset: try: resp = await fetcher.get_beatmapset(sid) @@ -41,7 +34,7 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db(beatmapset) return resp diff --git a/app/router/relationship.py b/app/router/relationship.py index 9e39e8b..02292c9 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -9,7 +9,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query, Request from pydantic import BaseModel -from sqlalchemy.orm import joinedload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,14 +26,12 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship) - .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] - .where( + select(Relationship).where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) ) - return [await RelationshipResp.from_db(db, rel) for rel in relationships] + return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] class AddFriendResp(BaseModel): @@ -92,14 +89,10 @@ async def add_relationship( if origin_type == RelationshipType.FOLLOW: relationship = ( await db.exec( - select(Relationship) - .where( + select(Relationship).where( Relationship.user_id == current_user_id, Relationship.target_id == target, ) - .options( - joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] - ) ) ).first() assert relationship, "Relationship should exist after commit" diff --git a/app/router/score.py b/app/router/score.py index baab3a2..cd0a236 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -99,7 +99,7 @@ async def get_user_beatmap_score( ) user_score = ( await db.exec( - Score.select_clause(True) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, @@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores( ) all_user_scores = ( await db.exec( - Score.select_clause() + select(Score) .where( Score.gamemode == ruleset if ruleset is not None else True, Score.beatmap_id == beatmap, @@ -207,9 +207,7 @@ async def submit_solo_score( if score_token.score_id: score = ( await db.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( + select(Score).where( Score.id == score_token.score_id, Score.user_id == current_user.id, ) @@ -243,8 +241,6 @@ async def submit_solo_score( score_id = score.id score_token.score_id = score_id await process_user(db, current_user, score, ranked) - score = ( - await db.exec(Score.select_clause().where(Score.id == score_id)) - ).first() + score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None return await ScoreResp.from_db(db, score, current_user) diff --git a/app/router/user.py b/app/router/user.py index 3df5a49..649f1d4 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -28,19 +28,10 @@ async def get_users( ): if user_ids: searched_users = ( - await session.exec( - select(User) - .options(*User.all_select_option()) - .limit(50) - .where(col(User.id).in_(user_ids)) - ) + await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) ).all() else: - searched_users = ( - await session.exec( - select(User).options(*User.all_select_option()).limit(50) - ) - ).all() + searched_users = (await session.exec(select(User).limit(50))).all() return BatchUserResponse( users=[ await UserResp.from_db( @@ -63,9 +54,7 @@ async def get_user_info( ): searched_user = ( await session.exec( - select(User) - .options(*User.all_select_option()) - .where( + select(User).where( User.id == int(user) if user.isdigit() else User.username == user.removeprefix("@") From 1635641654860591f069a492f0b4d0b0cd23009b Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 14:11:42 +0000 Subject: [PATCH 06/45] feat(score): support leaderboard for country/friends/team/selected mods --- app/database/__init__.py | 3 +- app/database/best_score.py | 23 +- app/database/pp_best_score.py | 41 ++++ app/database/score.py | 382 +++++++++++++++++++--------------- app/models/score.py | 2 +- app/router/score.py | 49 ++--- 6 files changed, 284 insertions(+), 216 deletions(-) create mode 100644 app/database/pp_best_score.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 91bc7cc..12fa867 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -14,6 +14,7 @@ from .lazer_user import ( User, UserResp, ) +from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType from .score import ( Score, @@ -35,13 +36,13 @@ from .user_account_history import ( __all__ = [ "Beatmap", - "BeatmapResp", "Beatmapset", "BeatmapsetResp", "BestScore", "DailyChallengeStats", "DailyChallengeStatsResp", "OAuthToken", + "PPBestScore", "Relationship", "RelationshipResp", "RelationshipType", diff --git a/app/database/best_score.py b/app/database/best_score.py index 9993b63..42b0024 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -1,14 +1,14 @@ from typing import TYPE_CHECKING -from app.models.score import GameMode +from app.models.score import GameMode, Rank from .lazer_user import User from sqlmodel import ( + JSON, BigInteger, Column, Field, - Float, ForeignKey, Relationship, SQLModel, @@ -20,7 +20,7 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): - __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] user_id: int = Field( sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) @@ -29,13 +29,20 @@ class BestScore(SQLModel, table=True): ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - pp: float = Field( - sa_column=Column(Float, default=0), + total_score: int = Field( + default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score")) ) - acc: float = Field( - sa_column=Column(Float, default=0), + mods: list[str] = Field( + default_factory=list, + sa_column=Column(JSON), ) + rank: Rank user: User = Relationship() - score: "Score" = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[BestScore.score_id]", + "lazy": "joined", + } + ) beatmap: "Beatmap" = Relationship() diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py new file mode 100644 index 0000000..ffc74d3 --- /dev/null +++ b/app/database/pp_best_score.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from .lazer_user import User + +from sqlmodel import ( + BigInteger, + Column, + Field, + Float, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .beatmap import Beatmap + from .score import Score + + +class PPBestScore(SQLModel, table=True): + __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) + gamemode: GameMode = Field(index=True) + pp: float = Field( + sa_column=Column(Float, default=0), + ) + acc: float = Field( + sa_column=Column(Float, default=0), + ) + + user: User = Relationship() + score: "Score" = Relationship() + beatmap: "Beatmap" = Relationship() diff --git a/app/database/score.py b/app/database/score.py index 4cee832..c5f1a38 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import Sequence from datetime import UTC, date, datetime +import json import math from typing import TYPE_CHECKING @@ -12,7 +13,7 @@ from app.calculator import ( calculate_weighted_pp, clamp, ) -from app.models.beatmap import BeatmapRankStatus +from app.database.team import TeamMember from app.models.model import UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( @@ -31,12 +32,18 @@ from .beatmapset import BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp from .monthly_playcounts import MonthlyPlaycounts +from .pp_best_score import PPBestScore +from .relationship import ( + Relationship as DBRelationship, + RelationshipType, +) from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased +from sqlalchemy.sql.elements import ColumnElement from sqlmodel import ( JSON, BigInteger, @@ -45,9 +52,10 @@ from sqlmodel import ( Relationship, SQLModel, col, - false, func, select, + text, + true, ) from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql._expression_select_cls import SelectOfScalar @@ -156,9 +164,7 @@ class ScoreResp(ScoreBase): rank_country: int | None = None @classmethod - async def from_db( - cls, session: AsyncSession, score: Score, user: User | None = None - ) -> "ScoreResp": + async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id await score.awaitable_attrs.beatmap @@ -195,30 +201,30 @@ class ScoreResp(ScoreBase): s.maximum_statistics = { HitResult.GREAT: score.beatmap.max_combo, } - if user: - s.user = await UserResp.from_db( - user, - session, - include=["statistics", "team", "daily_challenge_user_stats"], - ruleset=score.gamemode, - ) + s.user = await UserResp.from_db( + score.user, + session, + include=["statistics", "team", "daily_challenge_user_stats"], + ruleset=score.gamemode, + ) s.rank_global = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, mode=score.gamemode, - user=user or score.user, + user=score.user, ) or None ) s.rank_country = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, score.gamemode, - user or score.user, + score.user, + type=LeaderboardType.COUNTRY, ) or None ) @@ -228,134 +234,137 @@ class ScoreResp(ScoreBase): async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() - .over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc()) + .over( + partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc() + ) .label("rn") ) - subq = select(BestScore, rownum).subquery() + subq = select(PPBestScore, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.score_id == score_id) result = await session.exec(stmt) return result.one_or_none() +async def _score_where( + type: LeaderboardType, + beatmap: int, + mode: GameMode, + mods: list[str] | None = None, + user: User | None = None, +) -> list[ColumnElement[bool]] | None: + wheres = [ + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ] + + if type == LeaderboardType.FRIENDS: + if user and user.is_supporter: + subq = ( + select(DBRelationship.target_id) + .where( + DBRelationship.type == RelationshipType.FOLLOW, + DBRelationship.user_id == user.id, + ) + .subquery() + ) + wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id))) + else: + return None + elif type == LeaderboardType.COUNTRY: + if user and user.is_supporter: + wheres.append( + col(BestScore.user).has(col(User.country_code) == user.country_code) + ) + else: + return None + elif type == LeaderboardType.TEAM: + if user: + team_membership = await user.awaitable_attrs.team_membership + if team_membership: + team_id = team_membership.team_id + wheres.append( + col(BestScore.user).has( + col(User.team_membership).has(TeamMember.team_id == team_id) + ) + ) + if mods: + if user and user.is_supporter: + wheres.append( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" + ) # pyright: ignore[reportArgumentType] + ) + else: + return None + return wheres + + async def get_leaderboard( session: AsyncSession, - beatmap_md5: str, + beatmap: int, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, user: User | None = None, limit: int = 50, -) -> list[Score]: - scores = [] - if type == LeaderboardType.GLOBAL: - query = ( - select(Score) - .where( - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) - elif type == LeaderboardType.FRIENDS and user and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user and user.team_membership: - team_id = user.team_membership.team_id - query = ( - select(Score) - .join(Beatmap) - .where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Score.user.team_membership).is_not(None), - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) +) -> tuple[list[Score], Score | None]: + wheres = await _score_where(type, beatmap, mode, mods, user) + if wheres is None: + return [], None + query = ( + select(BestScore) + .where(*wheres) + .limit(limit) + .order_by(col(BestScore.total_score).desc()) + ) + if mods: + query = query.params(w=json.dumps(mods)) + scores = [s.score for s in await session.exec(query)] + user_score = None if user: - user_score = ( - await session.exec( - select(Score).where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - Score.user_id == user.id, - col(Score.passed).is_(True), + self_query = ( + select(BestScore) + .where(BestScore.user_id == user.id) + .order_by(col(BestScore.total_score).desc()) + .limit(1) + ) + if mods: + self_query = self_query.where( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" ) - ) - ).first() + ).params(w=json.dumps(mods)) + user_bs = (await session.exec(self_query)).first() + if user_bs: + user_score = user_bs.score if user_score and user_score not in scores: scores.append(user_score) - return scores + return scores, user_score async def get_score_position_by_user( session: AsyncSession, - beatmap_md5: str, + beatmap: int, user: User, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user.is_supporter: - where_clause.append(Score.mods == mods) - else: - where_clause.append(false()) - if type == LeaderboardType.FRIENDS and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user.team_membership: - team_id = user.team_membership.team_id - where_clause.append( - col(Score.user.team_membership).is_not(None), - ) - where_clause.append( - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - ) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=Score.map_md5, - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) .label("row_number") ) - subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery() - stmt = select(subq.c.row_number).where(subq.c.user == user) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.user_id == user.id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -363,57 +372,26 @@ async def get_score_position_by_user( async def get_score_position_by_id( session: AsyncSession, - beatmap_md5: str, + beatmap: int, score_id: int, mode: GameMode, user: User | None = None, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.id == score_id, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user and user.is_supporter: - where_clause.append(Score.mods == mods) - elif mods: - where_clause.append(false()) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=[col(Score.user_id), col(Score.map_md5)], - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) - .label("rownum") + .label("row_number") ) - subq = ( - select(Score.user_id, Score.id, Score.total_score, rownum) - .join(Beatmap) - .where(*where_clause) - .subquery() - ) - best_scores = aliased(subq) - overall_rank = ( - func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank") - ) - final_q = ( - select(best_scores.c.id, overall_rank) - .select_from(best_scores) - .where(best_scores.c.rownum == 1) - .subquery() - ) - - stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -424,16 +402,38 @@ async def get_user_best_score_in_beatmap( beatmap: int, user: int, mode: GameMode | None = None, -) -> Score | None: +) -> BestScore | None: return ( await session.exec( - select(Score) + select(BestScore) .where( - Score.gamemode == mode if mode is not None else True, - Score.beatmap_id == beatmap, - Score.user_id == user, + BestScore.gamemode == mode if mode is not None else true(), + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, ) - .order_by(col(Score.total_score).desc()) + .order_by(col(BestScore.total_score).desc()) + ) + ).first() + + +# FIXME +async def get_user_best_score_with_mod_in_beatmap( + session: AsyncSession, + beatmap: int, + user: int, + mod: list[str], + mode: GameMode | None = None, +) -> BestScore | None: + return ( + await session.exec( + select(BestScore) + .where( + BestScore.gamemode == mode if mode is not None else True, + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, + # BestScore.mods == mod, + ) + .order_by(col(BestScore.total_score).desc()) ) ).first() @@ -443,13 +443,13 @@ async def get_user_best_pp_in_beatmap( beatmap: int, user: int, mode: GameMode, -) -> BestScore | None: +) -> PPBestScore | None: return ( await session.exec( - select(BestScore).where( - BestScore.beatmap_id == beatmap, - BestScore.user_id == user, - BestScore.gamemode == mode, + select(PPBestScore).where( + PPBestScore.beatmap_id == beatmap, + PPBestScore.user_id == user, + PPBestScore.gamemode == mode, ) ) ).first() @@ -459,12 +459,12 @@ async def get_user_best_pp( session: AsyncSession, user: int, limit: int = 200, -) -> Sequence[BestScore]: +) -> Sequence[PPBestScore]: return ( await session.exec( - select(BestScore) - .where(BestScore.user_id == user) - .order_by(col(BestScore.pp).desc()) + select(PPBestScore) + .where(PPBestScore.user_id == user) + .order_by(col(PPBestScore.pp).desc()) .limit(limit) ) ).all() @@ -474,9 +474,15 @@ async def process_user( session: AsyncSession, user: User, score: Score, ranked: bool = False ): assert user.id + assert score.id + mod_for_save = list({mod["acronym"] for mod in score.mods}) previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) + previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( + session, score.beatmap_id, user.id, mod_for_save, score.gamemode + ) + print(previous_score_best, previous_score_best_mod) add_to_db = False mouthly_playcount = ( await session.exec( @@ -493,7 +499,7 @@ async def process_user( ) add_to_db = True statistics = None - for i in user.statistics: + for i in await user.awaitable_attrs.statistics: if i.mode == score.gamemode.value: statistics = i break @@ -506,7 +512,7 @@ async def process_user( statistics.total_score += score.total_score difference = ( score.total_score - previous_score_best.total_score - if previous_score_best and previous_score_best.id != score.id + if previous_score_best else score.total_score ) if difference > 0 and score.passed and ranked: @@ -533,9 +539,41 @@ async def process_user( statistics.grade_sh -= 1 case Rank.A: statistics.grade_a -= 1 + else: + previous_score_best = BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + session.add(previous_score_best) + statistics.ranked_score += difference statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) + if score.passed and ranked: + if previous_score_best_mod is not None: + previous_score_best_mod.mods = mod_for_save + previous_score_best_mod.score_id = score.id + previous_score_best_mod.rank = score.rank + previous_score_best_mod.total_score = score.total_score + elif ( + previous_score_best is not None and previous_score_best.score_id != score.id + ): + session.add( + BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + ) statistics.play_count += 1 mouthly_playcount.playcount += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) @@ -623,7 +661,7 @@ async def process_score( ) if previous_pp_best is None or score.pp > previous_pp_best.pp: assert score.id - best_score = BestScore( + best_score = PPBestScore( user_id=user_id, score_id=score.id, beatmap_id=beatmap_id, diff --git a/app/models/score.py b/app/models/score.py index b613ae2..bfc9f53 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -132,7 +132,7 @@ class HitResultInt(IntEnum): class LeaderboardType(Enum): GLOBAL = "global" - FRIENDS = "friends" + FRIENDS = "friend" COUNTRY = "country" TEAM = "team" diff --git a/app/router/score.py b/app/router/score.py index cd0a236..6c6a475 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,7 +1,7 @@ from __future__ import annotations from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User -from app.database.score import process_score, process_user +from app.database.score import get_leaderboard, process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -9,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, GameMode, + LeaderboardType, Rank, SoloScoreSubmissionInfo, ) @@ -19,7 +20,7 @@ from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel from redis import Redis from sqlalchemy.orm import joinedload -from sqlmodel import col, select, true +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -33,44 +34,26 @@ class BeatmapScores(BaseModel): ) async def get_beatmap_scores( beatmap: int, + mode: GameMode, legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 - mode: GameMode | None = Query(None), - # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 - type: str = Query(None), + mods: list[str] = Query(default_factory=set, alias="mods[]"), + type: LeaderboardType = Query(LeaderboardType.GLOBAL), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + limit: int = Query(50, ge=1, le=200), ): if legacy_only: raise HTTPException( status_code=404, detail="this server only contains lazer scores" ) - all_scores = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).all() - - user_score = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - Score.user_id == current_user.id, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).first() + all_scores, user_score = await get_leaderboard( + db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods + ) return BeatmapScores( - scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores], - userScore=await ScoreResp.from_db(db, user_score, user_score.user) - if user_score - else None, + scores=[await ScoreResp.from_db(db, score) for score in all_scores], + userScore=await ScoreResp.from_db(db, user_score) if user_score else None, ) @@ -116,7 +99,7 @@ async def get_user_beatmap_score( else: return BeatmapUserScore( position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score, user_score.user), + score=await ScoreResp.from_db(db, user_score), ) @@ -149,9 +132,7 @@ async def get_user_all_beatmap_scores( ) ).all() - return [ - await ScoreResp.from_db(db, score, current_user) for score in all_user_scores - ] + return [await ScoreResp.from_db(db, score) for score in all_user_scores] @router.post( @@ -243,4 +224,4 @@ async def submit_solo_score( await process_user(db, current_user, score, ranked) score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None - return await ScoreResp.from_db(db, score, current_user) + return await ScoreResp.from_db(db, score) From c5fc6afc189fbe665801412c1cff9cc7a308ccd6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 14:38:10 +0000 Subject: [PATCH 07/45] feat(redis): use asyncio --- app/database/score.py | 2 +- app/dependencies/database.py | 11 ++-------- app/dependencies/fetcher.py | 23 ++++++++++----------- app/fetcher/_base.py | 38 +++++++++++++++++------------------ app/fetcher/osu_dot_direct.py | 6 +++--- app/router/beatmap.py | 8 ++++---- app/router/room.py | 22 +++++++++----------- app/router/score.py | 2 +- main.py | 5 +++-- 9 files changed, 53 insertions(+), 64 deletions(-) diff --git a/app/database/score.py b/app/database/score.py index c5f1a38..642eac1 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -39,7 +39,7 @@ from .relationship import ( ) from .score_token import ScoreToken -from redis import Redis +from redis.asyncio import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased diff --git a/app/dependencies/database.py b/app/dependencies/database.py index fe09139..77b15c3 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -5,15 +5,11 @@ import json from app.config import settings from pydantic import BaseModel +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -try: - import redis -except ImportError: - redis = None - def json_serializer(value): if isinstance(value, BaseModel | SQLModel): @@ -25,10 +21,7 @@ def json_serializer(value): engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 -if redis: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) -else: - redis_client = None +redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index d3c216a..806eb87 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -8,7 +8,7 @@ from app.log import logger fetcher: Fetcher | None = None -def get_fetcher() -> Fetcher: +async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( @@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher: settings.FETCHER_CALLBACK_URL, ) redis = get_redis() - if redis: - access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") - if access_token: - fetcher.access_token = str(access_token) - refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + if access_token: + fetcher.access_token = str(access_token) + refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") + if refresh_token: + fetcher.refresh_token = str(refresh_token) + if not fetcher.access_token or not fetcher.refresh_token: + logger.opt(colors=True).info( + f"Login to initialize fetcher: {fetcher.authorize_url}" + ) return fetcher diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 08e3508..2717a35 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -59,16 +59,15 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) async def refresh_access_token(self) -> None: async with AsyncClient() as client: @@ -87,13 +86,12 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 08b8dfc..cb3897f 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,7 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger -import redis +import redis.asyncio as redis class OsuDotDirectFetcher(BaseFetcher): @@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher): self, redis: redis.Redis, beatmap_id: int ) -> str: if redis.exists(f"beatmap:{beatmap_id}:raw"): - return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) - redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) return raw diff --git a/app/router/beatmap.py b/app/router/beatmap.py index df5f20d..0a25562 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -22,7 +22,7 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis import rosu_pp_py as rosu from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -127,8 +127,8 @@ async def get_beatmap_attributes( f"beatmap:{beatmap}:{ruleset}:" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" ) - if redis.exists(key): - return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] + if await redis.exists(key): + return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) @@ -138,7 +138,7 @@ async def get_beatmap_attributes( ) except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] raise HTTPException(status_code=400, detail=str(e)) - redis.set(key, attr.model_dump_json()) + await redis.set(key, attr.model_dump_json()) return attr except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") diff --git a/app/router/room.py b/app/router/room.py index ed540fc..3a65617 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -6,7 +6,8 @@ from app.models.room import Room from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Query +from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,17 +20,14 @@ async def get_all_rooms( status: str = Query(None), category: str = Query(None), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), ): all_room_ids = (await db.exec(select(RoomIndex).where(True))).all() - redis = get_redis() roomsList: list[Room] = [] - if redis: - for room_index in all_room_ids: - dumped_room = redis.get(str(room_index.id)) - if dumped_room: - actual_room = Room.model_validate_json(str(dumped_room)) - if actual_room.status == status and actual_room.category == category: - roomsList.append(actual_room) - return roomsList - else: - raise HTTPException(status_code=500, detail="Redis Error") + for room_index in all_room_ids: + dumped_room = await redis.get(str(room_index.id)) + if dumped_room: + actual_room = Room.model_validate_json(str(dumped_room)) + if actual_room.status == status and actual_room.category == category: + roomsList.append(actual_room) + return roomsList diff --git a/app/router/score.py b/app/router/score.py index 6c6a475..2f1303e 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -18,7 +18,7 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/main.py b/main.py index 92d4402..f5d20c1 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.dependencies.database import create_tables, engine +from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher from app.router import api_router, auth_router, fetcher_router, signalr_router @@ -15,10 +15,11 @@ from fastapi import FastAPI async def lifespan(app: FastAPI): # on startup await create_tables() - get_fetcher() # 初始化 fetcher + await get_fetcher() # 初始化 fetcher # on shutdown yield await engine.dispose() + await redis_client.aclose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) From 86a6fd1b69b962692ead1146ff6bc1addf1455ec Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 02:49:49 +0000 Subject: [PATCH 08/45] feat(user): support `online` & `last_visit` --- app/database/lazer_user.py | 5 ++++- app/database/score.py | 1 - app/fetcher/osu_dot_direct.py | 2 +- app/signalr/hub/metadata.py | 18 +++++++++++++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index d502ccb..1337cc2 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,6 +1,7 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, NotRequired, TypedDict +from app.dependencies.database import get_redis from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page, RankHistory @@ -157,7 +158,7 @@ class User(AsyncAttrs, UserBase, table=True): class UserResp(UserBase): id: int | None = None - is_online: bool = True # TODO + is_online: bool = False groups: list = [] # TODO country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) favourite_beatmapset_count: int = 0 # TODO @@ -225,6 +226,8 @@ class UserResp(UserBase): .limit(200) ) ).one() + redis = get_redis() + u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.cover_url = ( obj.cover.get( "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" diff --git a/app/database/score.py b/app/database/score.py index 642eac1..32ddb6c 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -482,7 +482,6 @@ async def process_user( previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( session, score.beatmap_id, user.id, mod_for_save, score.gamemode ) - print(previous_score_best, previous_score_best_mod) add_to_db = False mouthly_playcount = ( await session.exec( diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index cb3897f..6e18435 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -22,7 +22,7 @@ class OsuDotDirectFetcher(BaseFetcher): async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int ) -> str: - if redis.exists(f"beatmap:{beatmap_id}:raw"): + if await redis.exists(f"beatmap:{beatmap_id}:raw"): return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 2712883..227cf7b 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -2,10 +2,12 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from datetime import UTC, datetime from typing import override from app.database import Relationship, RelationshipType -from app.dependencies.database import engine +from app.database.lazer_user import User +from app.dependencies.database import engine, get_redis from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity from .hub import Client, Hub @@ -54,6 +56,18 @@ class MetadataHub(Hub[MetadataClientState]): async def _clean_state(self, state: MetadataClientState) -> None: if state.pushable: await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) + redis = get_redis() + if await redis.exists(f"metadata:online:{state.connection_id}"): + await redis.delete(f"metadata:online:{state.connection_id}") + async with AsyncSession(engine) as session: + async with session.begin(): + user = ( + await session.exec( + select(User).where(User.id == int(state.connection_id)) + ) + ).one() + user.last_visit = datetime.now(UTC) + await session.commit() @override def create_state(self, client: Client) -> MetadataClientState: @@ -93,6 +107,8 @@ class MetadataHub(Hub[MetadataClientState]): ) ) await asyncio.gather(*tasks) + redis = get_redis() + await redis.set(f"metadata:online:{user_id}", "") async def UpdateStatus(self, client: Client, status: int) -> None: status_ = OnlineStatus(status) From d938998239c0b445a01261c492137856f88a9683 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 04:22:17 +0000 Subject: [PATCH 09/45] feat(beatmapset): support post favoutite to beatmapset --- app/database/__init__.py | 2 + app/database/beatmap.py | 8 +- app/database/beatmapset.py | 119 ++++++++++++------ app/database/favourite_beatmapset.py | 53 ++++++++ app/database/lazer_user.py | 15 ++- app/database/score.py | 6 +- app/router/beatmap.py | 11 +- app/router/beatmapset.py | 45 +++++-- ...8ebf_beatmapset_support_favourite_count.py | 40 ++++++ 9 files changed, 249 insertions(+), 50 deletions(-) create mode 100644 app/database/favourite_beatmapset.py create mode 100644 migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 12fa867..6e2e8c5 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -10,6 +10,7 @@ from .beatmapset import ( ) from .best_score import BestScore from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .favourite_beatmapset import FavouriteBeatmapset from .lazer_user import ( User, UserResp, @@ -41,6 +42,7 @@ __all__ = [ "BestScore", "DailyChallengeStats", "DailyChallengeStatsResp", + "FavouriteBeatmapset", "OAuthToken", "PPBestScore", "Relationship", diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 2ab5ad0..c55643a 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -14,6 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher + from .lazer_user import User + class BeatmapOwner(SQLModel): id: int @@ -161,6 +163,8 @@ class BeatmapResp(BeatmapBase): beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, + session: AsyncSession | None = None, + user: "User | None" = None, ) -> "BeatmapResp": beatmap_ = beatmap.model_dump() if query_mode is not None and beatmap.mode != query_mode: @@ -170,5 +174,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=session, user=user + ) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 5a618b7..49313b2 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -5,14 +5,17 @@ from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.model import UTCBaseModel from app.models.score import GameMode +from .lazer_user import BASE_INCLUDES, User, UserResp + from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset class BeatmapCovers(SQLModel): @@ -90,7 +93,6 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str - favourite_count: int nsfw: bool = Field(default=False) play_count: int preview_url: str @@ -114,11 +116,9 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON)) - # TODO: recent_favourites: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) track_id: int | None = Field(default=None) # feature artist? - # TODO: has_favourited # BeatmapsetExtended bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) @@ -152,6 +152,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod async def from_resp( @@ -199,40 +200,88 @@ class BeatmapsetResp(BeatmapsetBase): genre: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None nominations: BeatmapNominations | None = None + has_favourited: bool = False + favourite_count: int = 0 + recent_favourites: list[UserResp] = Field(default_factory=list) @classmethod - async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db( + cls, + beatmapset: Beatmapset, + include: list[str] = [], + session: AsyncSession | None = None, + user: User | None = None, + ) -> "BeatmapsetResp": from .beatmap import BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset - beatmaps = [ - await BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in await beatmapset.awaitable_attrs.beatmaps - ] + update = { + "beatmaps": [ + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ], + "hype": BeatmapHype( + current=beatmapset.hype_current, required=beatmapset.hype_required + ), + "availability": BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ), + "genre": BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ), + "language": BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ), + "nominations": BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ), + "status": beatmapset.beatmap_status.name.lower(), + "ranked": beatmapset.beatmap_status.value, + "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, + **beatmapset.model_dump(), + } + if session and user: + existing_favourite = ( + await session.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id + ) + ) + ).first() + update["has_favourited"] = existing_favourite is not None + + if session and "recent_favourites" in include: + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id, + ) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + update["recent_favourites"] = [ + await UserResp.from_db( + await favourite.awaitable_attrs.user, + session=session, + include=BASE_INCLUDES, + ) + for favourite in recent_favourites + ] + + if session: + update["favourite_count"] = ( + await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + ).one() return cls.model_validate( - { - "beatmaps": beatmaps, - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "status": beatmapset.beatmap_status.name.lower(), - "ranked": beatmapset.beatmap_status.value, - "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, - **beatmapset.model_dump(), - } + update, ) diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py new file mode 100644 index 0000000..51bd578 --- /dev/null +++ b/app/database/favourite_beatmapset.py @@ -0,0 +1,53 @@ +import datetime + +from app.database.beatmapset import Beatmapset +from app.database.lazer_user import User + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + + +class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): + __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + beatmapset_id: int = Field( + default=None, + sa_column=Column( + ForeignKey("beatmapsets.id"), + index=True, + ), + ) + date: datetime.datetime = Field( + default=datetime.datetime.now(datetime.UTC), + sa_column=Column( + DateTime, + ), + ) + + user: User = Relationship(back_populates="favourite_beatmapsets") + beatmapset: Beatmapset = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + }, + back_populates="favourites", + ) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 1337cc2..3bd751b 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,7 +1,6 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, NotRequired, TypedDict -from app.dependencies.database import get_redis from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page, RankHistory @@ -28,7 +27,8 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from app.database.relationship import RelationshipResp + from .favourite_beatmapset import FavouriteBeatmapset + from .relationship import RelationshipResp class Kudosu(TypedDict): @@ -144,6 +144,9 @@ class User(AsyncAttrs, UserBase, table=True): back_populates="user" ) monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( + back_populates="user" + ) email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -201,6 +204,8 @@ class UserResp(UserBase): include: list[str] = [], ruleset: GameMode | None = None, ) -> "UserResp": + from app.dependencies.database import get_redis + from .best_score import BestScore from .relationship import Relationship, RelationshipResp, RelationshipType @@ -320,3 +325,9 @@ SEARCH_INCLUDED = [ "achievements", "monthly_playcounts", ] + +BASE_INCLUDES = [ + "team", + "daily_challenge_user_stats", + "statistics", +] diff --git a/app/database/score.py b/app/database/score.py index 32ddb6c..79cb005 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -169,7 +169,9 @@ class ScoreResp(ScoreBase): assert score.id await score.awaitable_attrs.beatmap s.beatmap = await BeatmapResp.from_db(score.beatmap) - s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) + s.beatmapset = await BeatmapsetResp.from_db( + score.beatmap.beatmapset, session=session, user=score.user + ) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -669,7 +671,7 @@ async def process_score( acc=score.accuracy, ) session.add(best_score) - session.delete(previous_pp_best) if previous_pp_best else None + await session.delete(previous_pp_best) if previous_pp_best else None await session.commit() await session.refresh(score) await session.refresh(score_token) diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 0a25562..9574bdb 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -50,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -62,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -90,7 +90,12 @@ async def batch_get_beatmaps( await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp( + beatmaps=[ + await BeatmapResp.from_db(bm, session=db, user=current_user) + for bm in beatmaps + ] + ) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index b82678d..b4d2e4c 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -1,10 +1,8 @@ from __future__ import annotations -from app.database import ( - Beatmapset, - BeatmapsetResp, - User, -) +from typing import Literal + +from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,7 +10,7 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Form, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError from sqlmodel import select @@ -34,7 +32,9 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = await BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db( + beatmapset, session=db, include=["recent_favourites"], user=current_user + ) return resp @@ -53,3 +53,34 @@ async def download_beatmapset( return RedirectResponse( f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" ) + + +@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"]) +async def favourite_beatmapset( + beatmapset: int, + action: Literal["favourite", "unfavourite"] = Form(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + existing_favourite = ( + await db.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.user_id == current_user.id, + FavouriteBeatmapset.beatmapset_id == beatmapset, + ) + ) + ).first() + + if action == "favourite" and existing_favourite: + raise HTTPException(status_code=400, detail="Already favourited") + elif action == "unfavourite" and not existing_favourite: + raise HTTPException(status_code=400, detail="Not favourited") + + if action == "favourite": + favourite = FavouriteBeatmapset( + user_id=current_user.id, beatmapset_id=beatmapset + ) + db.add(favourite) + else: + await db.delete(existing_favourite) + await db.commit() diff --git a/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py new file mode 100644 index 0000000..84bae15 --- /dev/null +++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py @@ -0,0 +1,40 @@ +"""beatmapset: support favourite count + +Revision ID: 1178d0758ebf +Revises: +Create Date: 2025-08-01 04:05:09.882800 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "1178d0758ebf" +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("beatmapsets", "favourite_count") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "beatmapsets", + sa.Column( + "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False + ), + ) + # ### end Alembic commands ### From 74e4b1cb530a67e8dcd85c29073562a5343719f0 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 04:27:44 +0000 Subject: [PATCH 10/45] fix(relationship): fix unique relationship --- app/database/relationship.py | 7 ++- ...02_relationship_fix_unique_relationship.py | 54 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 migrations/versions/58a11441d302_relationship_fix_unique_relationship.py diff --git a/app/database/relationship.py b/app/database/relationship.py index 7a351aa..b941c28 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -22,12 +22,16 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) user_id: int = Field( default=None, sa_column=Column( BigInteger, ForeignKey("lazer_users.id"), - primary_key=True, index=True, ), ) @@ -36,7 +40,6 @@ class Relationship(SQLModel, table=True): sa_column=Column( BigInteger, ForeignKey("lazer_users.id"), - primary_key=True, index=True, ), ) diff --git a/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py new file mode 100644 index 0000000..e383621 --- /dev/null +++ b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py @@ -0,0 +1,54 @@ +"""relationship: fix unique relationship + +Revision ID: 58a11441d302 +Revises: 1178d0758ebf +Create Date: 2025-08-01 04:23:02.498166 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "58a11441d302" +down_revision: str | Sequence[str] | None = "1178d0758ebf" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "relationship", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + ) + op.drop_constraint("PRIMARY", "relationship", type_="primary") + op.create_primary_key("pk_relationship", "relationship", ["id"]) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("pk_relationship", "relationship", type_="primary") + op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"]) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.drop_column("relationship", "id") + # ### end Alembic commands ### From d399cb52e261571fb946bb0b6e809023f932bd6d Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 11:00:57 +0000 Subject: [PATCH 11/45] fix(signarl): wrong msgpack encode --- app/models/signalr.py | 46 +++++++++++++++++-- app/signalr/packet.py | 2 +- .../msgpack_lazer_api/msgpack_lazer_api.pyi | 2 +- packages/msgpack_lazer_api/src/decode.rs | 4 +- packages/msgpack_lazer_api/src/encode.rs | 10 ++-- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 37b2741..202da4f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime -from typing import Any, get_origin +from enum import Enum +from typing import Any from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, TypeAdapter, @@ -17,22 +19,56 @@ def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): v = getattr(value, field) - anno = get_origin(info.annotation) - if anno and issubclass(anno, BaseModel): + typ = v.__class__ + if issubclass(typ, BaseModel): data.append(serialize_to_list(v)) - elif anno and issubclass(anno, list): + elif issubclass(typ, list): data.append( TypeAdapter( info.annotation, config=ConfigDict(arbitrary_types_allowed=True) ).dump_python(v) ) - elif isinstance(v, datetime.datetime): + elif issubclass(typ, datetime.datetime): data.append([v, 0]) + elif issubclass(typ, Enum): + list_ = list(typ) + data.append(list_.index(v) if v in list_ else v.value) else: data.append(v) return data +def _by_index(v: Any, class_: type[Enum]): + enum_list = list(class_) + if not isinstance(v, int): + return v + if 0 <= v < len(enum_list): + return enum_list[v] + raise ValueError( + f"Value {v} is out of range for enum " + f"{class_.__name__} with {len(enum_list)} items" + ) + + +def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: + return BeforeValidator(lambda v: _by_index(v, enum_class)) + + +def msgpack_union(v): + data = v[1] + data.append(v[0]) + return data + + +def msgpack_union_dump(v: BaseModel) -> list[Any]: + _type = getattr(v, "type", None) + if _type is None: + raise ValueError( + f"Model {v.__class__.__name__} does not have a '_type' attribute" + ) + return [_type, serialize_to_list(v)] + + class MessagePackArrayModel(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index e361ef8..387231c 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -158,7 +158,7 @@ class MsgpackProtocol: result_kind = 2 if packet.error: result_kind = 1 - elif packet.result is None: + elif packet.result is not None: result_kind = 3 payload.extend( [ diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index 88b79c5..b8653f0 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -5,7 +5,7 @@ class APIMod: @property def acronym(self) -> str: ... @property - def settings(self) -> str: ... + def settings(self) -> dict[str, Any]: ... def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index 15156ca..b8e239b 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -13,6 +13,8 @@ pub fn read_object( match rmp::decode::read_marker(cursor) { Ok(marker) => match marker { rmp::Marker::Null => Ok(py.None()), + rmp::Marker::True => Ok(true.into_py_any(py)?), + rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::U8 => { @@ -86,8 +88,6 @@ pub fn read_object( cursor.read_exact(&mut data).map_err(to_py_err)?; Ok(data.into_pyobject(py)?.into_any().unbind()) } - rmp::Marker::True => Ok(true.into_py_any(py)?), - rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32), rmp::Marker::Str8 => { let mut buf = [0u8; 1]; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 88a732b..0e0907c 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -110,12 +110,12 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { write_list(buf, list); } else if let Ok(string) = obj.downcast::() { write_string(buf, string); - } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); - } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); + } else if let Ok(float) = obj.downcast::() { + write_float(buf, float); + } else if let Ok(integer) = obj.downcast::() { + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { From a25cb852d9cd0a39b7b4beb44ab2edd70538d09e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 11:08:59 +0000 Subject: [PATCH 12/45] feat(multiplay): support `CreateRoom` hub method --- app/database/__init__.py | 8 + app/database/playlist_attempts.py | 9 + app/database/playlists.py | 85 ++++++++ app/database/room.py | 137 ++++++++++++- app/models/mods.py | 14 +- app/models/multiplayer_hub.py | 168 ++++++++++++++++ app/models/room.py | 313 +----------------------------- app/router/__init__.py | 8 +- app/router/room.py | 143 ++++++-------- app/signalr/hub/multiplayer.py | 101 +++++++++- main.py | 7 +- 11 files changed, 590 insertions(+), 403 deletions(-) create mode 100644 app/database/playlist_attempts.py create mode 100644 app/database/playlists.py create mode 100644 app/models/multiplayer_hub.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 6e2e8c5..2c01f7a 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -15,8 +15,11 @@ from .lazer_user import ( User, UserResp, ) +from .playlist_attempts import ItemAttemptsCount +from .playlists import Playlist, PlaylistResp from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType +from .room import Room, RoomResp from .score import ( Score, ScoreBase, @@ -43,11 +46,16 @@ __all__ = [ "DailyChallengeStats", "DailyChallengeStatsResp", "FavouriteBeatmapset", + "ItemAttemptsCount", "OAuthToken", "PPBestScore", + "Playlist", + "PlaylistResp", "Relationship", "RelationshipResp", "RelationshipType", + "Room", + "RoomResp", "Score", "ScoreBase", "ScoreResp", diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py new file mode 100644 index 0000000..5b4710a --- /dev/null +++ b/app/database/playlist_attempts.py @@ -0,0 +1,9 @@ +from sqlmodel import Field, SQLModel + + +class ItemAttemptsCount(SQLModel, table=True): + __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] + id: int = Field(foreign_key="room_playlists.db_id", primary_key=True, index=True) + room_id: int = Field(foreign_key="rooms.id", index=True) + attempts: int = Field(default=0) + passed: int = Field(default=0) diff --git a/app/database/playlists.py b/app/database/playlists.py new file mode 100644 index 0000000..42567b6 --- /dev/null +++ b/app/database/playlists.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel +from app.models.mods import APIMod, msgpack_to_apimod +from app.models.multiplayer_hub import PlaylistItem + +from .beatmap import Beatmap, BeatmapResp + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .room import Room + + +class PlaylistBase(SQLModel, UTCBaseModel): + id: int = 0 + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) + ruleset_id: int = Field(ge=0, le=3) + expired: bool = Field(default=False) + playlist_order: int = Field(default=0) + played_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True)), + default=None, + ) + allowed_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + required_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + beatmap_id: int = Field( + foreign_key="beatmaps.id", + ) + freestyle: bool = Field(default=False) + + +class Playlist(PlaylistBase, table=True): + __tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType] + db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) + room_id: int = Field(foreign_key="rooms.id", exclude=True) + + beatmap: Beatmap = Relationship( + sa_relationship_kwargs={ + "lazy": "joined", + } + ) + room: "Room" = Relationship() + + @classmethod + async def from_hub(cls, playlist: PlaylistItem, room_id: int) -> "Playlist": + return cls( + id=playlist.id, + owner_id=playlist.owner_id, + ruleset_id=playlist.ruleset_id, + beatmap_id=playlist.beatmap_id, + required_mods=[msgpack_to_apimod(mod) for mod in playlist.required_mods], + allowed_mods=[msgpack_to_apimod(mod) for mod in playlist.allowed_mods], + expired=playlist.expired, + playlist_order=playlist.order, + played_at=playlist.played_at, + freestyle=playlist.freestyle, + room_id=room_id, + ) + + +class PlaylistResp(PlaylistBase): + beatmap: BeatmapResp | None = None + + @classmethod + async def from_db(cls, playlist: Playlist) -> "PlaylistResp": + resp = cls.model_validate(playlist) + resp.beatmap = await BeatmapResp.from_db(playlist.beatmap) + return resp diff --git a/app/database/room.py b/app/database/room.py index 0b79ee6..8eb882d 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -1,6 +1,135 @@ -from sqlmodel import Field, SQLModel +from datetime import UTC, datetime + +from app.models.multiplayer_hub import ServerMultiplayerRoom +from app.models.room import ( + MatchType, + QueueMode, + RoomCategory, + RoomDifficultyRange, + RoomPlaylistItemStats, + RoomStatus, +) + +from .lazer_user import User, UserResp +from .playlist_attempts import ItemAttemptsCount +from .playlists import Playlist, PlaylistResp + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) -class RoomIndex(SQLModel, table=True): - __tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType] - id: int = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue] +class RoomBase(SQLModel): + name: str = Field(index=True) + category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) + duration: int | None = Field(default=None) # minutes + starts_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + ended_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=None, + ) + participant_count: int = Field(default=0) + max_attempts: int | None = Field(default=None) # playlists + type: MatchType + queue_mode: QueueMode + auto_skip: bool + auto_start_duration: int + status: RoomStatus + # TODO: channel_id + # recent_participants: list[User] + + +class Room(RoomBase, table=True): + __tablename__ = "rooms" # pyright: ignore[reportAssignmentType] + id: int = Field(default=None, primary_key=True, index=True) + host_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + + host: User = Relationship() + playlist: list[Playlist] = Relationship( + sa_relationship_kwargs={ + "lazy": "joined", + "cascade": "all, delete-orphan", + "overlaps": "room", + } + ) + # playlist_item_attempts: list["ItemAttemptsCount"] = Relationship( + # sa_relationship_kwargs={ + # "lazy": "joined", + # "cascade": "all, delete-orphan", + # "primaryjoin": "ItemAttemptsCount.room_id == Room.id", + # } + # ) + + +class RoomResp(RoomBase): + id: int + password: str | None = None + host: UserResp | None = None + playlist: list[PlaylistResp] = [] + playlist_item_stats: RoomPlaylistItemStats | None = None + difficulty_range: RoomDifficultyRange | None = None + current_playlist_item: PlaylistResp | None = None + playlist_item_attempts: list[ItemAttemptsCount] = [] + + @classmethod + async def from_db(cls, room: Room) -> "RoomResp": + resp = cls.model_validate(room.model_dump()) + + stats = RoomPlaylistItemStats(count_active=0, count_total=0) + difficulty_range = RoomDifficultyRange( + min=0, + max=0, + ) + rulesets = set() + for playlist in room.playlist: + stats.count_total += 1 + if not playlist.expired: + stats.count_active += 1 + rulesets.add(playlist.ruleset_id) + difficulty_range.min = min( + difficulty_range.min, playlist.beatmap.difficulty_rating + ) + difficulty_range.max = max( + difficulty_range.max, playlist.beatmap.difficulty_rating + ) + resp.playlist.append(await PlaylistResp.from_db(playlist)) + stats.ruleset_ids = list(rulesets) + resp.playlist_item_stats = stats + resp.difficulty_range = difficulty_range + resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None + # resp.playlist_item_attempts = room.playlist_item_attempts + + return resp + + @classmethod + async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp": + room = server_room.room + resp = cls( + id=room.room_id, + name=room.settings.name, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=room.settings.auto_start_duration, + status=server_room.status, + category=server_room.category, + # duration = room.settings.duration, + starts_at=server_room.start_at, + participant_count=len(room.users), + ) + return resp diff --git a/app/models/mods.py b/app/models/mods.py index abcd2cd..4b20138 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -5,10 +5,12 @@ from typing import Literal, NotRequired, TypedDict from app.path import STATIC_DIR +from msgpack_lazer_api import APIMod as MsgpackAPIMod + class APIMod(TypedDict): acronym: str - settings: NotRequired[dict[str, bool | float | str]] + settings: NotRequired[dict[str, bool | float | str | int]] # https://github.com/ppy/osu-api/wiki#mods @@ -167,3 +169,13 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool: if expected_value != NO_CHECK and value != expected_value: return False return True + + +def msgpack_to_apimod(mod: MsgpackAPIMod) -> APIMod: + """ + Convert a MsgpackAPIMod to an APIMod. + """ + return APIMod( + acronym=mod.acronym, + settings=mod.settings, + ) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py new file mode 100644 index 0000000..fa5e935 --- /dev/null +++ b/app/models/multiplayer_hub.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import datetime +from typing import Annotated, Any, Literal + +from .room import ( + DownloadState, + MatchType, + MultiplayerRoomState, + MultiplayerUserState, + QueueMode, + RoomCategory, + RoomStatus, +) +from .signalr import ( + EnumByIndex, + MessagePackArrayModel, + UserState, + msgpack_union, + msgpack_union_dump, +) + +from msgpack_lazer_api import APIMod +from pydantic import BaseModel, Field, field_serializer, field_validator + + +class MultiplayerClientState(UserState): + room_id: int = 0 + + +class MultiplayerRoomSettings(MessagePackArrayModel): + name: str = "Unnamed Room" + playlist_item_id: int = 0 + password: str = "" + match_type: Annotated[MatchType, EnumByIndex(MatchType)] = MatchType.HEAD_TO_HEAD + queue_mode: Annotated[QueueMode, EnumByIndex(QueueMode)] = QueueMode.HOST_ONLY + auto_start_duration: int = 0 + auto_skip: bool = False + + +class BeatmapAvailability(MessagePackArrayModel): + state: Annotated[DownloadState, EnumByIndex(DownloadState)] = DownloadState.UNKNOWN + progress: float | None = None + + +class _MatchUserState(MessagePackArrayModel): ... + + +class TeamVersusUserState(_MatchUserState): + team_id: int + + type: Literal[0] = Field(0, exclude=True) + + +MatchUserState = TeamVersusUserState + + +class _MatchRoomState(MessagePackArrayModel): ... + + +class MultiplayerTeam(MessagePackArrayModel): + id: int + name: str + + +class TeamVersusRoomState(_MatchRoomState): + teams: list[MultiplayerTeam] = Field( + default_factory=lambda: [ + MultiplayerTeam(id=0, name="Team Red"), + MultiplayerTeam(id=1, name="Team Blue"), + ] + ) + + type: Literal[0] = Field(0, exclude=True) + + +MatchRoomState = TeamVersusRoomState + + +class PlaylistItem(MessagePackArrayModel): + id: int + owner_id: int + beatmap_id: int + checksum: str + ruleset_id: int + required_mods: list[APIMod] = Field(default_factory=list) + allowed_mods: list[APIMod] = Field(default_factory=list) + expired: bool + order: int + played_at: datetime.datetime | None = None + star: float + freestyle: bool + + +class _MultiplayerCountdown(MessagePackArrayModel): + id: int + remaining: int + is_exclusive: bool + + +class MatchStartCountdown(_MultiplayerCountdown): + type: Literal[0] = Field(0, exclude=True) + + +class ForceGameplayStartCountdown(_MultiplayerCountdown): + type: Literal[1] = Field(1, exclude=True) + + +class ServerShuttingDownCountdown(_MultiplayerCountdown): + type: Literal[2] = Field(2, exclude=True) + + +MultiplayerCountdown = ( + MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown +) + + +class MultiplayerRoomUser(MessagePackArrayModel): + user_id: int + state: Annotated[MultiplayerUserState, EnumByIndex(MultiplayerUserState)] = ( + MultiplayerUserState.IDLE + ) + availability: BeatmapAvailability = BeatmapAvailability( + state=DownloadState.UNKNOWN, progress=None + ) + mods: list[APIMod] = Field(default_factory=list) + match_state: MatchUserState | None = None + ruleset_id: int | None = None # freestyle + beatmap_id: int | None = None # freestyle + + @field_validator("match_state", mode="before") + def union_validate(v: Any): + if isinstance(v, list): + return msgpack_union(v) + return v + + @field_serializer("match_state") + def union_serialize(v: Any): + return msgpack_union_dump(v) + + +class MultiplayerRoom(MessagePackArrayModel): + room_id: int + state: Annotated[MultiplayerRoomState, EnumByIndex(MultiplayerRoomState)] + settings: MultiplayerRoomSettings + users: list[MultiplayerRoomUser] = Field(default_factory=list) + host: MultiplayerRoomUser | None = None + match_state: MatchRoomState | None = None + playlist: list[PlaylistItem] = Field(default_factory=list) + active_cooldowns: list[MultiplayerCountdown] = Field(default_factory=list) + channel_id: int + + @field_validator("match_state", mode="before") + def union_validate(v: Any): + if isinstance(v, list): + return msgpack_union(v) + return v + + @field_serializer("match_state") + def union_serialize(v: Any): + return msgpack_union_dump(v) + + +class ServerMultiplayerRoom(BaseModel): + room: MultiplayerRoom + category: RoomCategory + status: RoomStatus + start_at: datetime.datetime diff --git a/app/models/room.py b/app/models/room.py index 2d01a26..42f897c 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -1,17 +1,8 @@ from __future__ import annotations -from datetime import datetime, timedelta from enum import Enum -from app.database.beatmap import Beatmap, BeatmapResp -from app.database.user import User as DBUser -from app.fetcher import Fetcher -from app.models.mods import APIMod -from app.models.user import User -from app.utils import convert_db_user_to_api_user - -from pydantic import BaseModel, Field -from sqlmodel.ext.asyncio.session import AsyncSession +from pydantic import BaseModel class RoomCategory(str, Enum): @@ -64,51 +55,13 @@ class MultiplayerUserState(str, Enum): class DownloadState(str, Enum): - UNKOWN = "unkown" + UNKNOWN = "unknown" NOT_DOWNLOADED = "not_downloaded" DOWNLOADING = "downloading" IMPORTING = "importing" LOCALLY_AVAILABLE = "locally_available" -class PlaylistItem(BaseModel): - id: int - owner_id: int - ruleset_id: int - expired: bool - playlist_order: int | None - played_at: datetime | None - allowed_mods: list[APIMod] = [] - required_mods: list[APIMod] = [] - beatmap_id: int - beatmap: BeatmapResp | None - freestyle: bool - - class Config: - exclude_none = True - - @classmethod - async def from_mpListItem( - cls, item: MultiPlayerListItem, db: AsyncSession, fetcher: Fetcher - ): - s = cls.model_validate(item.model_dump()) - s.id = item.id - s.owner_id = item.OwnerID - s.ruleset_id = item.RulesetID - s.expired = item.Expired - s.playlist_order = item.PlaylistOrder - s.played_at = item.PlayedAt - s.required_mods = item.RequierdMods - s.allowed_mods = item.AllowedMods - s.freestyle = item.Freestyle - cur_beatmap = await Beatmap.get_or_fetch( - db, fetcher=fetcher, bid=item.BeatmapID - ) - s.beatmap = BeatmapResp.from_db(cur_beatmap) - s.beatmap_id = item.BeatmapID - return s - - class RoomPlaylistItemStats(BaseModel): count_active: int count_total: int @@ -120,269 +73,7 @@ class RoomDifficultyRange(BaseModel): max: float -class ItemAttemptsCount(BaseModel): - id: int - attempts: int - passed: bool - - -class PlaylistAggregateScore(BaseModel): - playlist_item_attempts: list[ItemAttemptsCount] - - -class MultiplayerRoomSettings(BaseModel): - Name: str = "Unnamed Room" - PlaylistItemId: int - Password: str = "" - MatchType: MatchType - QueueMode: QueueMode - AutoStartDuration: timedelta - AutoSkip: bool - - @classmethod - def from_apiRoom(cls, room: Room): - s = cls.model_validate(room.model_dump()) - s.Name = room.name - s.Password = room.password if room.password is not None else "" - s.MatchType = room.type - s.QueueMode = room.queue_mode - s.AutoStartDuration = timedelta(seconds=room.auto_start_duration) - s.AutoSkip = room.auto_skip - return s - - -class BeatmapAvailability(BaseModel): - State: DownloadState - DownloadProgress: float | None - - -class MatchUserState(BaseModel): - class Config: - extra = "allow" - - -class TeamVersusState(MatchUserState): - TeamId: int - - -MatchUserStateType = TeamVersusState | MatchUserState - - -class MultiplayerRoomUser(BaseModel): - UserID: int - State: MultiplayerUserState = MultiplayerUserState.IDLE - BeatmapAvailability: BeatmapAvailability - Mods: list[APIMod] = [] - MatchUserState: MatchUserStateType | None - RulesetId: int | None - BeatmapId: int | None - User: User | None - - @classmethod - async def from_id(cls, id: int, db: AsyncSession): - actualUser = ( - await db.exec( - DBUser.all_select_clause().where( - DBUser.id == id, - ) - ) - ).first() - user = ( - await convert_db_user_to_api_user(actualUser) - if actualUser is not None - else None - ) - return MultiplayerRoomUser( - UserID=id, - MatchUserState=None, - BeatmapAvailability=BeatmapAvailability( - State=DownloadState.UNKOWN, DownloadProgress=None - ), - RulesetId=None, - BeatmapId=None, - User=user, - ) - - -class MatchRoomState(BaseModel): - class Config: - extra = "allow" - - -class MultiPlayerTeam(BaseModel): - id: int = 0 - name: str = "" - - -class TeamVersusRoomState(BaseModel): - teams: list[MultiPlayerTeam] = [] - - class Config: - pass - - @classmethod - def create_default(cls): - return cls( - teams=[ - MultiPlayerTeam(id=0, name="Team Red"), - MultiPlayerTeam(id=1, name="Team Blue"), - ] - ) - - -MatchRoomStateType = TeamVersusRoomState | MatchRoomState - - -class MultiPlayerListItem(BaseModel): - id: int - OwnerID: int - BeatmapID: int - BeatmapChecksum: str = "" - RulesetID: int - RequierdMods: list[APIMod] - AllowedMods: list[APIMod] - Expired: bool - PlaylistOrder: int | None - PlayedAt: datetime | None - StarRating: float - Freestyle: bool - - @classmethod - async def from_apiItem(cls, item: PlaylistItem, db: AsyncSession, fetcher: Fetcher): - s = cls.model_validate(item.model_dump()) - s.id = item.id - s.OwnerID = item.owner_id - if item.beatmap is None: # 从客户端接受的一定没有这字段 - cur_beatmap = await Beatmap.get_or_fetch( - db, fetcher=fetcher, bid=item.beatmap_id - ) - s.BeatmapID = cur_beatmap.id if cur_beatmap.id is not None else 0 - s.BeatmapChecksum = cur_beatmap.checksum - s.StarRating = cur_beatmap.difficulty_rating - s.RulesetID = item.ruleset_id - s.RequierdMods = item.required_mods - s.AllowedMods = item.allowed_mods - s.Expired = item.expired - s.PlaylistOrder = item.playlist_order if item.playlist_order is not None else 0 - s.PlayedAt = item.played_at - s.Freestyle = item.freestyle - return s - - -class MultiplayerCountdown(BaseModel): - id: int = 0 - time_remaining: timedelta = timedelta(seconds=0) - is_exclusive: bool = True - - class Config: - extra = "allow" - - -class MatchStartCountdown(MultiplayerCountdown): - pass - - -class ForceGameplayStartCountdown(MultiplayerCountdown): - pass - - -class ServerShuttingCountdown(MultiplayerCountdown): - pass - - -MultiplayerCountdownType = ( - MatchStartCountdown - | ForceGameplayStartCountdown - | ServerShuttingCountdown - | MultiplayerCountdown -) - - class PlaylistStatus(BaseModel): count_active: int count_total: int ruleset_ids: list[int] - - -class MultiplayerRoom(BaseModel): - RoomId: int - State: MultiplayerRoomState - Settings: MultiplayerRoomSettings = MultiplayerRoomSettings( - PlaylistItemId=0, - MatchType=MatchType.HEAD_TO_HEAD, - QueueMode=QueueMode.HOST_ONLY, - AutoStartDuration=timedelta(0), - AutoSkip=False, - ) - Users: list[MultiplayerRoomUser] - Host: MultiplayerRoomUser - MatchState: MatchRoomState | None - Playlist: list[MultiPlayerListItem] - ActivecCountDowns: list[MultiplayerCountdownType] - ChannelID: int - - @classmethod - def CanAddPlayistItem(cls, user: MultiplayerRoomUser) -> bool: - return user == cls.Host or cls.Settings.QueueMode != QueueMode.HOST_ONLY - - @classmethod - async def from_apiRoom(cls, room: Room, db: AsyncSession, fetcher: Fetcher): - s = cls.model_validate(room.model_dump()) - s.RoomId = room.room_id if room.room_id is not None else 0 - s.ChannelID = room.channel_id - s.Settings = MultiplayerRoomSettings.from_apiRoom(room) - s.Host = await MultiplayerRoomUser.from_id(room.host.id if room.host else 0, db) - s.Playlist = [ - await MultiPlayerListItem.from_apiItem(item, db, fetcher) - for item in room.playlist - ] - return s - - -class Room(BaseModel): - room_id: int - name: str - password: str | None - has_password: bool = Field(exclude=True) - host: User | None - category: RoomCategory - duration: int | None - starts_at: datetime | None - ends_at: datetime | None - max_particapants: int | None = Field(exclude=True) - particapant_count: int - recent_particapants: list[User] - type: MatchType - max_attempts: int | None - playlist: list[PlaylistItem] - playlist_item_status: list[RoomPlaylistItemStats] - difficulity_range: RoomDifficultyRange - queue_mode: QueueMode - auto_skip: bool - auto_start_duration: int - current_user_score: PlaylistAggregateScore | None - current_playlist_item: PlaylistItem | None - channel_id: int - status: RoomStatus - availability: RoomAvailability = Field(exclude=True) - - class Config: - exclude_none = True - - @classmethod - async def from_mpRoom( - cls, room: MultiplayerRoom, db: AsyncSession, fetcher: Fetcher - ): - s = cls.model_validate(room.model_dump()) - s.room_id = room.RoomId - s.name = room.Settings.Name - s.password = room.Settings.Password - s.type = room.Settings.MatchType - s.queue_mode = room.Settings.QueueMode - s.auto_skip = room.Settings.AutoSkip - s.host = room.Host.User - s.playlist = [ - await PlaylistItem.from_mpListItem(item, db, fetcher) - for item in room.Playlist - ] - return s diff --git a/app/router/__init__.py b/app/router/__init__.py index 1e87343..22f6c70 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 beatmapset, me, relationship, + room, score, user, ) @@ -14,4 +15,9 @@ from .api_router import router as api_router from .auth import router as auth_router from .fetcher import fetcher_router as fetcher_router -__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"] +__all__ = [ + "api_router", + "auth_router", + "fetcher_router", + "signalr_router", +] diff --git a/app/router/room.py b/app/router/room.py index a2347ec..ba909c6 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -1,109 +1,86 @@ from __future__ import annotations -from app.database.room import RoomIndex +from typing import Literal + +from app.database.room import RoomResp from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.fetcher import Fetcher -from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room +from app.models.room import RoomStatus +from app.signalr.hub import MultiplayerHubs from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Query from redis.asyncio import Redis -from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/rooms", tags=["rooms"], response_model=list[Room]) +@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) async def get_all_rooms( - mode: str | None = Query(None), # TODO: 对房间根据状态进行筛选 - status: str | None = Query(None), - category: str | None = Query( + mode: Literal["open", "ended", "participated", "owned", None] = Query( None - ), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗) + ), # TODO: 对房间根据状态进行筛选 + category: str = Query(default="realtime"), # TODO + status: RoomStatus | None = Query(None), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), redis: Redis = Depends(get_redis), ): - all_roomID = (await db.exec(select(RoomIndex))).all() - redis = get_redis() - if redis is not None: - resp: list[Room] = [] - for id in all_roomID: - dumped_room = redis.get(str(id)) - validated_room = MultiplayerRoom.model_validate_json(str(dumped_room)) - flag: bool = False - if status is not None: - if ( - validated_room.State == MultiplayerRoomState.OPEN - and status == "idle" - ): - flag = True - elif validated_room != MultiplayerRoomState.CLOSED: - flag = True - if flag: - resp.append( - await Room.from_mpRoom( - MultiplayerRoom.model_validate_json(str(dumped_room)), - db, - fetcher, - ) - ) - return resp - else: - raise HTTPException(status_code=500, detail="Redis Error") + rooms = MultiplayerHubs.rooms.values() + return [await RoomResp.from_hub(room) for room in rooms] -@router.get("/rooms/{room}", tags=["room"], response_model=Room) -async def get_room( - room: int, - db: AsyncSession = Depends(get_db), - fetcher: Fetcher = Depends(get_fetcher), -): - redis = get_redis() - if redis: - dumped_room = str(redis.get(str(room))) - if dumped_room is not None: - resp = await Room.from_mpRoom( - MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher - ) - return resp - else: - raise HTTPException(status_code=404, detail="Room Not Found") - else: - raise HTTPException(status_code=500, detail="Redis error") +# @router.get("/rooms/{room}", tags=["room"], response_model=Room) +# async def get_room( +# room: int, +# db: AsyncSession = Depends(get_db), +# fetcher: Fetcher = Depends(get_fetcher), +# ): +# redis = get_redis() +# if redis: +# dumped_room = str(redis.get(str(room))) +# if dumped_room is not None: +# resp = await Room.from_mpRoom( +# MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher +# ) +# return resp +# else: +# raise HTTPException(status_code=404, detail="Room Not Found") +# else: +# raise HTTPException(status_code=500, detail="Redis error") -class APICreatedRoom(Room): - error: str | None +# class APICreatedRoom(Room): +# error: str | None -@router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom) -async def create_room( - room: Room, - db: AsyncSession = Depends(get_db), - fetcher: Fetcher = Depends(get_fetcher), -): - redis = get_redis() - if redis: - room_index = RoomIndex() - db.add(room_index) - await db.commit() - await db.refresh(room_index) - server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher) - redis.set(str(room_index.id), server_room.model_dump_json()) - room.room_id = room_index.id - return APICreatedRoom(**room.model_dump(), error=None) - else: - raise HTTPException(status_code=500, detail="redis error") +# @router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom) +# async def create_room( +# room: Room, +# db: AsyncSession = Depends(get_db), +# fetcher: Fetcher = Depends(get_fetcher), +# ): +# redis = get_redis() +# if redis: +# room_index = RoomIndex() +# db.add(room_index) +# await db.commit() +# await db.refresh(room_index) +# server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher) +# redis.set(str(room_index.id), server_room.model_dump_json()) +# room.room_id = room_index.id +# return APICreatedRoom(**room.model_dump(), error=None) +# else: +# raise HTTPException(status_code=500, detail="redis error") -@router.delete("/rooms/{room}", tags=["room"]) -async def remove_room(room: int, db: AsyncSession = Depends(get_db)): - redis = get_redis() - if redis: - redis.delete(str(room)) - room_index = await db.get(RoomIndex, room) - if room_index: - await db.delete(room_index) - await db.commit() +# @router.delete("/rooms/{room}", tags=["room"]) +# async def remove_room(room: int, db: AsyncSession = Depends(get_db)): +# redis = get_redis() +# if redis: +# redis.delete(str(room)) +# room_index = await db.get(RoomIndex, room) +# if room_index: +# await db.delete(room_index) +# await db.commit() diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 72b4a52..23ca69b 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,6 +1,103 @@ from __future__ import annotations -from .hub import Hub +from typing import override + +from app.database import Room +from app.database.playlists import Playlist +from app.dependencies.database import engine +from app.log import logger +from app.models.multiplayer_hub import ( + MultiplayerClientState, + MultiplayerRoom, + MultiplayerRoomUser, + ServerMultiplayerRoom, +) +from app.models.room import RoomCategory, RoomStatus +from app.models.signalr import serialize_to_list +from app.signalr.exception import InvokeException + +from .hub import Client, Hub + +from sqlmodel.ext.asyncio.session import AsyncSession -class MultiplayerHub(Hub): ... +class MultiplayerHub(Hub[MultiplayerClientState]): + @override + def __init__(self): + super().__init__() + self.rooms: dict[int, ServerMultiplayerRoom] = {} + + @staticmethod + def group_id(room: int) -> str: + return f"room:{room}" + + @override + def create_state(self, client: Client) -> MultiplayerClientState: + return MultiplayerClientState( + connection_id=client.connection_id, + connection_token=client.connection_token, + ) + + async def CreateRoom(self, client: Client, room: MultiplayerRoom): + logger.info(f"[MultiplayerHub] {client.user_id} creating room") + async with AsyncSession(engine) as session: + async with session: + db_room = Room( + name=room.settings.name, + category=RoomCategory.NORMAL, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=room.settings.auto_start_duration, + host_id=client.user_id, + status=RoomStatus.IDLE, + ) + session.add(db_room) + await session.commit() + await session.refresh(db_room) + playitem = room.playlist[0] + playitem.owner_id = client.user_id + playitem.order = 1 + db_playlist = await Playlist.from_hub(playitem, db_room.id) + session.add(db_playlist) + room.room_id = db_room.id + starts_at = db_room.starts_at + await session.commit() + await session.refresh(db_playlist) + # room.playlist.append() + server_room = ServerMultiplayerRoom( + room=room, + category=RoomCategory.NORMAL, + status=RoomStatus.IDLE, + start_at=starts_at, + ) + self.rooms[room.room_id] = server_room + return await self.JoinRoomWithPassword( + client, room.room_id, room.settings.password + ) + + async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): + logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") + user = MultiplayerRoomUser(user_id=client.user_id) + if room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[room_id] + room = server_room.room + for u in room.users: + if u.user_id == client.user_id: + raise InvokeException("You are already in this room") + if room.settings.password != password: + raise InvokeException("Incorrect password") + if room.host is None: + # from CreateRoom + room.host = user + store.room_id = room_id + await self.broadcast_group_call( + self.group_id(room_id), "UserJoined", serialize_to_list(user) + ) + room.users.append(user) + self.add_to_group(client, self.group_id(room_id)) + return serialize_to_list(room) diff --git a/main.py b/main.py index 72444ef..b12f543 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,12 @@ from datetime import datetime from app.config import settings from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher -from app.router import api_router, auth_router, fetcher_router, signalr_router +from app.router import ( + api_router, + auth_router, + fetcher_router, + signalr_router, +) from fastapi import FastAPI From 0b68bdc0c11fc75a6a76e97a4f874c98a784c639 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 01:55:30 +0000 Subject: [PATCH 13/45] fix(beatmap,beatmapset): fix lookup --- app/router/beatmap.py | 2 +- app/router/beatmapset.py | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 9574bdb..7dfd0f9 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -74,7 +74,7 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp) @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( - b_ids: list[int] = Query(alias="id", default_factory=list), + b_ids: list[int] = Query(alias="ids[]", default_factory=list), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index b4d2e4c..f77c2ed 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Literal -from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User +from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -17,6 +17,32 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp) +async def lookup_beatmapset( + beatmap_id: int = Query(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), +): + beatmapset_id = ( + await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id)) + ).first() + if not beatmapset_id: + try: + resp = await fetcher.get_beatmap(beatmap_id) + await Beatmap.from_resp(db, resp) + await db.refresh(current_user) + except HTTPStatusError: + raise HTTPException(status_code=404, detail="Beatmapset not found") + beatmapset = ( + await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id)) + ).first() + if not beatmapset: + raise HTTPException(status_code=404, detail="Beatmapset not found") + resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user) + return resp + + @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, From 884a3f1cc23e298e8396172a1ee63086e6393242 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 01:56:00 +0000 Subject: [PATCH 14/45] fix(leaderboard): missing filter condition for user score --- app/database/score.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/database/score.py b/app/database/score.py index 79cb005..1bd5978 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -328,6 +328,10 @@ async def get_leaderboard( self_query = ( select(BestScore) .where(BestScore.user_id == user.id) + .where( + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ) .order_by(col(BestScore.total_score).desc()) .limit(1) ) From 86e2313c50d2e5c537eb2ddf29df83ad903c23a1 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 01:56:54 +0000 Subject: [PATCH 15/45] feat(multiplayer): support add/edit/remove playlist item --- app/database/playlists.py | 62 +++++++- app/{signalr => }/exception.py | 0 app/models/multiplayer_hub.py | 256 ++++++++++++++++++++++++++++++++- app/signalr/hub/hub.py | 2 +- app/signalr/hub/multiplayer.py | 136 ++++++++++++++++-- 5 files changed, 441 insertions(+), 15 deletions(-) rename app/{signalr => }/exception.py (100%) diff --git a/app/database/playlists.py b/app/database/playlists.py index 42567b6..10ad86b 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -16,7 +16,10 @@ from sqlmodel import ( ForeignKey, Relationship, SQLModel, + func, + select, ) +from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .room import Room @@ -59,9 +62,20 @@ class Playlist(PlaylistBase, table=True): room: "Room" = Relationship() @classmethod - async def from_hub(cls, playlist: PlaylistItem, room_id: int) -> "Playlist": + async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: + stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( + cls.room_id == room_id + ) + result = await session.exec(stmt) + return result.one() + + @classmethod + async def from_hub( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ) -> "Playlist": + next_id = await cls.get_next_id_for_room(room_id, session=session) return cls( - id=playlist.id, + id=next_id, owner_id=playlist.owner_id, ruleset_id=playlist.ruleset_id, beatmap_id=playlist.beatmap_id, @@ -74,6 +88,50 @@ class Playlist(PlaylistBase, table=True): room_id=room_id, ) + @classmethod + async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == playlist.id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + db_playlist.owner_id = playlist.owner_id + db_playlist.ruleset_id = playlist.ruleset_id + db_playlist.beatmap_id = playlist.beatmap_id + db_playlist.required_mods = [ + msgpack_to_apimod(mod) for mod in playlist.required_mods + ] + db_playlist.allowed_mods = [ + msgpack_to_apimod(mod) for mod in playlist.allowed_mods + ] + db_playlist.expired = playlist.expired + db_playlist.playlist_order = playlist.order + db_playlist.played_at = playlist.played_at + db_playlist.freestyle = playlist.freestyle + await session.commit() + + @classmethod + async def add_to_db( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ): + db_playlist = await cls.from_hub(playlist, room_id, session) + session.add(db_playlist) + await session.commit() + await session.refresh(db_playlist) + playlist.id = db_playlist.id + + @classmethod + async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == item_id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + await session.delete(db_playlist) + await session.commit() + class PlaylistResp(PlaylistBase): beatmap: BeatmapResp | None = None diff --git a/app/signalr/exception.py b/app/exception.py similarity index 100% rename from app/signalr/exception.py rename to app/exception.py diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index fa5e935..39ced12 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -1,7 +1,12 @@ from __future__ import annotations +from dataclasses import dataclass import datetime -from typing import Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal + +from app.database.beatmap import Beatmap +from app.dependencies.database import engine +from app.exception import InvokeException from .room import ( DownloadState, @@ -21,7 +26,14 @@ from .signalr import ( ) from msgpack_lazer_api import APIMod -from pydantic import BaseModel, Field, field_serializer, field_validator +from pydantic import Field, field_serializer, field_validator +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.signalr.hub import MultiplayerHub + +HOST_LIMIT = 50 +PER_USER_LIMIT = 3 class MultiplayerClientState(UserState): @@ -161,8 +173,246 @@ class MultiplayerRoom(MessagePackArrayModel): return msgpack_union_dump(v) -class ServerMultiplayerRoom(BaseModel): +class MultiplayerQueue: + def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"): + self.server_room = room + self.hub = hub + self.current_index = 0 + + @property + def upcoming_items(self): + return sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda i: i.order, + ) + + @property + def room(self): + return self.server_room.room + + async def update_order(self): + from app.database import Playlist + + match self.room.settings.queue_mode: + case QueueMode.ALL_PLAYERS_ROUND_ROBIN: + ordered_active_items = [] + + is_first_set = True + first_set_order_by_user_id = {} + + active_items = [item for item in self.room.playlist if not item.expired] + active_items.sort(key=lambda x: x.id) + + user_item_groups = {} + for item in active_items: + if item.owner_id not in user_item_groups: + user_item_groups[item.owner_id] = [] + user_item_groups[item.owner_id].append(item) + + max_items = max( + (len(items) for items in user_item_groups.values()), default=0 + ) + + for i in range(max_items): + current_set = [] + for user_id, items in user_item_groups.items(): + if i < len(items): + current_set.append(items[i]) + + if is_first_set: + current_set.sort(key=lambda item: (item.order, item.id)) + ordered_active_items.extend(current_set) + first_set_order_by_user_id = { + item.owner_id: idx + for idx, item in enumerate(ordered_active_items) + } + else: + current_set.sort( + key=lambda item: first_set_order_by_user_id.get( + item.owner_id, 0 + ) + ) + ordered_active_items.extend(current_set) + + is_first_set = False + + for idx, item in enumerate(ordered_active_items): + item.order = idx + case _: + ordered_active_items = sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda x: x.id, + ) + async with AsyncSession(engine) as session: + for idx, item in enumerate(ordered_active_items): + if item.order == idx: + continue + item.order = idx + await Playlist.update(item, self.room.room_id, session) + await self.hub.playlist_changed( + self.server_room, item, beatmap_changed=False + ) + + async def update_current_item(self): + upcoming_items = self.upcoming_items + next_item = ( + upcoming_items[0] + if upcoming_items + else max( + self.room.playlist, + key=lambda i: i.played_at or datetime.datetime.min, + ) + ) + self.current_index = self.room.playlist.index(next_item) + last_id = self.room.settings.playlist_item_id + self.room.settings.playlist_item_id = next_item.id + if last_id != next_item.id: + await self.hub.setting_changed(self.server_room, True) + + async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + is_host = self.room.host and self.room.host.user_id == user.user_id + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host: + raise InvokeException("You are not the host") + + limit = HOST_LIMIT if is_host else PER_USER_LIMIT + if ( + len( + list( + filter( + lambda x: x.owner_id == user.user_id, + self.room.playlist, + ) + ) + ) + >= limit + ): + raise InvokeException(f"You can only have {limit} items in the queue") + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + # TODO: mods validation + item.owner_id = user.user_id + item.star = float( + beatmap.difficulty_rating + ) # FIXME: beatmap use decimal + await Playlist.add_to_db(item, self.room.room_id, session) + self.room.playlist.append(item) + await self.hub.playlist_added(self.server_room, item) + await self.update_order() + await self.update_current_item() + + async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + existing_item = next( + (i for i in self.room.playlist if i.id == item.id), None + ) + if existing_item is None: + raise InvokeException( + "Attempted to change an item that doesn't exist" + ) + + if existing_item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to change an item which is not owned by the user" + ) + + if existing_item.expired: + raise InvokeException( + "Attempted to change an item which has already been played" + ) + + # TODO: mods validation + item.owner_id = user.user_id + item.star = float(beatmap.difficulty_rating) + item.order = existing_item.order + + await Playlist.update(item, self.room.room_id, session) + + # Update item in playlist + for idx, playlist_item in enumerate(self.room.playlist): + if playlist_item.id == item.id: + self.room.playlist[idx] = item + break + + await self.hub.playlist_changed( + self.server_room, + item, + beatmap_changed=item.checksum != existing_item.checksum, + ) + + async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): + from app.database import Playlist + + item = next( + (i for i in self.room.playlist if i.id == playlist_item_id), + None, + ) + + if item is None: + raise InvokeException("Item does not exist in the room") + + # Check if it's the only item and current item + if item == self.current_item: + upcoming_items = [i for i in self.room.playlist if not i.expired] + if len(upcoming_items) == 1: + raise InvokeException("The only item in the room cannot be removed") + + if item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to remove an item which is not owned by the user" + ) + + if item.expired: + raise InvokeException( + "Attempted to remove an item which has already been played" + ) + + async with AsyncSession(engine) as session: + await Playlist.delete_item(item.id, self.room.room_id, session) + + self.room.playlist.remove(item) + self.current_index = self.room.playlist.index(self.upcoming_items[0]) + + await self.update_order() + await self.update_current_item() + await self.hub.playlist_removed(self.server_room, item.id) + + @property + def current_item(self): + """Get the current playlist item""" + current_id = self.room.settings.playlist_item_id + return next( + (item for item in self.room.playlist if item.id == current_id), + None, + ) + + +@dataclass +class ServerMultiplayerRoom: room: MultiplayerRoom category: RoomCategory status: RoomStatus start_at: datetime.datetime + queue: MultiplayerQueue | None = None diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..4e2c9d6 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -6,9 +6,9 @@ import time from typing import Any from app.config import settings +from app.exception import InvokeException from app.log import logger from app.models.signalr import UserState -from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, CompletionPacket, diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 23ca69b..477396b 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -5,16 +5,19 @@ from typing import override from app.database import Room from app.database.playlists import Playlist from app.dependencies.database import engine +from app.exception import InvokeException from app.log import logger from app.models.multiplayer_hub import ( + BeatmapAvailability, MultiplayerClientState, + MultiplayerQueue, MultiplayerRoom, MultiplayerRoomUser, + PlaylistItem, ServerMultiplayerRoom, ) from app.models.room import RoomCategory, RoomStatus from app.models.signalr import serialize_to_list -from app.signalr.exception import InvokeException from .hub import Client, Hub @@ -40,6 +43,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def CreateRoom(self, client: Client, room: MultiplayerRoom): logger.info(f"[MultiplayerHub] {client.user_id} creating room") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") async with AsyncSession(engine) as session: async with session: db_room = Room( @@ -55,22 +61,22 @@ class MultiplayerHub(Hub[MultiplayerClientState]): session.add(db_room) await session.commit() await session.refresh(db_room) - playitem = room.playlist[0] - playitem.owner_id = client.user_id - playitem.order = 1 - db_playlist = await Playlist.from_hub(playitem, db_room.id) - session.add(db_playlist) + item = room.playlist[0] + item.owner_id = client.user_id room.room_id = db_room.id starts_at = db_room.starts_at - await session.commit() - await session.refresh(db_playlist) - # room.playlist.append() + await Playlist.add_to_db(item, db_room.id, session) server_room = ServerMultiplayerRoom( room=room, category=RoomCategory.NORMAL, status=RoomStatus.IDLE, start_at=starts_at, ) + queue = MultiplayerQueue( + room=server_room, + hub=self, + ) + server_room.queue = queue self.rooms[room.room_id] = server_room return await self.JoinRoomWithPassword( client, room.room_id, room.settings.password @@ -101,3 +107,115 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room.users.append(user) self.add_to_group(client, self.group_id(room_id)) return serialize_to_list(room) + + async def ChangeBeatmapAvailability( + self, client: Client, beatmap_availability: BeatmapAvailability + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + availability = user.availability + if ( + availability.state == beatmap_availability.state + and availability.progress == beatmap_availability.progress + ): + return + user.availability = availability + await self.broadcast_group_call( + self.group_id(store.room_id), + "UserBeatmapAvailabilityChanged", + user.user_id, + serialize_to_list(beatmap_availability), + ) + + async def AddPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.add_item( + item, + user, + ) + + async def EditPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.edit_item( + item, + user, + ) + + async def RemovePlaylistItem(self, client: Client, item_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.remove_item( + item_id, + user, + ) + + async def setting_changed(self, room: ServerMultiplayerRoom, beatmap_changed: bool): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "SettingsChanged", + serialize_to_list(room.room.settings), + ) + + async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemAdded", + serialize_to_list(item), + ) + + async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemRemoved", + item_id, + ) + + async def playlist_changed( + self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemChanged", + serialize_to_list(item), + ) From 693c18ba6e2f561b3fdcff9231e6537884dff7a7 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 04:24:13 +0000 Subject: [PATCH 16/45] feat(multiplayer): support change mods/playstyles(freestyle) --- app/models/multiplayer_hub.py | 79 ++++++++++++++- app/signalr/hub/multiplayer.py | 176 +++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 1 deletion(-) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 39ced12..9bccb71 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -103,6 +103,84 @@ class PlaylistItem(MessagePackArrayModel): star: float freestyle: bool + def validate_user_mods( + self, + user: "MultiplayerRoomUser", + proposed_mods: list[APIMod], + ) -> tuple[bool, list[APIMod]]: + """ + Validates user mods against playlist item rules and returns valid mods. + Returns (is_valid, valid_mods). + """ + from typing import Literal, cast + + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + + ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id + ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) + + valid_mods = [] + all_proposed_valid = True + + # Check if mods are valid for the ruleset + for mod in proposed_mods: + if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]: + all_proposed_valid = False + continue + valid_mods.append(mod) + + # Check mod compatibility within user mods + incompatible_mods = set() + final_valid_mods = [] + for mod in valid_mods: + if mod.acronym in incompatible_mods: + all_proposed_valid = False + continue + setting_mods = API_MODS[ruleset_key].get(mod.acronym) + if setting_mods: + incompatible_mods.update(setting_mods["IncompatibleMods"]) + final_valid_mods.append(mod) + + # If not freestyle, check against allowed mods + if not self.freestyle: + allowed_acronyms = {mod.acronym for mod in self.allowed_mods} + filtered_valid_mods = [] + for mod in final_valid_mods: + if mod.acronym not in allowed_acronyms: + all_proposed_valid = False + else: + filtered_valid_mods.append(mod) + final_valid_mods = filtered_valid_mods + + # Check compatibility with required mods + required_mod_acronyms = {mod.acronym for mod in self.required_mods} + all_mod_acronyms = { + mod.acronym for mod in final_valid_mods + } | required_mod_acronyms + + # Check for incompatibility between required and user mods + filtered_valid_mods = [] + for mod in final_valid_mods: + mod_acronym = mod.acronym + is_compatible = True + + for other_acronym in all_mod_acronyms: + if other_acronym == mod_acronym: + continue + setting_mods = API_MODS[ruleset_key].get(mod_acronym) + if setting_mods and other_acronym in setting_mods["IncompatibleMods"]: + is_compatible = False + all_proposed_valid = False + break + + if is_compatible: + filtered_valid_mods.append(mod) + + return all_proposed_valid, filtered_valid_mods + class _MultiplayerCountdown(MessagePackArrayModel): id: int @@ -405,7 +483,6 @@ class MultiplayerQueue: current_id = self.room.settings.playlist_item_id return next( (item for item in self.room.playlist if item.id == current_id), - None, ) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 477396b..bd34be0 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import override from app.database import Room +from app.database.beatmap import Beatmap from app.database.playlists import Playlist from app.dependencies.database import engine from app.exception import InvokeException @@ -17,10 +18,13 @@ from app.models.multiplayer_hub import ( ServerMultiplayerRoom, ) from app.models.room import RoomCategory, RoomStatus +from app.models.score import GameMode from app.models.signalr import serialize_to_list from .hub import Client, Hub +from msgpack_lazer_api import APIMod +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -219,3 +223,175 @@ class MultiplayerHub(Hub[MultiplayerClientState]): "PlaylistItemChanged", serialize_to_list(item), ) + + async def ChangeUserStyle( + self, client: Client, beatmap_id: int | None, ruleset_id: int | None + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_style( + beatmap_id, + ruleset_id, + server_room, + user, + ) + + async def validate_styles(self, room: ServerMultiplayerRoom): + assert room.queue + if not room.queue.current_item.freestyle: + for user in room.room.users: + await self.change_user_style( + None, + None, + room, + user, + ) + async with AsyncSession(engine) as session: + beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + beatmap_ids = ( + await session.exec( + select(Beatmap.id, Beatmap.mode).where( + Beatmap.beatmapset_id == beatmap.beatmapset_id, + ) + ) + ).all() + for user in room.room.users: + beatmap_id = user.beatmap_id + ruleset_id = user.ruleset_id + user_beatmap = next( + (b for b in beatmap_ids if b[0] == beatmap_id), + None, + ) + if beatmap_id is not None and user_beatmap is None: + beatmap_id = None + beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode + if ( + ruleset_id is not None + and beatmap_ruleset != GameMode.OSU + and ruleset_id != beatmap_ruleset + ): + ruleset_id = None + await self.change_user_style( + beatmap_id, + ruleset_id, + room, + user, + ) + + for user in room.room.users: + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, user.mods + ) + if not is_valid: + await self.change_user_mods(valid_mods, room, user) + + async def change_user_style( + self, + beatmap_id: int | None, + ruleset_id: int | None, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + if user.beatmap_id == beatmap_id and user.ruleset_id == ruleset_id: + return + + if beatmap_id is not None or ruleset_id is not None: + assert room.queue + if not room.queue.current_item.freestyle: + raise InvokeException("Current item does not allow free user styles.") + + async with AsyncSession(engine) as session: + item_beatmap = await session.get( + Beatmap, room.queue.current_item.beatmap_id + ) + if item_beatmap is None: + raise InvokeException("Item beatmap not found") + + user_beatmap = ( + item_beatmap + if beatmap_id is None + else await session.get(Beatmap, beatmap_id) + ) + + if user_beatmap is None: + raise InvokeException("Invalid beatmap selected.") + + if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id: + raise InvokeException( + "Selected beatmap is not from the same beatmap set." + ) + + if ( + ruleset_id is not None + and user_beatmap.mode != GameMode.OSU + and ruleset_id != user_beatmap.mode + ): + raise InvokeException( + "Selected ruleset is not supported for the given beatmap." + ) + + user.beatmap_id = beatmap_id + user.ruleset_id = ruleset_id + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStyleChanged", + user.user_id, + beatmap_id, + ruleset_id, + ) + + async def ChangeUserMods(self, client: Client, new_mods: list[APIMod]): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_mods(new_mods, server_room, user) + + async def change_user_mods( + self, + new_mods: list[APIMod], + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + assert room.queue + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, new_mods + ) + if not is_valid: + incompatible_mods = [ + mod.acronym for mod in new_mods if mod not in valid_mods + ] + raise InvokeException( + f"Incompatible mods were selected: {','.join(incompatible_mods)}" + ) + + if user.mods == valid_mods: + return + + user.mods = valid_mods + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserModsChanged", + user.user_id, + valid_mods, + ) From c83f950d132c6c1860b0f382f6de53578972d737 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 14:59:12 +0000 Subject: [PATCH 17/45] fix(signalr): encode enum by index --- app/models/signalr.py | 34 +++++++++++++++++----------------- app/signalr/hub/hub.py | 8 +++++++- app/signalr/packet.py | 4 +++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 202da4f..9e189e9 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -15,26 +15,26 @@ from pydantic import ( ) +def serialize_msgpack(v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return serialize_to_list(v) + elif issubclass(typ, list): + return TypeAdapter( + typ, config=ConfigDict(arbitrary_types_allowed=True) + ).dump_python(v) + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): - v = getattr(value, field) - typ = v.__class__ - if issubclass(typ, BaseModel): - data.append(serialize_to_list(v)) - elif issubclass(typ, list): - data.append( - TypeAdapter( - info.annotation, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - ) - elif issubclass(typ, datetime.datetime): - data.append([v, 0]) - elif issubclass(typ, Enum): - list_ = list(typ) - data.append(list_.index(v) if v in list_ else v.value) - else: - data.append(v) + data.append(serialize_msgpack(v=getattr(value, field))) return data diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 4e2c9d6..85292df 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,13 +2,15 @@ from __future__ import annotations from abc import abstractmethod import asyncio +from enum import Enum +import inspect import time from typing import Any from app.config import settings from app.exception import InvokeException from app.log import logger -from app.models.signalr import UserState +from app.models.signalr import UserState, _by_index from app.signalr.packet import ( ClosePacket, CompletionPacket, @@ -265,6 +267,10 @@ class Hub[TState: UserState]: continue if issubclass(param.annotation, BaseModel): call_params.append(param.annotation.model_validate(args.pop(0))) + elif inspect.isclass(param.annotation) and issubclass( + param.annotation, Enum + ): + call_params.append(_by_index(args.pop(0), param.annotation)) else: call_params.append(args.pop(0)) return await method_(client, *call_params) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 387231c..de5ce8a 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -8,6 +8,8 @@ from typing import ( Protocol as TypingProtocol, ) +from app.models.signalr import serialize_msgpack + import msgpack_lazer_api as m SEP = b"\x1e" @@ -151,7 +153,7 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append(packet.arguments) + payload.append([serialize_msgpack(arg) for arg in packet.arguments]) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): From 41631b839f35befc1962793d5d1c828174aaa5b5 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 15:02:12 +0000 Subject: [PATCH 18/45] fix(user): last_visit is nullable --- app/database/lazer_user.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 3bd751b..2717c3a 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -66,7 +66,7 @@ class UserBase(UTCBaseModel, SQLModel): is_active: bool = True is_bot: bool = False is_supporter: bool = False - last_visit: datetime = Field( + last_visit: datetime | None = Field( default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) ) pm_friends_only: bool = False From a11ea743a71ffe14388bf4c5cda7132fa5fc02c9 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 11:00:57 +0000 Subject: [PATCH 19/45] fix(signarl): wrong msgpack encode --- app/models/signalr.py | 46 +++++++++++++++++-- app/signalr/packet.py | 2 +- .../msgpack_lazer_api/msgpack_lazer_api.pyi | 2 +- packages/msgpack_lazer_api/src/decode.rs | 4 +- packages/msgpack_lazer_api/src/encode.rs | 10 ++-- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 37b2741..202da4f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime -from typing import Any, get_origin +from enum import Enum +from typing import Any from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, TypeAdapter, @@ -17,22 +19,56 @@ def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): v = getattr(value, field) - anno = get_origin(info.annotation) - if anno and issubclass(anno, BaseModel): + typ = v.__class__ + if issubclass(typ, BaseModel): data.append(serialize_to_list(v)) - elif anno and issubclass(anno, list): + elif issubclass(typ, list): data.append( TypeAdapter( info.annotation, config=ConfigDict(arbitrary_types_allowed=True) ).dump_python(v) ) - elif isinstance(v, datetime.datetime): + elif issubclass(typ, datetime.datetime): data.append([v, 0]) + elif issubclass(typ, Enum): + list_ = list(typ) + data.append(list_.index(v) if v in list_ else v.value) else: data.append(v) return data +def _by_index(v: Any, class_: type[Enum]): + enum_list = list(class_) + if not isinstance(v, int): + return v + if 0 <= v < len(enum_list): + return enum_list[v] + raise ValueError( + f"Value {v} is out of range for enum " + f"{class_.__name__} with {len(enum_list)} items" + ) + + +def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: + return BeforeValidator(lambda v: _by_index(v, enum_class)) + + +def msgpack_union(v): + data = v[1] + data.append(v[0]) + return data + + +def msgpack_union_dump(v: BaseModel) -> list[Any]: + _type = getattr(v, "type", None) + if _type is None: + raise ValueError( + f"Model {v.__class__.__name__} does not have a '_type' attribute" + ) + return [_type, serialize_to_list(v)] + + class MessagePackArrayModel(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index e361ef8..387231c 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -158,7 +158,7 @@ class MsgpackProtocol: result_kind = 2 if packet.error: result_kind = 1 - elif packet.result is None: + elif packet.result is not None: result_kind = 3 payload.extend( [ diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index 88b79c5..b8653f0 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -5,7 +5,7 @@ class APIMod: @property def acronym(self) -> str: ... @property - def settings(self) -> str: ... + def settings(self) -> dict[str, Any]: ... def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index 15156ca..b8e239b 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -13,6 +13,8 @@ pub fn read_object( match rmp::decode::read_marker(cursor) { Ok(marker) => match marker { rmp::Marker::Null => Ok(py.None()), + rmp::Marker::True => Ok(true.into_py_any(py)?), + rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::U8 => { @@ -86,8 +88,6 @@ pub fn read_object( cursor.read_exact(&mut data).map_err(to_py_err)?; Ok(data.into_pyobject(py)?.into_any().unbind()) } - rmp::Marker::True => Ok(true.into_py_any(py)?), - rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32), rmp::Marker::Str8 => { let mut buf = [0u8; 1]; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 88a732b..0e0907c 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -110,12 +110,12 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { write_list(buf, list); } else if let Ok(string) = obj.downcast::() { write_string(buf, string); - } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); - } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); + } else if let Ok(float) = obj.downcast::() { + write_float(buf, float); + } else if let Ok(integer) = obj.downcast::() { + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { From 5ccb35dc8be2ca234d5c04b1f3a11f8e95fea094 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 14:59:12 +0000 Subject: [PATCH 20/45] fix(signalr): encode enum by index --- app/models/signalr.py | 34 +++++++++++++++++----------------- app/signalr/hub/hub.py | 8 +++++++- app/signalr/packet.py | 4 +++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 202da4f..9e189e9 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -15,26 +15,26 @@ from pydantic import ( ) +def serialize_msgpack(v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return serialize_to_list(v) + elif issubclass(typ, list): + return TypeAdapter( + typ, config=ConfigDict(arbitrary_types_allowed=True) + ).dump_python(v) + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): - v = getattr(value, field) - typ = v.__class__ - if issubclass(typ, BaseModel): - data.append(serialize_to_list(v)) - elif issubclass(typ, list): - data.append( - TypeAdapter( - info.annotation, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - ) - elif issubclass(typ, datetime.datetime): - data.append([v, 0]) - elif issubclass(typ, Enum): - list_ = list(typ) - data.append(list_.index(v) if v in list_ else v.value) - else: - data.append(v) + data.append(serialize_msgpack(v=getattr(value, field))) return data diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..a11fbe7 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,12 +2,14 @@ from __future__ import annotations from abc import abstractmethod import asyncio +from enum import Enum +import inspect import time from typing import Any from app.config import settings from app.log import logger -from app.models.signalr import UserState +from app.models.signalr import UserState, _by_index from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, @@ -265,6 +267,10 @@ class Hub[TState: UserState]: continue if issubclass(param.annotation, BaseModel): call_params.append(param.annotation.model_validate(args.pop(0))) + elif inspect.isclass(param.annotation) and issubclass( + param.annotation, Enum + ): + call_params.append(_by_index(args.pop(0), param.annotation)) else: call_params.append(args.pop(0)) return await method_(client, *call_params) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 387231c..de5ce8a 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -8,6 +8,8 @@ from typing import ( Protocol as TypingProtocol, ) +from app.models.signalr import serialize_msgpack + import msgpack_lazer_api as m SEP = b"\x1e" @@ -151,7 +153,7 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append(packet.arguments) + payload.append([serialize_msgpack(arg) for arg in packet.arguments]) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): From 0f1a57afba5b73339a817f075ae9e3141bc0d48b Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 15:02:12 +0000 Subject: [PATCH 21/45] fix(user): last_visit is nullable --- app/database/lazer_user.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 3bd751b..2717c3a 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -66,7 +66,7 @@ class UserBase(UTCBaseModel, SQLModel): is_active: bool = True is_bot: bool = False is_supporter: bool = False - last_visit: datetime = Field( + last_visit: datetime | None = Field( default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) ) pm_friends_only: bool = False From 9f7ab812134910abd5a905633547d4792833fb41 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 09:45:04 +0000 Subject: [PATCH 22/45] feat(signalr): support json & msgpack protocol for all hubs --- app/models/metadata_hub.py | 124 ++++------ app/models/score.py | 44 +--- app/models/signalr.py | 67 +----- app/models/spectator_hub.py | 50 ++-- app/signalr/hub/hub.py | 22 +- app/signalr/hub/metadata.py | 22 +- app/signalr/hub/spectator.py | 25 +- app/signalr/packet.py | 224 +++++++++++++++++- app/utils.py | 52 ++++ .../msgpack_lazer_api/msgpack_lazer_api.pyi | 7 - packages/msgpack_lazer_api/src/decode.rs | 15 +- packages/msgpack_lazer_api/src/encode.rs | 62 +++-- packages/msgpack_lazer_api/src/lib.rs | 25 -- 13 files changed, 432 insertions(+), 307 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 8ae3e65..3206d03 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,114 +1,85 @@ from __future__ import annotations from enum import IntEnum -from typing import Any, Literal +from typing import ClassVar, Literal -from app.models.signalr import UserState +from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field -class _UserActivity(BaseModel): - model_config = ConfigDict(serialize_by_alias=True) - type: Literal[ - "ChoosingBeatmap", - "InSoloGame", - "WatchingReplay", - "SpectatingUser", - "SearchingForLobby", - "InLobby", - "InMultiplayerGame", - "SpectatingMultiplayerGame", - "InPlaylistGame", - "EditingBeatmap", - "ModdingBeatmap", - "TestingBeatmap", - "InDailyChallengeLobby", - "PlayingDailyChallenge", - ] = Field(alias="$dtype") - value: Any | None = Field(alias="$value") +class _UserActivity(SignalRUnionMessage): ... class ChoosingBeatmap(_UserActivity): - type: Literal["ChoosingBeatmap"] = Field(alias="$dtype") - - -class InGameValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") - ruleset_id: int = Field(alias="RulesetID") - ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb") + union_type: ClassVar[Literal[11]] = 11 class _InGame(_UserActivity): - value: InGameValue = Field(alias="$value") + beatmap_id: int + beatmap_display_title: str + ruleset_id: int + ruleset_playing_verb: str class InSoloGame(_InGame): - type: Literal["InSoloGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[12]] = 12 class InMultiplayerGame(_InGame): - type: Literal["InMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[23]] = 23 class SpectatingMultiplayerGame(_InGame): - type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[24]] = 24 class InPlaylistGame(_InGame): - type: Literal["InPlaylistGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[31]] = 31 -class EditingBeatmapValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class PlayingDailyChallenge(_InGame): + union_type: ClassVar[Literal[52]] = 52 class EditingBeatmap(_UserActivity): - type: Literal["EditingBeatmap"] = Field(alias="$dtype") - value: EditingBeatmapValue = Field(alias="$value") + union_type: ClassVar[Literal[41]] = 41 + beatmap_id: int + beatmap_display_title: str -class TestingBeatmap(_UserActivity): - type: Literal["TestingBeatmap"] = Field(alias="$dtype") +class TestingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[43]] = 43 -class ModdingBeatmap(_UserActivity): - type: Literal["ModdingBeatmap"] = Field(alias="$dtype") - - -class WatchingReplayValue(BaseModel): - score_id: int = Field(alias="ScoreID") - player_name: str = Field(alias="PlayerName") - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class ModdingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[42]] = 42 class WatchingReplay(_UserActivity): - type: Literal["WatchingReplay"] = Field(alias="$dtype") - value: int | None = Field(alias="$value") # Replay ID + union_type: ClassVar[Literal[13]] = 13 + score_id: int + player_name: str + beatmap_id: int + beatmap_display_title: str class SpectatingUser(WatchingReplay): - type: Literal["SpectatingUser"] = Field(alias="$dtype") + union_type: ClassVar[Literal[14]] = 14 class SearchingForLobby(_UserActivity): - type: Literal["SearchingForLobby"] = Field(alias="$dtype") - - -class InLobbyValue(BaseModel): - room_id: int = Field(alias="RoomID") - room_name: str = Field(alias="RoomName") + union_type: ClassVar[Literal[21]] = 21 class InLobby(_UserActivity): - type: Literal["InLobby"] = "InLobby" + union_type: ClassVar[Literal[22]] = 22 + room_id: int + room_name: str class InDailyChallengeLobby(_UserActivity): - type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype") + union_type: ClassVar[Literal[51]] = 51 UserActivity = ( @@ -128,23 +99,28 @@ UserActivity = ( ) -class MetadataClientState(UserState): - user_activity: UserActivity | None = None - status: OnlineStatus | None = None - - def to_dict(self) -> dict[str, Any] | None: - if self.status is None or self.status == OnlineStatus.OFFLINE: - return None - dumped = self.model_dump(by_alias=True, exclude_none=True) - return { - "Activity": dumped.get("user_activity"), - "Status": dumped.get("status"), - } +class UserPresence(BaseModel): + activity: UserActivity | None = Field( + default=None, metadata=SignalRMeta(use_upper_case=True) + ) + status: OnlineStatus | None = Field( + default=None, metadata=SignalRMeta(use_upper_case=True) + ) @property def pushable(self) -> bool: return self.status is not None and self.status != OnlineStatus.OFFLINE + @property + def for_push(self) -> "UserPresence | None": + return UserPresence( + activity=self.activity, + status=self.status, + ) + + +class MetadataClientState(UserPresence, UserState): ... + class OnlineStatus(IntEnum): OFFLINE = 0 # 隐身 diff --git a/app/models/score.py b/app/models/score.py index bfc9f53..cef6b28 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -1,6 +1,6 @@ from __future__ import annotations -from enum import Enum, IntEnum +from enum import Enum from typing import Literal, TypedDict from .mods import API_MODS, APIMod, init_mods @@ -93,43 +93,6 @@ class HitResult(str, Enum): ) -class HitResultInt(IntEnum): - PERFECT = 0 - GREAT = 1 - GOOD = 2 - OK = 3 - MEH = 4 - MISS = 5 - - LARGE_TICK_HIT = 6 - SMALL_TICK_HIT = 7 - SLIDER_TAIL_HIT = 8 - - LARGE_BONUS = 9 - SMALL_BONUS = 10 - - LARGE_TICK_MISS = 11 - SMALL_TICK_MISS = 12 - - IGNORE_HIT = 13 - IGNORE_MISS = 14 - - NONE = 15 - COMBO_BREAK = 16 - - LEGACY_COMBO_INCREASE = 99 - - def is_hit(self) -> bool: - return self not in ( - HitResultInt.NONE, - HitResultInt.IGNORE_MISS, - HitResultInt.COMBO_BREAK, - HitResultInt.LARGE_TICK_MISS, - HitResultInt.SMALL_TICK_MISS, - HitResultInt.MISS, - ) - - class LeaderboardType(Enum): GLOBAL = "global" FRIENDS = "friend" @@ -138,7 +101,6 @@ class LeaderboardType(Enum): ScoreStatistics = dict[HitResult, int] -ScoreStatisticsInt = dict[HitResultInt, int] class SoloScoreSubmissionInfo(BaseModel): @@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel): class LegacyReplaySoloScoreInfo(TypedDict): online_id: int mods: list[APIMod] - statistics: ScoreStatisticsInt - maximum_statistics: ScoreStatisticsInt + statistics: ScoreStatistics + maximum_statistics: ScoreStatistics client_version: str rank: Rank user_id: int diff --git a/app/models/signalr.py b/app/models/signalr.py index 9e189e9..90ef95f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,41 +1,21 @@ from __future__ import annotations -import datetime +from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, ClassVar from pydantic import ( BaseModel, BeforeValidator, - ConfigDict, Field, - TypeAdapter, - model_serializer, - model_validator, ) -def serialize_msgpack(v: Any) -> Any: - typ = v.__class__ - if issubclass(typ, BaseModel): - return serialize_to_list(v) - elif issubclass(typ, list): - return TypeAdapter( - typ, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - elif issubclass(typ, datetime.datetime): - return [v, 0] - elif issubclass(typ, Enum): - list_ = list(typ) - return list_.index(v) if v in list_ else v.value - return v - - -def serialize_to_list(value: BaseModel) -> list[Any]: - data = [] - for field, info in value.__class__.model_fields.items(): - data.append(serialize_msgpack(v=getattr(value, field))) - return data +@dataclass +class SignalRMeta: + member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute + json_ignore: bool = False # implement of JsonIgnore (json) attribute + use_upper_case: bool = False # use upper CamelCase for field names def _by_index(v: Any, class_: type[Enum]): @@ -54,37 +34,8 @@ def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: return BeforeValidator(lambda v: _by_index(v, enum_class)) -def msgpack_union(v): - data = v[1] - data.append(v[0]) - return data - - -def msgpack_union_dump(v: BaseModel) -> list[Any]: - _type = getattr(v, "type", None) - if _type is None: - raise ValueError( - f"Model {v.__class__.__name__} does not have a '_type' attribute" - ) - return [_type, serialize_to_list(v)] - - -class MessagePackArrayModel(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - @classmethod - def unpack(cls, v: Any) -> Any: - if isinstance(v, list): - fields = list(cls.model_fields.keys()) - if len(v) != len(fields): - raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}") - return dict(zip(fields, v)) - return v - - @model_serializer - def serialize(self) -> list[Any]: - return serialize_to_list(self) +class SignalRUnionMessage(BaseModel): + union_type: ClassVar[int] class Transport(BaseModel): diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index 994e083..a9e9042 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -5,14 +5,14 @@ from enum import IntEnum from typing import Any from app.models.beatmap import BeatmapRankStatus +from app.models.mods import APIMod from .score import ( - ScoreStatisticsInt, + ScoreStatistics, ) -from .signalr import MessagePackArrayModel, UserState +from .signalr import SignalRMeta, UserState -from msgpack_lazer_api import APIMod -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator class SpectatedUserState(IntEnum): @@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum): Quit = 5 -class SpectatorState(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class SpectatorState(BaseModel): beatmap_id: int | None = None ruleset_id: int | None = None # 0,1,2,3 mods: list[APIMod] = Field(default_factory=list) state: SpectatedUserState - maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict) + maximum_statistics: ScoreStatistics = Field(default_factory=dict) def __eq__(self, other: object) -> bool: if not isinstance(other, SpectatorState): @@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel): ) -class ScoreProcessorStatistics(MessagePackArrayModel): - base_score: int - maximum_base_score: int +class ScoreProcessorStatistics(BaseModel): + base_score: float + maximum_base_score: float accuracy_judgement_count: int combo_portion: float - bouns_portion: float + bonus_portion: float -class FrameHeader(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class FrameHeader(BaseModel): total_score: int - acc: float + accuracy: float combo: int max_combo: int - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) score_processor_statistics: ScoreProcessorStatistics received_time: datetime.datetime mods: list[APIMod] = Field(default_factory=list) @@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel): # SMOKE = 16 -class LegacyReplayFrame(MessagePackArrayModel): +class LegacyReplayFrame(BaseModel): time: float # from ReplayFrame,the parent of LegacyReplayFrame - x: float | None = None - y: float | None = None + mouse_x: float | None = None + mouse_y: float | None = None button_state: int + header: FrameHeader | None = Field( + default=None, metadata=[SignalRMeta(member_ignore=True)] + ) -class FrameDataBundle(MessagePackArrayModel): + +class FrameDataBundle(BaseModel): header: FrameHeader frames: list[LegacyReplayFrame] @@ -106,18 +106,16 @@ class APIUser(BaseModel): class ScoreInfo(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - mods: list[APIMod] user: APIUser ruleset: int - maximum_statistics: ScoreStatisticsInt + maximum_statistics: ScoreStatistics id: int | None = None total_score: int | None = None - acc: float | None = None + accuracy: float | None = None max_combo: int | None = None combo: int | None = None - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) class StoreScore(BaseModel): diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index a11fbe7..f3c5b29 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,14 +2,12 @@ from __future__ import annotations from abc import abstractmethod import asyncio -from enum import Enum -import inspect import time from typing import Any from app.config import settings from app.log import logger -from app.models.signalr import UserState, _by_index +from app.models.signalr import UserState from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, @@ -23,7 +21,6 @@ from app.signalr.store import ResultStore from app.signalr.utils import get_signature from fastapi import WebSocket -from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect @@ -51,7 +48,7 @@ class Client: self.connection_id = connection_id self.connection_token = connection_token self.connection = connection - self.procotol = protocol + self.protocol = protocol self._listen_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None self._store = ResultStore() @@ -64,14 +61,14 @@ class Client: return int(self.connection_id) async def send_packet(self, packet: Packet): - await self.connection.send_bytes(self.procotol.encode(packet)) + await self.connection.send_bytes(self.protocol.encode(packet)) async def receive_packets(self) -> list[Packet]: message = await self.connection.receive() d = message.get("bytes") or message.get("text", "").encode() if not d: return [] - return self.procotol.decode(d) + return self.protocol.decode(d) async def _ping(self): while True: @@ -265,14 +262,9 @@ class Hub[TState: UserState]: for name, param in signature.parameters.items(): if name == "self" or param.annotation is Client: continue - if issubclass(param.annotation, BaseModel): - call_params.append(param.annotation.model_validate(args.pop(0))) - elif inspect.isclass(param.annotation) and issubclass( - param.annotation, Enum - ): - call_params.append(_by_index(args.pop(0), param.annotation)) - else: - call_params.append(args.pop(0)) + call_params.append( + client.protocol.validate_object(args.pop(0), param.annotation) + ) return await method_(client, *call_params) async def call(self, client: Client, method: str, *args: Any) -> Any: diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 227cf7b..64232c0 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -12,7 +12,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv from .hub import Client, Hub -from pydantic import TypeAdapter from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -32,7 +31,7 @@ class MetadataHub(Hub[MetadataClientState]): ) -> set[Coroutine]: if store is not None and not store.pushable: return set() - data = store.to_dict() if store else None + data = store.for_push if store else None return { self.broadcast_group_call( self.online_presence_watchers_group(), @@ -103,7 +102,7 @@ class MetadataHub(Hub[MetadataClientState]): self.friend_presence_watchers_group(friend_id), "FriendPresenceUpdated", friend_id, - friend_state.to_dict(), + friend_state if friend_state.pushable else None, ) ) await asyncio.gather(*tasks) @@ -123,27 +122,24 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) - async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: + async def UpdateActivity( + self, client: Client, activity: UserActivity | None + ) -> None: user_id = int(client.connection_id) - activity = ( - TypeAdapter(UserActivity).validate_python(activity_dict) - if activity_dict - else None - ) store = self.get_or_create_state(client) - store.user_activity = activity + store.activity = activity tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) @@ -155,7 +151,7 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store, ) for user_id, store in self.state.items() if store.pushable diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index bd311ec..b9a3c99 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int -from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt -from app.models.signalr import serialize_to_list +from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics from app.models.spectator_hub import ( APIUser, FrameDataBundle, @@ -69,8 +68,8 @@ def save_replay( md5: str, username: str, score: Score, - statistics: ScoreStatisticsInt, - maximum_statistics: ScoreStatisticsInt, + statistics: ScoreStatistics, + maximum_statistics: ScoreStatistics, frames: list[LegacyReplayFrame], ) -> None: data = bytearray() @@ -107,8 +106,8 @@ def save_replay( last_time = 0 for frame in frames: frame_strs.append( - f"{frame.time - last_time}|{frame.x or 0.0}" - f"|{frame.y or 0.0}|{frame.button_state}" + f"{frame.time - last_time}|{frame.mouse_x or 0.0}" + f"|{frame.mouse_y or 0.0}|{frame.button_state}" ) last_time = frame.time frame_strs.append("-12345|0|0|0") @@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]): async def on_client_connect(self, client: Client) -> None: tasks = [ - self.call_noblock( - client, "UserBeganPlaying", user_id, serialize_to_list(store.state) - ) + self.call_noblock(client, "UserBeganPlaying", user_id, store.state) for user_id, store in self.state.items() if store.state is not None ] @@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserBeganPlaying", user_id, - serialize_to_list(state), + state, ) async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: @@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]): state = self.get_or_create_state(client) if not state.score: return - state.score.score_info.acc = frame_data.header.acc + state.score.score_info.accuracy = frame_data.header.accuracy state.score.score_info.combo = frame_data.header.combo state.score.score_info.max_combo = frame_data.header.max_combo state.score.score_info.statistics = frame_data.header.statistics @@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserSentFrames", user_id, - frame_data.model_dump(), + frame_data, ) async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: @@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserFinishedPlaying", user_id, - serialize_to_list(state) if state else None, + state, ) async def StartWatchingUser(self, client: Client, target_id: int) -> None: @@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]): client, "UserBeganPlaying", target_id, - serialize_to_list(target_store.state), + target_store.state, ) store = self.get_or_create_state(client) store.watched_user.add(target_id) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index de5ce8a..be98c39 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -1,16 +1,24 @@ from __future__ import annotations from dataclasses import dataclass -from enum import IntEnum +import datetime +from enum import Enum, IntEnum +import inspect import json +from types import NoneType, UnionType from typing import ( Any, Protocol as TypingProtocol, + Union, + get_args, + get_origin, ) -from app.models.signalr import serialize_msgpack +from app.models.signalr import SignalRMeta, SignalRUnionMessage +from app.utils import camel_to_snake, snake_to_camel import msgpack_lazer_api as m +from pydantic import BaseModel SEP = b"\x1e" @@ -75,8 +83,61 @@ class Protocol(TypingProtocol): @staticmethod def encode(packet: Packet) -> bytes: ... + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: ... + class MsgpackProtocol: + @classmethod + def serialize_msgpack(cls, v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_to_list(v) + elif issubclass(typ, list): + return [cls.serialize_msgpack(item) for item in v] + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif isinstance(v, dict): + return { + cls.serialize_msgpack(k): cls.serialize_msgpack(value) + for k, value in v.items() + } + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + @classmethod + def serialize_to_list(cls, value: BaseModel) -> list[Any]: + values = [] + for field, info in value.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: + continue + values.append(cls.serialize_msgpack(v=getattr(value, field))) + if issubclass(value.__class__, SignalRUnionMessage): + return [value.__class__.union_type, values] + else: + return values + + @staticmethod + def process_object(v: Any, typ: type[BaseModel]) -> Any: + if isinstance(v, list): + d = {} + for i, f in enumerate(typ.model_fields.items()): + field, info = f + if info.exclude: + continue + anno = info.annotation + if anno is None: + d[camel_to_snake(field)] = v[i] + continue + d[field] = MsgpackProtocol.validate_object(v[i], anno) + return d + return v + @staticmethod def _encode_varint(value: int) -> bytes: result = [] @@ -142,6 +203,49 @@ class MsgpackProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(obj=cls.process_object(v, typ)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return v[0] + elif isinstance(v, list): + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all(issubclass(arg, SignalRUnionMessage) for arg in args): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + union_type = v[0] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.union_type == union_type: + return cls.validate_object(v[1], arg) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload = [packet.type.value, packet.header or {}] @@ -153,7 +257,9 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append([serialize_msgpack(arg) for arg in packet.arguments]) + payload.append( + [MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments] + ) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): @@ -166,7 +272,9 @@ class MsgpackProtocol: [ packet.invocation_id, result_kind, - packet.error or packet.result or None, + packet.error + or MsgpackProtocol.serialize_msgpack(packet.result) + or None, ] ) elif isinstance(packet, ClosePacket): @@ -183,6 +291,62 @@ class MsgpackProtocol: class JSONProtocol: + @classmethod + def serialize_to_json(cls, v: Any): + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_model(v) + elif isinstance(v, dict): + return { + cls.serialize_to_json(k): cls.serialize_to_json(value) + for k, value in v.items() + } + elif isinstance(v, list): + return [cls.serialize_to_json(item) for item in v] + elif isinstance(v, datetime.datetime): + return v.isoformat() + elif isinstance(v, Enum): + return v.value + return v + + @classmethod + def serialize_model(cls, v: BaseModel) -> dict[str, Any]: + d = {} + for field, info in v.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( + cls.serialize_to_json(getattr(v, field)) + ) + if issubclass(v.__class__, SignalRUnionMessage): + return { + "$dtype": v.__class__.__name__, + "$value": d, + } + return d + + @staticmethod + def process_object( + v: Any, typ: type[BaseModel], from_union: bool = False + ) -> dict[str, Any]: + d = {} + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + value = v.get(snake_to_camel(field, not from_union)) + anno = typ.model_fields[field].annotation + if anno is None: + d[field] = value + continue + d[field] = JSONProtocol.validate_object(value, anno) + return d + @staticmethod def decode(input: bytes) -> list[Packet]: packets_raw = input.removesuffix(SEP).split(SEP) @@ -227,6 +391,52 @@ class JSONProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return datetime.datetime.fromisoformat(v) + elif isinstance(v, list): + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + # https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs + union_type = v["$dtype"] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.__name__ == union_type: + return cls.validate_object(v["$value"], arg, True) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload: dict[str, Any] = { @@ -243,7 +453,9 @@ class JSONProtocol: if packet.invocation_id is not None: payload["invocationId"] = packet.invocation_id if packet.arguments is not None: - payload["arguments"] = packet.arguments + payload["arguments"] = [ + JSONProtocol.serialize_to_json(arg) for arg in packet.arguments + ] if packet.stream_ids is not None: payload["streamIds"] = packet.stream_ids elif isinstance(packet, CompletionPacket): @@ -255,7 +467,7 @@ class JSONProtocol: if packet.error is not None: payload["error"] = packet.error if packet.result is not None: - payload["result"] = packet.result + payload["result"] = JSONProtocol.serialize_to_json(packet.result) elif isinstance(packet, PingPacket): pass elif isinstance(packet, ClosePacket): diff --git a/app/utils.py b/app/utils.py index 09e8fdc..0d759a1 100644 --- a/app/utils.py +++ b/app/utils.py @@ -4,3 +4,55 @@ from __future__ import annotations def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" return (timestamp + 62135596800) * 10_000_000 + + +def camel_to_snake(name: str) -> str: + """Convert a camelCase string to snake_case.""" + result = [] + last_chr = "" + for char in name: + if char.isupper(): + if not last_chr.isupper() and result: + result.append("_") + result.append(char.lower()) + else: + result.append(char) + last_chr = char + return "".join(result) + + +def snake_to_camel(name: str, lower_case: bool = True) -> str: + """Convert a snake_case string to camelCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations: + result.append(part.upper()) + else: + if result or not lower_case: + result.append(part.capitalize()) + else: + result.append(part.lower()) + + return "".join(result) diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index b8653f0..433c53b 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -1,11 +1,4 @@ from typing import Any -class APIMod: - def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ... - @property - def acronym(self) -> str: ... - @property - def settings(self) -> dict[str, Any]: ... - def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index b8e239b..1e36c42 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -1,8 +1,6 @@ -use crate::APIMod; use chrono::{TimeZone, Utc}; use pyo3::types::PyDict; use pyo3::{prelude::*, IntoPyObjectExt}; -use std::collections::HashMap; use std::io::Read; pub fn read_object( @@ -206,13 +204,12 @@ fn read_array( let obj1 = read_object(py, cursor, false)?; if obj1.extract::(py).map_or(false, |k| k.len() == 2) { let obj2 = read_object(py, cursor, true)?; - return Ok(APIMod { - acronym: obj1.extract::(py)?, - settings: obj2.extract::>(py)?, - } - .into_pyobject(py)? - .into_any() - .unbind()); + + let api_mod_dict = PyDict::new(py); + api_mod_dict.set_item("acronym", obj1)?; + api_mod_dict.set_item("settings", obj2)?; + + return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind()); } else { items.push(obj1); i += 1; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 0e0907c..3ff4864 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -1,8 +1,7 @@ -use crate::APIMod; -use chrono::{DateTime, Utc}; -use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods}; +use chrono::{DateTime, Utc}; +use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods}; use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString}; -use pyo3::{Bound, PyAny, PyRef, Python}; +use pyo3::{Bound, PyAny}; use std::io::Write; fn write_list(buf: &mut Vec, obj: &Bound<'_, PyList>) { @@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec, obj: &Bound<'_, PyDict>) { } } -fn write_nil(buf: &mut Vec){ +fn write_nil(buf: &mut Vec) { rmp::encode::write_nil(buf).unwrap(); } -// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs -fn write_api_mod(buf: &mut Vec, api_mod: PyRef) { - rmp::encode::write_array_len(buf, 2).unwrap(); - rmp::encode::write_str(buf, &api_mod.acronym).unwrap(); - rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap(); - for (k, v) in api_mod.settings.iter() { - rmp::encode::write_str(buf, k).unwrap(); - Python::with_gil(|py| write_object(buf, &v.bind(py))); +fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool { + if let Ok(Some(acronym)) = dict.get_item("acronym") { + if let Ok(acronym_str) = acronym.extract::() { + return acronym_str.len() == 2; + } } + false +} + +// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs +fn write_api_mod(buf: &mut Vec, api_mod: &Bound<'_, PyDict>) -> PyResult<()> { + let acronym = api_mod + .get_item("acronym")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?; + let acronym_str = acronym.extract::()?; + + let settings = api_mod + .get_item("settings")? + .unwrap_or_else(|| PyDict::new(acronym.py()).into_any()); + let settings_dict = settings.downcast::()?; + + rmp::encode::write_array_len(buf, 2).unwrap(); + rmp::encode::write_str(buf, &acronym_str).unwrap(); + rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap(); + + for (k, v) in settings_dict.iter() { + let key_str = k.extract::()?; + rmp::encode::write_str(buf, &key_str).unwrap(); + write_object(buf, &v); + } + + Ok(()) } fn write_datetime(buf: &mut Vec, obj: &Bound<'_, PyDateTime>) { @@ -111,21 +133,23 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { } else if let Ok(string) = obj.downcast::() { write_string(buf, string); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); + write_float(buf, float); } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { - write_hashmap(buf, dict); + if is_api_mod(dict) { + write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict)); + } else { + write_hashmap(buf, dict); + } } else if let Ok(_none) = obj.downcast::() { write_nil(buf); } else if let Ok(datetime) = obj.downcast::() { write_datetime(buf, datetime); - } else if let Ok(api_mod) = obj.extract::>() { - write_api_mod(buf, api_mod); } else { panic!("Unsupported type"); } diff --git a/packages/msgpack_lazer_api/src/lib.rs b/packages/msgpack_lazer_api/src/lib.rs index fda540c..220e645 100644 --- a/packages/msgpack_lazer_api/src/lib.rs +++ b/packages/msgpack_lazer_api/src/lib.rs @@ -2,30 +2,6 @@ mod decode; mod encode; use pyo3::prelude::*; -use std::collections::HashMap; - -#[pyclass] -struct APIMod { - #[pyo3(get, set)] - acronym: String, - #[pyo3(get, set)] - settings: HashMap, -} - -#[pymethods] -impl APIMod { - #[new] - fn new(acronym: String, settings: HashMap) -> Self { - APIMod { acronym, settings } - } - - fn __repr__(&self) -> String { - format!( - "APIMod(acronym='{}', settings={:?})", - self.acronym, self.settings - ) - } -} #[pyfunction] #[pyo3(name = "encode")] @@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult { fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(encode_py, m)?)?; m.add_function(wrap_pyfunction!(decode_py, m)?)?; - m.add_class::()?; Ok(()) } From b7bc87b8b63543e75db317f9ac52b73fe8cbae9c Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 11:01:25 +0000 Subject: [PATCH 23/45] fix(signalr): fix SignalRMeta cannot be read --- app/models/metadata_hub.py | 14 +++++++------- app/models/spectator_hub.py | 8 ++++---- app/signalr/packet.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 3206d03..a678d7f 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import IntEnum -from typing import ClassVar, Literal +from typing import Annotated, ClassVar, Literal from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState @@ -100,12 +100,12 @@ UserActivity = ( class UserPresence(BaseModel): - activity: UserActivity | None = Field( - default=None, metadata=SignalRMeta(use_upper_case=True) - ) - status: OnlineStatus | None = Field( - default=None, metadata=SignalRMeta(use_upper_case=True) - ) + activity: Annotated[ + UserActivity | None, Field(default=None), SignalRMeta(use_upper_case=True) + ] + status: Annotated[ + OnlineStatus | None, Field(default=None), SignalRMeta(use_upper_case=True) + ] @property def pushable(self) -> bool: diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index a9e9042..9f35932 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime from enum import IntEnum -from typing import Any +from typing import Annotated, Any from app.models.beatmap import BeatmapRankStatus from app.models.mods import APIMod @@ -89,9 +89,9 @@ class LegacyReplayFrame(BaseModel): mouse_y: float | None = None button_state: int - header: FrameHeader | None = Field( - default=None, metadata=[SignalRMeta(member_ignore=True)] - ) + header: Annotated[ + FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True) + ] class FrameDataBundle(BaseModel): diff --git a/app/signalr/packet.py b/app/signalr/packet.py index be98c39..70c2276 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -126,15 +126,19 @@ class MsgpackProtocol: def process_object(v: Any, typ: type[BaseModel]) -> Any: if isinstance(v, list): d = {} - for i, f in enumerate(typ.model_fields.items()): - field, info = f - if info.exclude: + i = 0 + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: continue anno = info.annotation if anno is None: d[camel_to_snake(field)] = v[i] - continue - d[field] = MsgpackProtocol.validate_object(v[i], anno) + else: + d[field] = MsgpackProtocol.validate_object(v[i], anno) + i += 1 return d return v From 2600fa499f05d266072fb447cd708af4587fc1c9 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 12:53:22 +0000 Subject: [PATCH 24/45] feat(multiplayer): support play WIP --- app/database/playlists.py | 18 +-- app/database/room.py | 2 +- app/database/score.py | 7 +- app/models/mods.py | 12 -- app/models/multiplayer_hub.py | 282 ++++++++++++++++++++++----------- app/models/room.py | 9 ++ app/models/signalr.py | 1 + app/router/score.py | 219 ++++++++++++++++++------- app/signalr/hub/multiplayer.py | 262 ++++++++++++++++++++++++++++-- app/signalr/packet.py | 46 +++++- app/utils.py | 4 +- 11 files changed, 666 insertions(+), 196 deletions(-) diff --git a/app/database/playlists.py b/app/database/playlists.py index 10ad86b..328f17d 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING from app.models.model import UTCBaseModel -from app.models.mods import APIMod, msgpack_to_apimod +from app.models.mods import APIMod from app.models.multiplayer_hub import PlaylistItem from .beatmap import Beatmap, BeatmapResp @@ -79,10 +79,10 @@ class Playlist(PlaylistBase, table=True): owner_id=playlist.owner_id, ruleset_id=playlist.ruleset_id, beatmap_id=playlist.beatmap_id, - required_mods=[msgpack_to_apimod(mod) for mod in playlist.required_mods], - allowed_mods=[msgpack_to_apimod(mod) for mod in playlist.allowed_mods], + required_mods=playlist.required_mods, + allowed_mods=playlist.allowed_mods, expired=playlist.expired, - playlist_order=playlist.order, + playlist_order=playlist.playlist_order, played_at=playlist.played_at, freestyle=playlist.freestyle, room_id=room_id, @@ -99,14 +99,10 @@ class Playlist(PlaylistBase, table=True): db_playlist.owner_id = playlist.owner_id db_playlist.ruleset_id = playlist.ruleset_id db_playlist.beatmap_id = playlist.beatmap_id - db_playlist.required_mods = [ - msgpack_to_apimod(mod) for mod in playlist.required_mods - ] - db_playlist.allowed_mods = [ - msgpack_to_apimod(mod) for mod in playlist.allowed_mods - ] + db_playlist.required_mods = playlist.required_mods + db_playlist.allowed_mods = playlist.allowed_mods db_playlist.expired = playlist.expired - db_playlist.playlist_order = playlist.order + db_playlist.playlist_order = playlist.playlist_order db_playlist.played_at = playlist.played_at db_playlist.freestyle = playlist.freestyle await session.commit() diff --git a/app/database/room.py b/app/database/room.py index 8eb882d..80457b6 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -125,7 +125,7 @@ class RoomResp(RoomBase): type=room.settings.match_type, queue_mode=room.settings.queue_mode, auto_skip=room.settings.auto_skip, - auto_start_duration=room.settings.auto_start_duration, + auto_start_duration=int(room.settings.auto_start_duration.total_seconds()), status=server_room.status, category=server_room.category, # duration = room.settings.duration, diff --git a/app/database/score.py b/app/database/score.py index 1bd5978..abc3d75 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -91,7 +91,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # optional # TODO: current_user_attributes - position: int | None = Field(default=None) # multiplayer + # position: int | None = Field(default=None) # multiplayer class Score(ScoreBase, table=True): @@ -162,6 +162,7 @@ class ScoreResp(ScoreBase): maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None + position: int = 1 # TODO @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": @@ -618,6 +619,8 @@ async def process_score( fetcher: "Fetcher", session: AsyncSession, redis: Redis, + item_id: int | None = None, + room_id: int | None = None, ) -> Score: assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) @@ -649,6 +652,8 @@ async def process_score( nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), + playlist_item_id=item_id, + room_id=room_id, ) if can_get_pp: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) diff --git a/app/models/mods.py b/app/models/mods.py index 4b20138..299a05f 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -5,8 +5,6 @@ from typing import Literal, NotRequired, TypedDict from app.path import STATIC_DIR -from msgpack_lazer_api import APIMod as MsgpackAPIMod - class APIMod(TypedDict): acronym: str @@ -169,13 +167,3 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool: if expected_value != NO_CHECK and value != expected_value: return False return True - - -def msgpack_to_apimod(mod: MsgpackAPIMod) -> APIMod: - """ - Convert a MsgpackAPIMod to an APIMod. - """ - return APIMod( - acronym=mod.acronym, - settings=mod.settings, - ) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 9bccb71..ba8a050 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -1,13 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass -import datetime -from typing import TYPE_CHECKING, Annotated, Any, Literal +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from app.database.beatmap import Beatmap from app.dependencies.database import engine from app.exception import InvokeException +from .mods import APIMod from .room import ( DownloadState, MatchType, @@ -18,15 +21,14 @@ from .room import ( RoomStatus, ) from .signalr import ( - EnumByIndex, - MessagePackArrayModel, + SignalRMeta, + SignalRUnionMessage, UserState, - msgpack_union, - msgpack_union_dump, ) -from msgpack_lazer_api import APIMod -from pydantic import Field, field_serializer, field_validator +from pydantic import BaseModel, Field +from sqlalchemy import update +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: @@ -40,37 +42,37 @@ class MultiplayerClientState(UserState): room_id: int = 0 -class MultiplayerRoomSettings(MessagePackArrayModel): +class MultiplayerRoomSettings(BaseModel): name: str = "Unnamed Room" - playlist_item_id: int = 0 + playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] password: str = "" - match_type: Annotated[MatchType, EnumByIndex(MatchType)] = MatchType.HEAD_TO_HEAD - queue_mode: Annotated[QueueMode, EnumByIndex(QueueMode)] = QueueMode.HOST_ONLY - auto_start_duration: int = 0 + match_type: MatchType = MatchType.HEAD_TO_HEAD + queue_mode: QueueMode = QueueMode.HOST_ONLY + auto_start_duration: timedelta = timedelta(seconds=0) auto_skip: bool = False -class BeatmapAvailability(MessagePackArrayModel): - state: Annotated[DownloadState, EnumByIndex(DownloadState)] = DownloadState.UNKNOWN +class BeatmapAvailability(BaseModel): + state: DownloadState = DownloadState.UNKNOWN progress: float | None = None -class _MatchUserState(MessagePackArrayModel): ... +class _MatchUserState(SignalRUnionMessage): ... class TeamVersusUserState(_MatchUserState): team_id: int - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 MatchUserState = TeamVersusUserState -class _MatchRoomState(MessagePackArrayModel): ... +class _MatchRoomState(SignalRUnionMessage): ... -class MultiplayerTeam(MessagePackArrayModel): +class MultiplayerTeam(BaseModel): id: int name: str @@ -83,24 +85,24 @@ class TeamVersusRoomState(_MatchRoomState): ] ) - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 MatchRoomState = TeamVersusRoomState -class PlaylistItem(MessagePackArrayModel): - id: int +class PlaylistItem(BaseModel): + id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] owner_id: int beatmap_id: int - checksum: str + beatmap_checksum: str ruleset_id: int required_mods: list[APIMod] = Field(default_factory=list) allowed_mods: list[APIMod] = Field(default_factory=list) expired: bool - order: int - played_at: datetime.datetime | None = None - star: float + playlist_order: int + played_at: datetime | None = None + star_rating: float freestyle: bool def validate_user_mods( @@ -127,7 +129,10 @@ class PlaylistItem(MessagePackArrayModel): # Check if mods are valid for the ruleset for mod in proposed_mods: - if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]: + if ( + ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[ruleset_key] + ): all_proposed_valid = False continue valid_mods.append(mod) @@ -136,35 +141,35 @@ class PlaylistItem(MessagePackArrayModel): incompatible_mods = set() final_valid_mods = [] for mod in valid_mods: - if mod.acronym in incompatible_mods: + if mod["acronym"] in incompatible_mods: all_proposed_valid = False continue - setting_mods = API_MODS[ruleset_key].get(mod.acronym) + setting_mods = API_MODS[ruleset_key].get(mod["acronym"]) if setting_mods: incompatible_mods.update(setting_mods["IncompatibleMods"]) final_valid_mods.append(mod) # If not freestyle, check against allowed mods if not self.freestyle: - allowed_acronyms = {mod.acronym for mod in self.allowed_mods} + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} filtered_valid_mods = [] for mod in final_valid_mods: - if mod.acronym not in allowed_acronyms: + if mod["acronym"] not in allowed_acronyms: all_proposed_valid = False else: filtered_valid_mods.append(mod) final_valid_mods = filtered_valid_mods # Check compatibility with required mods - required_mod_acronyms = {mod.acronym for mod in self.required_mods} + required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} all_mod_acronyms = { - mod.acronym for mod in final_valid_mods + mod["acronym"] for mod in final_valid_mods } | required_mod_acronyms # Check for incompatibility between required and user mods filtered_valid_mods = [] for mod in final_valid_mods: - mod_acronym = mod.acronym + mod_acronym = mod["acronym"] is_compatible = True for other_acronym in all_mod_acronyms: @@ -181,23 +186,29 @@ class PlaylistItem(MessagePackArrayModel): return all_proposed_valid, filtered_valid_mods + def clone(self) -> "PlaylistItem": + copy = self.model_copy() + copy.required_mods = list(self.required_mods) + copy.allowed_mods = list(self.allowed_mods) + return copy -class _MultiplayerCountdown(MessagePackArrayModel): - id: int - remaining: int - is_exclusive: bool + +class _MultiplayerCountdown(BaseModel): + id: int = 0 + remaining: timedelta + is_exclusive: bool = False class MatchStartCountdown(_MultiplayerCountdown): - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 class ForceGameplayStartCountdown(_MultiplayerCountdown): - type: Literal[1] = Field(1, exclude=True) + union_type: ClassVar[Literal[1]] = 1 class ServerShuttingDownCountdown(_MultiplayerCountdown): - type: Literal[2] = Field(2, exclude=True) + union_type: ClassVar[Literal[2]] = 2 MultiplayerCountdown = ( @@ -205,11 +216,9 @@ MultiplayerCountdown = ( ) -class MultiplayerRoomUser(MessagePackArrayModel): +class MultiplayerRoomUser(BaseModel): user_id: int - state: Annotated[MultiplayerUserState, EnumByIndex(MultiplayerUserState)] = ( - MultiplayerUserState.IDLE - ) + state: MultiplayerUserState = MultiplayerUserState.IDLE availability: BeatmapAvailability = BeatmapAvailability( state=DownloadState.UNKNOWN, progress=None ) @@ -218,50 +227,33 @@ class MultiplayerRoomUser(MessagePackArrayModel): ruleset_id: int | None = None # freestyle beatmap_id: int | None = None # freestyle - @field_validator("match_state", mode="before") - def union_validate(v: Any): - if isinstance(v, list): - return msgpack_union(v) - return v - @field_serializer("match_state") - def union_serialize(v: Any): - return msgpack_union_dump(v) - - -class MultiplayerRoom(MessagePackArrayModel): +class MultiplayerRoom(BaseModel): room_id: int - state: Annotated[MultiplayerRoomState, EnumByIndex(MultiplayerRoomState)] + state: MultiplayerRoomState settings: MultiplayerRoomSettings users: list[MultiplayerRoomUser] = Field(default_factory=list) host: MultiplayerRoomUser | None = None match_state: MatchRoomState | None = None playlist: list[PlaylistItem] = Field(default_factory=list) - active_cooldowns: list[MultiplayerCountdown] = Field(default_factory=list) + active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) channel_id: int - @field_validator("match_state", mode="before") - def union_validate(v: Any): - if isinstance(v, list): - return msgpack_union(v) - return v - - @field_serializer("match_state") - def union_serialize(v: Any): - return msgpack_union_dump(v) - class MultiplayerQueue: - def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"): + def __init__(self, room: "ServerMultiplayerRoom"): self.server_room = room - self.hub = hub self.current_index = 0 + @property + def hub(self) -> "MultiplayerHub": + return self.server_room.hub + @property def upcoming_items(self): return sorted( (item for item in self.room.playlist if not item.expired), - key=lambda i: i.order, + key=lambda i: i.playlist_order, ) @property @@ -323,9 +315,9 @@ class MultiplayerQueue: ) async with AsyncSession(engine) as session: for idx, item in enumerate(ordered_active_items): - if item.order == idx: + if item.playlist_order == idx: continue - item.order = idx + item.playlist_order = idx await Playlist.update(item, self.room.room_id, session) await self.hub.playlist_changed( self.server_room, item, beatmap_changed=False @@ -338,7 +330,7 @@ class MultiplayerQueue: if upcoming_items else max( self.room.playlist, - key=lambda i: i.played_at or datetime.datetime.min, + key=lambda i: i.played_at or datetime.min, ) ) self.current_index = self.room.playlist.index(next_item) @@ -356,14 +348,7 @@ class MultiplayerQueue: limit = HOST_LIMIT if is_host else PER_USER_LIMIT if ( - len( - list( - filter( - lambda x: x.owner_id == user.user_id, - self.room.playlist, - ) - ) - ) + len([True for u in self.room.playlist if u.owner_id == user.user_id]) >= limit ): raise InvokeException(f"You can only have {limit} items in the queue") @@ -376,11 +361,11 @@ class MultiplayerQueue: beatmap = await session.get(Beatmap, item.beatmap_id) if beatmap is None: raise InvokeException("Beatmap not found") - if item.checksum != beatmap.checksum: + if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") # TODO: mods validation item.owner_id = user.user_id - item.star = float( + item.star_rating = float( beatmap.difficulty_rating ) # FIXME: beatmap use decimal await Playlist.add_to_db(item, self.room.room_id, session) @@ -400,7 +385,7 @@ class MultiplayerQueue: beatmap = await session.get(Beatmap, item.beatmap_id) if beatmap is None: raise InvokeException("Beatmap not found") - if item.checksum != beatmap.checksum: + if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") existing_item = next( @@ -423,8 +408,8 @@ class MultiplayerQueue: # TODO: mods validation item.owner_id = user.user_id - item.star = float(beatmap.difficulty_rating) - item.order = existing_item.order + item.star_rating = float(beatmap.difficulty_rating) + item.playlist_order = existing_item.playlist_order await Playlist.update(item, self.room.room_id, session) @@ -437,7 +422,8 @@ class MultiplayerQueue: await self.hub.playlist_changed( self.server_room, item, - beatmap_changed=item.checksum != existing_item.checksum, + beatmap_changed=item.beatmap_checksum + != existing_item.beatmap_checksum, ) async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): @@ -477,12 +463,46 @@ class MultiplayerQueue: await self.update_current_item() await self.hub.playlist_removed(self.server_room, item.id) + async def finish_current_item(self): + from app.database import Playlist + + async with AsyncSession(engine) as session: + played_at = datetime.now(UTC) + await session.execute( + update(Playlist) + .where( + col(Playlist.id) == self.current_item.id, + col(Playlist.room_id) == self.room.room_id, + ) + .values(expired=True, played_at=played_at) + ) + self.room.playlist[self.current_index].expired = True + self.room.playlist[self.current_index].played_at = played_at + await self.hub.playlist_changed(self.server_room, self.current_item, True) + await self.update_order() + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + @property def current_item(self): - """Get the current playlist item""" - current_id = self.room.settings.playlist_item_id - return next( - (item for item in self.room.playlist if item.id == current_id), + return self.room.playlist[self.current_index] + + +@dataclass +class CountdownInfo: + countdown: MultiplayerCountdown + duration: timedelta + task: asyncio.Task | None = None + + def __init__(self, countdown: MultiplayerCountdown): + self.countdown = countdown + self.duration = ( + countdown.remaining + if countdown.remaining > timedelta(seconds=0) + else timedelta(seconds=0) ) @@ -491,5 +511,79 @@ class ServerMultiplayerRoom: room: MultiplayerRoom category: RoomCategory status: RoomStatus - start_at: datetime.datetime + start_at: datetime + hub: "MultiplayerHub" queue: MultiplayerQueue | None = None + _next_countdown_id: int = 0 + _countdown_id_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _tracked_countdown: dict[int, CountdownInfo] = field(default_factory=dict) + + async def get_next_countdown_id(self) -> int: + async with self._countdown_id_lock: + self._next_countdown_id += 1 + return self._next_countdown_id + + async def start_countdown( + self, + countdown: MultiplayerCountdown, + on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None, + ): + async def _countdown_task(self: "ServerMultiplayerRoom"): + await asyncio.sleep(info.duration.total_seconds()) + await self.stop_countdown(countdown) + if on_complete is not None: + await on_complete(self) + + if countdown.is_exclusive: + await self.stop_all_countdowns() + + countdown.id = await self.get_next_countdown_id() + info = CountdownInfo(countdown) + self.room.active_countdowns.append(info.countdown) + self._tracked_countdown[countdown.id] = info + await self.hub.send_match_event( + self, CountdownStartedEvent(countdown=info.countdown) + ) + info.task = asyncio.create_task(_countdown_task(self)) + + async def stop_countdown(self, countdown: MultiplayerCountdown): + info = next( + ( + info + for info in self._tracked_countdown.values() + if info.countdown.id == countdown.id + ), + None, + ) + if info is None: + return + if info.task is not None and not info.task.done(): + info.task.cancel() + del self._tracked_countdown[countdown.id] + self.room.active_countdowns.remove(countdown) + await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + + async def stop_all_countdowns(self): + for countdown in list(self._tracked_countdown.values()): + await self.stop_countdown(countdown.countdown) + + self._tracked_countdown.clear() + self.room.active_countdowns.clear() + + +class _MatchServerEvent(BaseModel): ... + + +class CountdownStartedEvent(_MatchServerEvent): + countdown: MultiplayerCountdown + + type: Literal[0] = Field(default=0, exclude=True) + + +class CountdownStoppedEvent(_MatchServerEvent): + id: int + + type: Literal[1] = Field(default=1, exclude=True) + + +MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent diff --git a/app/models/room.py b/app/models/room.py index 42f897c..392562a 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -53,6 +53,15 @@ class MultiplayerUserState(str, Enum): RESULTS = "results" SPECTATING = "spectating" + @property + def is_playing(self) -> bool: + return self in { + self.WAITING_FOR_LOAD, + self.PLAYING, + self.READY_FOR_GAMEPLAY, + self.LOADED, + } + class DownloadState(str, Enum): UNKNOWN = "unknown" diff --git a/app/models/signalr.py b/app/models/signalr.py index de66e30..7116ea0 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -14,6 +14,7 @@ class SignalRMeta: member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute use_upper_case: bool = False # use upper CamelCase for field names + use_abbr: bool = True class SignalRUnionMessage(BaseModel): diff --git a/app/router/score.py b/app/router/score.py index 2f1303e..b50911d 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,10 +1,19 @@ from __future__ import annotations -from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User +from app.database import ( + Beatmap, + Playlist, + Score, + ScoreResp, + ScoreToken, + ScoreTokenResp, + User, +) from app.database.score import get_leaderboard, process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, @@ -13,6 +22,7 @@ from app.models.score import ( Rank, SoloScoreSubmissionInfo, ) +from app.signalr.hub import MultiplayerHubs from .api_router import router @@ -24,6 +34,68 @@ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +async def submit_score( + info: SoloScoreSubmissionInfo, + beatmap: int, + token: int, + current_user: User, + db: AsyncSession, + redis: Redis, + fetcher: Fetcher, + item_id: int | None = None, + room_id: int | None = None, +): + if not info.passed: + info.rank = Rank.F + score_token = ( + await db.exec( + select(ScoreToken) + .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] + .where(ScoreToken.id == token) + ) + ).first() + if not score_token or score_token.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Score token not found") + if score_token.score_id: + score = ( + await db.exec( + select(Score).where( + Score.id == score_token.score_id, + Score.user_id == current_user.id, + ) + ) + ).first() + if not score: + raise HTTPException(status_code=404, detail="Score not found") + else: + beatmap_status = ( + await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)) + ).first() + if beatmap_status is None: + raise HTTPException(status_code=404, detail="Beatmap not found") + ranked = beatmap_status in { + BeatmapRankStatus.RANKED, + BeatmapRankStatus.APPROVED, + } + score = await process_score( + current_user, + beatmap, + ranked, + score_token, + info, + fetcher, + db, + redis, + ) + await db.refresh(current_user) + score_id = score.id + score_token.score_id = score_id + await process_user(db, current_user, score, ranked) + score = (await db.exec(select(Score).where(Score.id == score_id))).first() + assert score is not None + return await ScoreResp.from_db(db, score) + + class BeatmapScores(BaseModel): scores: list[ScoreResp] userScore: ScoreResp | None = None @@ -97,9 +169,10 @@ async def get_user_beatmap_score( status_code=404, detail=f"Cannot find user {user}'s score on this beatmap" ) else: + resp = await ScoreResp.from_db(db, user_score) return BeatmapUserScore( - position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score), + position=resp.rank_global or 0, + score=resp, ) @@ -173,55 +246,95 @@ async def submit_solo_score( redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): - if not info.passed: - info.rank = Rank.F - async with db: - score_token = ( - await db.exec( - select(ScoreToken) - .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] - .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) + return await submit_score(info, beatmap, token, current_user, db, redis, fetcher) + + +@router.post( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp +) +async def create_playlist_score( + room_id: int, + playlist_id: int, + beatmap_id: int = Form(), + beatmap_hash: str = Form(), + ruleset_id: int = Form(..., ge=0, le=3), + version_hash: str = Form(""), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + room = MultiplayerHubs.rooms[room_id] + if not room: + raise HTTPException(status_code=404, detail="Room not found") + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id ) - ).first() - if not score_token or score_token.user_id != current_user.id: - raise HTTPException(status_code=404, detail="Score token not found") - if score_token.score_id: - score = ( - await db.exec( - select(Score).where( - Score.id == score_token.score_id, - Score.user_id == current_user.id, - ) - ) - ).first() - if not score: - raise HTTPException(status_code=404, detail="Score not found") - else: - beatmap_status = ( - await db.exec( - select(Beatmap.beatmap_status).where(Beatmap.id == beatmap) - ) - ).first() - if beatmap_status is None: - raise HTTPException(status_code=404, detail="Beatmap not found") - ranked = beatmap_status in { - BeatmapRankStatus.RANKED, - BeatmapRankStatus.APPROVED, - } - score = await process_score( - current_user, - beatmap, - ranked, - score_token, - info, - fetcher, - db, - redis, + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist not found") + + # validate + if not item.freestyle: + if item.ruleset_id != ruleset_id: + raise HTTPException( + status_code=400, detail="Ruleset mismatch in playlist item" ) - await db.refresh(current_user) - score_id = score.id - score_token.score_id = score_id - await process_user(db, current_user, score, ranked) - score = (await db.exec(select(Score).where(Score.id == score_id))).first() - assert score is not None - return await ScoreResp.from_db(db, score) + if item.beatmap_id != beatmap_id: + raise HTTPException( + status_code=400, detail="Beatmap ID mismatch in playlist item" + ) + # TODO: max attempts + if item.expired: + raise HTTPException(status_code=400, detail="Playlist item has expired") + if item.played_at: + raise HTTPException( + status_code=400, detail="Playlist item has already been played" + ) + # 这里应该不用验证mod了吧。。。 + + score_token = ScoreToken( + user_id=current_user.id, + beatmap_id=beatmap_id, + ruleset_id=INT_TO_MODE[ruleset_id], + playlist_item_id=playlist_id, + ) + session.add(score_token) + await session.commit() + await session.refresh(score_token) + return ScoreTokenResp.from_db(score_token) + + +@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}") +async def submit_playlist_score( + room_id: int, + playlist_id: int, + token: int, + info: SoloScoreSubmissionInfo, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), + fetcher: Fetcher = Depends(get_fetcher), +): + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id + ) + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist item not found") + score_resp = await submit_score( + info, + item.beatmap_id, + token, + current_user, + session, + redis, + fetcher, + item.id, + room_id, + ) + return score_resp diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index bd34be0..21f192c 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +from datetime import timedelta from typing import override from app.database import Room @@ -8,8 +10,11 @@ from app.database.playlists import Playlist from app.dependencies.database import engine from app.exception import InvokeException from app.log import logger +from app.models.mods import APIMod from app.models.multiplayer_hub import ( BeatmapAvailability, + ForceGameplayStartCountdown, + MatchServerEvent, MultiplayerClientState, MultiplayerQueue, MultiplayerRoom, @@ -17,16 +22,22 @@ from app.models.multiplayer_hub import ( PlaylistItem, ServerMultiplayerRoom, ) -from app.models.room import RoomCategory, RoomStatus +from app.models.room import ( + DownloadState, + MultiplayerRoomState, + MultiplayerUserState, + RoomCategory, + RoomStatus, +) from app.models.score import GameMode -from app.models.signalr import serialize_to_list from .hub import Client, Hub -from msgpack_lazer_api import APIMod from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +GAMEPLAY_LOAD_TIMEOUT = 30 + class MultiplayerHub(Hub[MultiplayerClientState]): @override @@ -58,7 +69,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): type=room.settings.match_type, queue_mode=room.settings.queue_mode, auto_skip=room.settings.auto_skip, - auto_start_duration=room.settings.auto_start_duration, + auto_start_duration=int( + room.settings.auto_start_duration.total_seconds() + ), host_id=client.user_id, status=RoomStatus.IDLE, ) @@ -75,10 +88,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): category=RoomCategory.NORMAL, status=RoomStatus.IDLE, start_at=starts_at, + hub=self, ) queue = MultiplayerQueue( room=server_room, - hub=self, ) server_room.queue = queue self.rooms[room.room_id] = server_room @@ -86,6 +99,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): client, room.room_id, room.settings.password ) + async def JoinRoom(self, client: Client, room_id: int): + return self.JoinRoomWithPassword(client, room_id, "") + async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") store = self.get_or_create_state(client) @@ -105,12 +121,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # from CreateRoom room.host = user store.room_id = room_id - await self.broadcast_group_call( - self.group_id(room_id), "UserJoined", serialize_to_list(user) - ) + await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user) room.users.append(user) self.add_to_group(client, self.group_id(room_id)) - return serialize_to_list(room) + return room async def ChangeBeatmapAvailability( self, client: Client, beatmap_availability: BeatmapAvailability @@ -132,12 +146,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): and availability.progress == beatmap_availability.progress ): return - user.availability = availability + user.availability = beatmap_availability await self.broadcast_group_call( self.group_id(store.room_id), "UserBeatmapAvailabilityChanged", user.user_id, - serialize_to_list(beatmap_availability), + (beatmap_availability), ) async def AddPlaylistItem(self, client: Client, item: PlaylistItem): @@ -198,14 +212,14 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.broadcast_group_call( self.group_id(room.room.room_id), "SettingsChanged", - serialize_to_list(room.room.settings), + (room.room.settings), ) async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): await self.broadcast_group_call( self.group_id(room.room.room_id), "PlaylistItemAdded", - serialize_to_list(item), + (item), ) async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): @@ -221,7 +235,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.broadcast_group_call( self.group_id(room.room.room_id), "PlaylistItemChanged", - serialize_to_list(item), + (item), ) async def ChangeUserStyle( @@ -378,7 +392,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) if not is_valid: incompatible_mods = [ - mod.acronym for mod in new_mods if mod not in valid_mods + mod["acronym"] for mod in new_mods if mod not in valid_mods ] raise InvokeException( f"Incompatible mods were selected: {','.join(incompatible_mods)}" @@ -395,3 +409,221 @@ class MultiplayerHub(Hub[MultiplayerClientState]): user.user_id, valid_mods, ) + + async def validate_user_stare( + self, + room: ServerMultiplayerRoom, + old: MultiplayerUserState, + new: MultiplayerUserState, + ): + assert room.queue + match new: + case MultiplayerUserState.IDLE: + if old.is_playing: + raise InvokeException( + "Cannot return to idle without aborting gameplay." + ) + case MultiplayerUserState.READY: + if old != MultiplayerUserState.IDLE: + raise InvokeException(f"Cannot change state from {old} to {new}") + if room.queue.current_item.expired: + raise InvokeException( + "Cannot ready up while all items have been played." + ) + case MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.LOADED: + if old != MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.READY_FOR_GAMEPLAY: + if old != MultiplayerUserState.LOADED: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.PLAYING: + raise InvokeException("State is managed by the server.") + case MultiplayerUserState.FINISHED_PLAY: + if old != MultiplayerUserState.PLAYING: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.RESULTS: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.SPECTATING: + if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY): + raise InvokeException(f"Cannot change state from {old} to {new}") + + async def ChangeState(self, client: Client, state: MultiplayerUserState): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if user.state == state: + return + match state: + case MultiplayerUserState.IDLE: + if user.state.is_playing: + return + case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY: + if not user.state.is_playing: + return + await self.validate_user_stare( + server_room, + user.state, + state, + ) + await self.change_user_state(server_room, user, state) + if state == MultiplayerUserState.SPECTATING and ( + room.state == MultiplayerRoomState.PLAYING + or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + ): + await self.call_noblock(client, "LoadRequested") + await self.update_room_state(server_room) + + async def change_user_state( + self, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + state: MultiplayerUserState, + ): + user.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStateChanged", + user.user_id, + user.state, + ) + + async def update_room_state(self, room: ServerMultiplayerRoom): + match room.room.state: + case MultiplayerRoomState.WAITING_FOR_LOAD: + played_count = len( + [True for user in room.room.users if user.state.is_playing] + ) + ready_count = len( + [ + True + for user in room.room.users + if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY + ] + ) + if played_count == ready_count: + await self.start_gameplay(room) + case MultiplayerRoomState.PLAYING: + assert room.queue + if all( + u.state != MultiplayerUserState.PLAYING for u in room.room.users + ): + for u in filter( + lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, + room.room.users, + ): + await self.change_user_state( + room, u, MultiplayerUserState.RESULTS + ) + await self.change_room_state(room, MultiplayerRoomState.OPEN) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "ResultsReady", + ) + await room.queue.finish_current_item() + + async def change_room_state( + self, room: ServerMultiplayerRoom, state: MultiplayerRoomState + ): + room.room.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "RoomStateChanged", + state, + ) + + async def StartMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + if any(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Not all users are ready") + + await self.start_match(server_room) + + async def start_match(self, room: ServerMultiplayerRoom): + assert room.queue + if room.room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Can't start match when already in a running state.") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + ready_users = [ + u + for u in room.room.users + if u.availability.state == DownloadState.LOCALLY_AVAILABLE + and ( + u.state == MultiplayerUserState.READY + or u.state == MultiplayerUserState.IDLE + ) + ] + await asyncio.gather( + *[ + self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD) + for u in ready_users + ] + ) + await self.change_room_state( + room, + MultiplayerRoomState.WAITING_FOR_LOAD, + ) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "LoadRequested", + ) + await room.start_countdown( + ForceGameplayStartCountdown( + remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + ), + self.start_gameplay, + ) + + async def start_gameplay(self, room: ServerMultiplayerRoom): + assert room.queue + if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: + raise InvokeException("Room is not ready for gameplay") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + playing = False + for user in room.room.users: + client = self.get_client_by_id(str(user.user_id)) + if client is None: + continue + + if user.state in ( + MultiplayerUserState.READY_FOR_GAMEPLAY, + MultiplayerUserState.LOADED, + ): + playing = True + await self.change_user_state(room, user, MultiplayerUserState.PLAYING) + await self.call_noblock(client, "GameplayStarted") + await self.change_room_state( + room, + (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), + ) + + async def send_match_event( + self, room: ServerMultiplayerRoom, event: MatchServerEvent + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchEvent", + event, + ) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 70c2276..9afb78d 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -97,6 +97,8 @@ class MsgpackProtocol: return [cls.serialize_msgpack(item) for item in v] elif issubclass(typ, datetime.datetime): return [v, 0] + elif issubclass(typ, datetime.timedelta): + return int(v.total_seconds()) elif isinstance(v, dict): return { cls.serialize_msgpack(k): cls.serialize_msgpack(value) @@ -213,6 +215,8 @@ class MsgpackProtocol: return typ.model_validate(obj=cls.process_object(v, typ)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return v[0] + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + return datetime.timedelta(seconds=int(v)) elif isinstance(v, list): return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): @@ -296,21 +300,30 @@ class MsgpackProtocol: class JSONProtocol: @classmethod - def serialize_to_json(cls, v: Any): + def serialize_to_json(cls, v: Any, dict_key: bool = False): typ = v.__class__ if issubclass(typ, BaseModel): return cls.serialize_model(v) elif isinstance(v, dict): return { - cls.serialize_to_json(k): cls.serialize_to_json(value) + cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items() } elif isinstance(v, list): return [cls.serialize_to_json(item) for item in v] elif isinstance(v, datetime.datetime): return v.isoformat() - elif isinstance(v, Enum): + elif isinstance(v, datetime.timedelta): + # d.hh:mm:ss + total_seconds = int(v.total_seconds()) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02}:{minutes:02}:{seconds:02}" + elif isinstance(v, Enum) and dict_key: return v.value + elif isinstance(v, Enum): + list_ = list(typ) + return list_.index(v) return v @classmethod @@ -322,9 +335,13 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( - cls.serialize_to_json(getattr(v, field)) - ) + d[ + snake_to_camel( + field, + metadata.use_upper_case if metadata else False, + metadata.use_abbr if metadata else True, + ) + ] = cls.serialize_to_json(getattr(v, field)) if issubclass(v.__class__, SignalRUnionMessage): return { "$dtype": v.__class__.__name__, @@ -343,7 +360,11 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - value = v.get(snake_to_camel(field, not from_union)) + value = v.get( + snake_to_camel( + field, not from_union, metadata.use_abbr if metadata else True + ) + ) anno = typ.model_fields[field].annotation if anno is None: d[field] = value @@ -401,6 +422,17 @@ class JSONProtocol: return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return datetime.datetime.fromisoformat(v) + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + # d.hh:mm:ss + parts = v.split(":") + if len(parts) == 3: + return datetime.timedelta( + hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]) + ) + elif len(parts) == 2: + return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) + elif len(parts) == 1: + return datetime.timedelta(seconds=int(parts[0])) elif isinstance(v, list): return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): diff --git a/app/utils.py b/app/utils.py index 0d759a1..ac51b90 100644 --- a/app/utils.py +++ b/app/utils.py @@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str: return "".join(result) -def snake_to_camel(name: str, lower_case: bool = True) -> str: +def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str: """Convert a snake_case string to camelCase.""" if not name: return name @@ -47,7 +47,7 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str: result = [] for part in parts: - if part.lower() in abbreviations: + if part.lower() in abbreviations and use_abbr: result.append(part.upper()) else: if result or not lower_case: From c2579e86ebc5595a2a8d8b65a646411953e568e8 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 13:50:59 +0000 Subject: [PATCH 25/45] feat(multiplayer): supoort manage user (kick, transfer host, leave) --- app/signalr/hub/multiplayer.py | 116 ++++++++++++++++++++++++++++++++- 1 file changed, 114 insertions(+), 2 deletions(-) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 21f192c..9350c34 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from datetime import timedelta +from datetime import UTC, datetime, timedelta from typing import override from app.database import Room @@ -33,7 +33,8 @@ from app.models.score import GameMode from .hub import Client, Hub -from sqlmodel import select +from sqlalchemy import update +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession GAMEPLAY_LOAD_TIMEOUT = 30 @@ -627,3 +628,114 @@ class MultiplayerHub(Hub[MultiplayerClientState]): "MatchEvent", event, ) + + async def make_user_leave( + self, + client: Client, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + kicked: bool = False, + ): + self.remove_from_group(client, self.group_id(room.room.room_id)) + room.room.users.remove(user) + + if len(room.room.users) == 0: + await self.end_room(room) + await self.update_room_state(room) + if room.room.host and room.room.host.user_id == user.user_id: + next_host = room.room.users[0] + await self.set_host(room, next_host) + + if kicked: + await self.call_noblock(client, "UserKicked", user) + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserKicked", user + ) + else: + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserLeft", user + ) + + target_store = self.state.get(user.user_id) + if target_store: + target_store.room_id = 0 + + async def end_room(self, room: ServerMultiplayerRoom): + assert room.room.host + async with AsyncSession(engine) as session: + await session.execute( + update(Room) + .where(col(Room.id) == room.room.room_id) + .values( + name=room.room.settings.name, + ended_at=datetime.now(UTC), + type=room.room.settings.match_type, + queue_mode=room.room.settings.queue_mode, + auto_skip=room.room.settings.auto_skip, + auto_start_duration=int( + room.room.settings.auto_start_duration.total_seconds() + ), + host_id=room.room.host.user_id, + ) + ) + del self.rooms[room.room.room_id] + + async def LeaveRoom(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + return + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.make_user_leave(client, server_room, user) + + async def KickUser(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + user = next((u for u in room.users if u.user_id == user_id), None) + if user is None: + raise InvokeException("User not found in this room") + + target_client = self.get_client_by_id(str(user.user_id)) + if target_client is None: + return + await self.make_user_leave(target_client, server_room, user, kicked=True) + + async def set_host(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser): + room.room.host = user + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "HostChanged", + user.user_id, + ) + + async def TransferHost(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + new_host = next((u for u in room.users if u.user_id == user_id), None) + if new_host is None: + raise InvokeException("User not found in this room") + await self.set_host(server_room, new_host) From 1e304542bd4d61da3082a9d58aa029a47cca3fb2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 14:00:49 +0000 Subject: [PATCH 26/45] feat(multiplayer): supoort abort match --- app/models/multiplayer_hub.py | 6 ++++ app/signalr/hub/multiplayer.py | 62 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index ba8a050..9c30523 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -4,6 +4,7 @@ import asyncio from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta +from enum import IntEnum from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from app.database.beatmap import Beatmap @@ -587,3 +588,8 @@ class CountdownStoppedEvent(_MatchServerEvent): MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent + + +class GameplayAbortReason(IntEnum): + LOAD_TOOK_TOO_LONG = 0 + HOST_ABORTED = 1 diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 9350c34..1a2b332 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -14,6 +14,7 @@ from app.models.mods import APIMod from app.models.multiplayer_hub import ( BeatmapAvailability, ForceGameplayStartCountdown, + GameplayAbortReason, MatchServerEvent, MultiplayerClientState, MultiplayerQueue, @@ -615,6 +616,13 @@ class MultiplayerHub(Hub[MultiplayerClientState]): playing = True await self.change_user_state(room, user, MultiplayerUserState.PLAYING) await self.call_noblock(client, "GameplayStarted") + elif user.state == MultiplayerUserState.WAITING_FOR_LOAD: + await self.change_user_state(room, user, MultiplayerUserState.IDLE) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "GameplayAborted", + GameplayAbortReason.LOAD_TOOK_TOO_LONG, + ) await self.change_room_state( room, (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), @@ -739,3 +747,57 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if new_host is None: raise InvokeException("User not found in this room") await self.set_host(server_room, new_host) + + async def AbortGameplay(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if not user.state.is_playing: + raise InvokeException("Cannot abort gameplay while not in a gameplay state") + + await self.change_user_state( + server_room, + user, + MultiplayerUserState.IDLE, + ) + await self.update_room_state(server_room) + + async def AbortMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if ( + room.state != MultiplayerRoomState.PLAYING + or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + ): + raise InvokeException("Room is not in a playable state") + + await asyncio.gather( + *[ + self.change_user_state(server_room, u, MultiplayerUserState.IDLE) + for u in room.users + if u.state.is_playing + ] + ) + await self.broadcast_group_call( + self.group_id(room.room_id), + "GameplayAborted", + GameplayAbortReason.HOST_ABORTED, + ) + await self.update_room_state(server_room) From 34bf2c6b324c3f16da4eb0bbd511d2e0fffa268e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 15:14:30 +0000 Subject: [PATCH 27/45] feat(multiplayer): support change settings --- app/models/multiplayer_hub.py | 172 +++++++++++++++++++++++++++++++-- app/signalr/hub/multiplayer.py | 92 +++++++++++++++--- 2 files changed, 240 insertions(+), 24 deletions(-) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 9c30523..97148db 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -1,11 +1,12 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import IntEnum -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, override from app.database.beatmap import Beatmap from app.dependencies.database import engine @@ -291,7 +292,9 @@ class MultiplayerQueue: current_set.append(items[i]) if is_first_set: - current_set.sort(key=lambda item: (item.order, item.id)) + current_set.sort( + key=lambda item: (item.playlist_order, item.id) + ) ordered_active_items.extend(current_set) first_set_order_by_user_id = { item.owner_id: idx @@ -308,7 +311,7 @@ class MultiplayerQueue: is_first_set = False for idx, item in enumerate(ordered_active_items): - item.order = idx + item.playlist_order = idx case _: ordered_active_items = sorted( (item for item in self.room.playlist if not item.expired), @@ -487,6 +490,15 @@ class MultiplayerQueue: assert self.room.host await self.add_item(self.current_item.clone(), self.room.host) + async def update_queue_mode(self): + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + await self.update_order() + await self.update_current_item() + @property def current_item(self): return self.room.playlist[self.current_index] @@ -507,6 +519,125 @@ class CountdownInfo: ) +class _MatchRequest(SignalRUnionMessage): ... + + +class ChangeTeamRequest(_MatchRequest): + union_type: ClassVar[Literal[0]] = 0 + team_id: int + + +class StartMatchCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[1]] = 1 + duration: timedelta + + +class StopCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[2]] = 2 + id: int + + +MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest + + +class MatchTypeHandler(ABC): + def __init__(self, room: "ServerMultiplayerRoom"): + self.room = room + self.hub = room.hub + + @abstractmethod + async def handle_join(self, user: MultiplayerRoomUser): ... + + @abstractmethod + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @abstractmethod + async def handle_leave(self, user: MultiplayerRoomUser): ... + + +class HeadToHeadHandler(MatchTypeHandler): + @override + async def handle_join(self, user: MultiplayerRoomUser): + if user.match_state is not None: + user.match_state = None + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + +class TeamVersusHandler(MatchTypeHandler): + @override + def __init__(self, room: "ServerMultiplayerRoom"): + super().__init__(room) + self.state = TeamVersusRoomState() + room.room.match_state = self.state + task = asyncio.create_task(self.hub.change_room_match_state(self.room)) + self.hub.tasks.add(task) + task.add_done_callback(self.hub.tasks.discard) + + def _get_best_available_team(self) -> int: + for team in self.state.teams: + if all( + ( + user.match_state is None + or not isinstance(user.match_state, TeamVersusUserState) + or user.match_state.team_id != team.id + ) + for user in self.room.room.users + ): + return team.id + + from collections import defaultdict + + team_counts = defaultdict(int) + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + team_counts[user.match_state.team_id] += 1 + + if team_counts: + min_count = min(team_counts.values()) + for team_id, count in team_counts.items(): + if count == min_count: + return team_id + return self.state.teams[0].id if self.state.teams else 0 + + @override + async def handle_join(self, user: MultiplayerRoomUser): + best_team_id = self._get_best_available_team() + user.match_state = TeamVersusUserState(team_id=best_team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): + if not isinstance(request, ChangeTeamRequest): + return + + if request.team_id not in [team.id for team in self.state.teams]: + raise InvokeException("Invalid team ID") + + user.match_state = TeamVersusUserState(team_id=request.team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + +MATCH_TYPE_HANDLERS = { + MatchType.HEAD_TO_HEAD: HeadToHeadHandler, + MatchType.TEAM_VERSUS: TeamVersusHandler, +} + + @dataclass class ServerMultiplayerRoom: room: MultiplayerRoom @@ -514,10 +645,35 @@ class ServerMultiplayerRoom: status: RoomStatus start_at: datetime hub: "MultiplayerHub" - queue: MultiplayerQueue | None = None - _next_countdown_id: int = 0 - _countdown_id_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - _tracked_countdown: dict[int, CountdownInfo] = field(default_factory=dict) + match_type_handler: MatchTypeHandler + queue: MultiplayerQueue + _next_countdown_id: int + _countdown_id_lock: asyncio.Lock + _tracked_countdown: dict[int, CountdownInfo] + + def __init__( + self, + room: MultiplayerRoom, + category: RoomCategory, + start_at: datetime, + hub: "MultiplayerHub", + ): + self.room = room + self.category = category + self.status = RoomStatus.IDLE + self.start_at = start_at + self.hub = hub + self.queue = MultiplayerQueue(self) + self._next_countdown_id = 0 + self._countdown_id_lock = asyncio.Lock() + self._tracked_countdown = {} + + async def set_handler(self): + self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( + self + ) + for i in self.room.users: + await self.match_type_handler.handle_join(i) async def get_next_countdown_id(self) -> int: async with self._countdown_id_lock: diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 1a2b332..1a27497 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -15,16 +15,20 @@ from app.models.multiplayer_hub import ( BeatmapAvailability, ForceGameplayStartCountdown, GameplayAbortReason, + MatchRequest, MatchServerEvent, MultiplayerClientState, - MultiplayerQueue, MultiplayerRoom, + MultiplayerRoomSettings, MultiplayerRoomUser, PlaylistItem, ServerMultiplayerRoom, + StartMatchCountdownRequest, + StopCountdownRequest, ) from app.models.room import ( DownloadState, + MatchType, MultiplayerRoomState, MultiplayerUserState, RoomCategory, @@ -88,15 +92,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): server_room = ServerMultiplayerRoom( room=room, category=RoomCategory.NORMAL, - status=RoomStatus.IDLE, start_at=starts_at, hub=self, ) - queue = MultiplayerQueue( - room=server_room, - ) - server_room.queue = queue self.rooms[room.room_id] = server_room + await server_room.set_handler() return await self.JoinRoomWithPassword( client, room.room_id, room.settings.password ) @@ -126,6 +126,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user) room.users.append(user) self.add_to_group(client, self.group_id(room_id)) + await server_room.match_type_handler.handle_join(user) return room async def ChangeBeatmapAvailability( @@ -164,7 +165,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("Room does not exist") server_room = self.rooms[store.room_id] room = server_room.room - assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) if user is None: raise InvokeException("You are not in this room") @@ -182,7 +183,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("Room does not exist") server_room = self.rooms[store.room_id] room = server_room.room - assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) if user is None: raise InvokeException("You are not in this room") @@ -200,7 +201,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("Room does not exist") server_room = self.rooms[store.room_id] room = server_room.room - assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) if user is None: raise InvokeException("You are not in this room") @@ -262,7 +263,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def validate_styles(self, room: ServerMultiplayerRoom): - assert room.queue if not room.queue.current_item.freestyle: for user in room.room.users: await self.change_user_style( @@ -323,7 +323,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): return if beatmap_id is not None or ruleset_id is not None: - assert room.queue if not room.queue.current_item.freestyle: raise InvokeException("Current item does not allow free user styles.") @@ -388,7 +387,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room: ServerMultiplayerRoom, user: MultiplayerRoomUser, ): - assert room.queue is_valid, valid_mods = room.queue.current_item.validate_user_mods( user, new_mods ) @@ -418,7 +416,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): old: MultiplayerUserState, new: MultiplayerUserState, ): - assert room.queue match new: case MultiplayerUserState.IDLE: if old.is_playing: @@ -515,7 +512,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if played_count == ready_count: await self.start_gameplay(room) case MultiplayerRoomState.PLAYING: - assert room.queue if all( u.state != MultiplayerUserState.PLAYING for u in room.room.users ): @@ -562,7 +558,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.start_match(server_room) async def start_match(self, room: ServerMultiplayerRoom): - assert room.queue if room.room.state != MultiplayerRoomState.OPEN: raise InvokeException("Can't start match when already in a running state.") if room.queue.current_item.expired: @@ -598,7 +593,6 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def start_gameplay(self, room: ServerMultiplayerRoom): - assert room.queue if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: raise InvokeException("Room is not ready for gameplay") if room.queue.current_item.expired: @@ -801,3 +795,69 @@ class MultiplayerHub(Hub[MultiplayerClientState]): GameplayAbortReason.HOST_ABORTED, ) await self.update_room_state(server_room) + + async def change_user_match_state( + self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchUserStateChanged", + user.user_id, + user.match_state, + ) + + async def change_room_match_state(self, room: ServerMultiplayerRoom): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchRoomStateChanged", + room.room.match_state, + ) + + async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot change settings while playing") + + if settings.match_type == MatchType.PLAYLISTS: + raise InvokeException("Invalid match type selected") + + previous_settings = room.settings + room.settings = settings + + if previous_settings.match_type != settings.match_type: + await server_room.set_handler() + if previous_settings.queue_mode != settings.queue_mode: + await server_room.queue.update_queue_mode() + + await self.setting_changed(server_room, beatmap_changed=False) + await self.update_room_state(server_room) + + async def SendMatchRequest(self, client: Client, request: MatchRequest): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if isinstance(request, StartMatchCountdownRequest): + # TODO: countdown + ... + elif isinstance(request, StopCountdownRequest): + ... + else: + await server_room.match_type_handler.handle_request(user, request) From f82a1bb3c0be2c8c8f33ec51e74e93c66546725e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 4 Aug 2025 01:31:24 +0000 Subject: [PATCH 28/45] feat(multiplayer): support invite player --- app/signalr/hub/multiplayer.py | 72 ++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 1a27497..eb602fe 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -6,7 +6,9 @@ from typing import override from app.database import Room from app.database.beatmap import Beatmap +from app.database.lazer_user import User from app.database.playlists import Playlist +from app.database.relationship import Relationship, RelationshipType from app.dependencies.database import engine from app.exception import InvokeException from app.log import logger @@ -861,3 +863,73 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ... else: await server_room.match_type_handler.handle_request(user, request) + + async def InvitePlayer(self, client: Client, user_id: int): + print(f"Inviting player... {client.user_id} {user_id}") + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + async with AsyncSession(engine) as session: + db_user = await session.get(User, user_id) + target_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == user_id, + Relationship.target_id == client.user_id, + ) + ) + ).first() + inviter_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == client.user_id, + Relationship.target_id == user_id, + ) + ) + ).first() + if db_user is None: + raise InvokeException("User not found") + if db_user.id == client.user_id: + raise InvokeException("You cannot invite yourself") + if db_user.id in [u.user_id for u in room.users]: + raise InvokeException("User already invited") + if db_user.is_restricted: + raise InvokeException("User is restricted") + if ( + inviter_relationship + and inviter_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + target_relationship + and target_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + db_user.pm_friends_only + and target_relationship is not None + and target_relationship.type != RelationshipType.FOLLOW + ): + raise InvokeException( + "Cannot perform action " + "because user has disabled non-friend communications" + ) + + target_client = self.get_client_by_id(str(user_id)) + if target_client is None: + raise InvokeException("User is not online") + await self.call_noblock( + target_client, + "Invited", + client.user_id, + room.room_id, + room.settings.password, + ) From 9da9f27febcccabbded8dd736c8376976fe66c4e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 4 Aug 2025 02:20:14 +0000 Subject: [PATCH 29/45] feat(multiplayer): complete validation --- app/models/multiplayer_hub.py | 103 ++++++++++++++++++++++++++++++--- app/signalr/hub/multiplayer.py | 38 ++++++++++-- 2 files changed, 129 insertions(+), 12 deletions(-) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 97148db..e2f4edf 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import IntEnum -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, override +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, cast, override from app.database.beatmap import Beatmap from app.dependencies.database import engine @@ -107,6 +107,97 @@ class PlaylistItem(BaseModel): star_rating: float freestyle: bool + def _get_api_mods(self): + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + return API_MODS + + def _validate_mod_for_ruleset( + self, mod: APIMod, ruleset_key: int, context: str = "mod" + ) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + # Check if mod is valid for ruleset + if ( + typed_ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[typed_ruleset_key] + ): + raise InvokeException( + f"{context} {mod['acronym']} is invalid for this ruleset" + ) + + mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] + + # Check if mod is unplayable in multiplayer + if mod_settings.get("UserPlayable", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not playable by users" + ) + + if mod_settings.get("ValidForMultiplayer", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not valid for multiplayer" + ) + + def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + for i, mod1 in enumerate(mods): + mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"]) + if mod1_settings: + incompatible = set(mod1_settings.get("IncompatibleMods", [])) + for mod2 in mods[i + 1 :]: + if mod2["acronym"] in incompatible: + raise InvokeException( + f"Mods {mod1['acronym']} and " + f"{mod2['acronym']} are incompatible" + ) + + def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + + for req_mod in self.required_mods: + req_acronym = req_mod["acronym"] + req_settings = API_MODS[typed_ruleset_key].get(req_acronym) + if req_settings: + incompatible = set(req_settings.get("IncompatibleMods", [])) + conflicting_allowed = allowed_acronyms & incompatible + if conflicting_allowed: + conflict_list = ", ".join(conflicting_allowed) + raise InvokeException( + f"Required mod {req_acronym} conflicts with " + f"allowed mods: {conflict_list}" + ) + + def validate_playlist_item_mods(self) -> None: + ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) + + # Validate required mods + for mod in self.required_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod") + + # Validate allowed mods + for mod in self.allowed_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod") + + # Check internal compatibility of required mods + self._check_mod_compatibility(self.required_mods, ruleset_key) + + # Check compatibility between required and allowed mods + self._check_required_allowed_compatibility(ruleset_key) + def validate_user_mods( self, user: "MultiplayerRoomUser", @@ -118,10 +209,7 @@ class PlaylistItem(BaseModel): """ from typing import Literal, cast - from app.models.mods import API_MODS, init_mods - - if not API_MODS: - init_mods() + API_MODS = self._get_api_mods() ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) @@ -367,7 +455,8 @@ class MultiplayerQueue: raise InvokeException("Beatmap not found") if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") - # TODO: mods validation + + item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float( beatmap.difficulty_rating @@ -410,7 +499,7 @@ class MultiplayerQueue: "Attempted to change an item which has already been played" ) - # TODO: mods validation + item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float(beatmap.difficulty_rating) item.playlist_order = existing_item.playlist_order diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index eb602fe..3be8024 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -64,6 +64,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]): connection_token=client.connection_token, ) + @override + async def _clean_state(self, state: MultiplayerClientState): + user_id = int(state.connection_id) + if state.room_id != 0 and state.room_id in self.rooms: + server_room = self.rooms[state.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == user_id), None) + if user is not None: + await self.make_user_leave( + self.get_client_by_id(str(user_id)), server_room, user + ) + async def CreateRoom(self, client: Client, room: MultiplayerRoom): logger.info(f"[MultiplayerHub] {client.user_id} creating room") store = self.get_or_create_state(client) @@ -554,8 +566,17 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("You are not in this room") if room.host is None or room.host.user_id != client.user_id: raise InvokeException("You are not the host of this room") - if any(u.state != MultiplayerUserState.READY for u in room.users): - raise InvokeException("Not all users are ready") + + # Check host state - host must be ready or spectating + if room.host.state not in ( + MultiplayerUserState.SPECTATING, + MultiplayerUserState.READY, + ): + raise InvokeException("Can't start match when the host is not ready.") + + # Check if any users are ready + if all(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Can't start match when no users are ready.") await self.start_match(server_room) @@ -646,7 +667,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if len(room.room.users) == 0: await self.end_room(room) await self.update_room_state(room) - if room.room.host and room.room.host.user_id == user.user_id: + if ( + len(room.room.users) != 0 + and room.room.host + and room.room.host.user_id == user.user_id + ): next_host = room.room.users[0] await self.set_host(room, next_host) @@ -710,6 +735,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if room.host is None or room.host.user_id != client.user_id: raise InvokeException("You are not the host of this room") + if user_id == client.user_id: + raise InvokeException("Can't kick self") + user = next((u for u in room.users if u.user_id == user_id), None) if user is None: raise InvokeException("User not found in this room") @@ -780,9 +808,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if ( room.state != MultiplayerRoomState.PLAYING - or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + and room.state != MultiplayerRoomState.WAITING_FOR_LOAD ): - raise InvokeException("Room is not in a playable state") + raise InvokeException("Cannot abort a match that hasn't started.") await asyncio.gather( *[ From cfcf9ad03457da3cd3e7c8a92bcc8e37f5462812 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 4 Aug 2025 02:21:40 +0000 Subject: [PATCH 30/45] chore(mods): update mod definitions catch: add MF --- static/README.md | 2 +- static/mods.json | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/static/README.md b/static/README.md index 16ece63..77b54fe 100644 --- a/static/README.md +++ b/static/README.md @@ -2,4 +2,4 @@ - `mods.json`: 包含了游戏中的所有可用mod的详细信息。 - Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json - - Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380` + - Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082` diff --git a/static/mods.json b/static/mods.json index defb57f..0a8449b 100644 --- a/static/mods.json +++ b/static/mods.json @@ -2438,7 +2438,8 @@ "Settings": [], "IncompatibleMods": [ "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2460,7 +2461,8 @@ "AC", "AT", "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2477,7 +2479,8 @@ "Settings": [], "IncompatibleMods": [ "AT", - "CN" + "CN", + "MF" ], "RequiresConfiguration": false, "UserPlayable": true, @@ -2638,6 +2641,24 @@ "ValidForMultiplayerAsFreeMod": true, "AlwaysValidForSubmission": false }, + { + "Acronym": "MF", + "Name": "Moving Fast", + "Description": "Dashing by default, slow down!", + "Type": "Fun", + "Settings": [], + "IncompatibleMods": [ + "AT", + "CN", + "RX" + ], + "RequiresConfiguration": false, + "UserPlayable": true, + "ValidForMultiplayer": true, + "ValidForFreestyleAsRequiredMod": false, + "ValidForMultiplayerAsFreeMod": true, + "AlwaysValidForSubmission": false + }, { "Acronym": "SV2", "Name": "Score V2", From 082883599e3a743c18231976bc8078fc08b84803 Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Tue, 5 Aug 2025 07:29:41 +0000 Subject: [PATCH 31/45] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0gitignore,?= =?UTF-8?q?=E6=96=B9=E4=BE=BF=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 369e759..05622b7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +test-cert/ htmlcov/ .tox/ .nox/ From 0988f1fc0c40350030ec746ab9216fe77c10e7cc Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Tue, 5 Aug 2025 16:17:33 +0000 Subject: [PATCH 32/45] feat(multiplayer): partital support for multiplayer rooms' filtering --- app/router/room.py | 79 +++++++++++++--------------------------------- 1 file changed, 22 insertions(+), 57 deletions(-) diff --git a/app/router/room.py b/app/router/room.py index ba909c6..cfaaf56 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -2,9 +2,11 @@ from __future__ import annotations from typing import Literal +from app.database.lazer_user import User from app.database.room import RoomResp from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher +from app.dependencies.user import get_current_user from app.fetcher import Fetcher from app.models.room import RoomStatus from app.signalr.hub import MultiplayerHubs @@ -19,68 +21,31 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) async def get_all_rooms( mode: Literal["open", "ended", "participated", "owned", None] = Query( - None + default="open" ), # TODO: 对房间根据状态进行筛选 category: str = Query(default="realtime"), # TODO status: RoomStatus | None = Query(None), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), redis: Redis = Depends(get_redis), + current_user: User = Depends(get_current_user), ): rooms = MultiplayerHubs.rooms.values() - return [await RoomResp.from_hub(room) for room in rooms] - - -# @router.get("/rooms/{room}", tags=["room"], response_model=Room) -# async def get_room( -# room: int, -# db: AsyncSession = Depends(get_db), -# fetcher: Fetcher = Depends(get_fetcher), -# ): -# redis = get_redis() -# if redis: -# dumped_room = str(redis.get(str(room))) -# if dumped_room is not None: -# resp = await Room.from_mpRoom( -# MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher -# ) -# return resp -# else: -# raise HTTPException(status_code=404, detail="Room Not Found") -# else: -# raise HTTPException(status_code=500, detail="Redis error") - - -# class APICreatedRoom(Room): -# error: str | None - - -# @router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom) -# async def create_room( -# room: Room, -# db: AsyncSession = Depends(get_db), -# fetcher: Fetcher = Depends(get_fetcher), -# ): -# redis = get_redis() -# if redis: -# room_index = RoomIndex() -# db.add(room_index) -# await db.commit() -# await db.refresh(room_index) -# server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher) -# redis.set(str(room_index.id), server_room.model_dump_json()) -# room.room_id = room_index.id -# return APICreatedRoom(**room.model_dump(), error=None) -# else: -# raise HTTPException(status_code=500, detail="redis error") - - -# @router.delete("/rooms/{room}", tags=["room"]) -# async def remove_room(room: int, db: AsyncSession = Depends(get_db)): -# redis = get_redis() -# if redis: -# redis.delete(str(room)) -# room_index = await db.get(RoomIndex, room) -# if room_index: -# await db.delete(room_index) -# await db.commit() + resp_list: list[RoomResp] = [] + for room in rooms: + if category != "realtime": # 歌单模式的处理逻辑 + if room.category == category: + if mode == "owned": + if ( + room.room.host.user_id if room.room.host is not None else 0 + ) != current_user.id: + continue + else: + if ( + room.room.host.user_id if room.room.host is not None else 0 + ) != current_user.id: + continue + if room.status != status: + continue + resp_list.append(await RoomResp.from_hub(room)) + return resp_list From 0a80c5051cb7e5c1651322c174466474ebd8c7ea Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Tue, 5 Aug 2025 17:21:45 +0000 Subject: [PATCH 33/45] feat(multiplayer): support countdown --- app/models/metadata_hub.py | 15 ++++------ app/models/multiplayer_hub.py | 42 +++++++++++++-------------- app/models/signalr.py | 1 - app/signalr/hub/multiplayer.py | 52 ++++++++++++++++++++++++++++++---- app/signalr/packet.py | 39 ++++++++++++++----------- app/utils.py | 38 +++++++++++++++++++++++-- 6 files changed, 131 insertions(+), 56 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index a678d7f..684ab54 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,11 +1,11 @@ from __future__ import annotations from enum import IntEnum -from typing import Annotated, ClassVar, Literal +from typing import ClassVar, Literal -from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState +from app.models.signalr import SignalRUnionMessage, UserState -from pydantic import BaseModel, Field +from pydantic import BaseModel class _UserActivity(SignalRUnionMessage): ... @@ -100,12 +100,9 @@ UserActivity = ( class UserPresence(BaseModel): - activity: Annotated[ - UserActivity | None, Field(default=None), SignalRMeta(use_upper_case=True) - ] - status: Annotated[ - OnlineStatus | None, Field(default=None), SignalRMeta(use_upper_case=True) - ] + activity: UserActivity | None = None + + status: OnlineStatus | None = None @property def pushable(self) -> bool: diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index e2f4edf..9d78282 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -53,10 +53,14 @@ class MultiplayerRoomSettings(BaseModel): auto_start_duration: timedelta = timedelta(seconds=0) auto_skip: bool = False + @property + def auto_start_enabled(self) -> bool: + return self.auto_start_duration != timedelta(seconds=0) + class BeatmapAvailability(BaseModel): state: DownloadState = DownloadState.UNKNOWN - progress: float | None = None + download_progress: float | None = None class _MatchUserState(SignalRUnionMessage): ... @@ -283,10 +287,12 @@ class PlaylistItem(BaseModel): return copy -class _MultiplayerCountdown(BaseModel): +class _MultiplayerCountdown(SignalRUnionMessage): id: int = 0 - remaining: timedelta - is_exclusive: bool = False + time_remaining: timedelta + is_exclusive: Annotated[ + bool, Field(default=True), SignalRMeta(member_ignore=True) + ] = True class MatchStartCountdown(_MultiplayerCountdown): @@ -310,7 +316,7 @@ class MultiplayerRoomUser(BaseModel): user_id: int state: MultiplayerUserState = MultiplayerUserState.IDLE availability: BeatmapAvailability = BeatmapAvailability( - state=DownloadState.UNKNOWN, progress=None + state=DownloadState.UNKNOWN, download_progress=None ) mods: list[APIMod] = Field(default_factory=list) match_state: MatchUserState | None = None @@ -602,8 +608,8 @@ class CountdownInfo: def __init__(self, countdown: MultiplayerCountdown): self.countdown = countdown self.duration = ( - countdown.remaining - if countdown.remaining > timedelta(seconds=0) + countdown.time_remaining + if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0) ) @@ -776,13 +782,12 @@ class ServerMultiplayerRoom: ): async def _countdown_task(self: "ServerMultiplayerRoom"): await asyncio.sleep(info.duration.total_seconds()) - await self.stop_countdown(countdown) if on_complete is not None: await on_complete(self) + await self.stop_countdown(countdown) if countdown.is_exclusive: await self.stop_all_countdowns() - countdown.id = await self.get_next_countdown_id() info = CountdownInfo(countdown) self.room.active_countdowns.append(info.countdown) @@ -793,21 +798,14 @@ class ServerMultiplayerRoom: info.task = asyncio.create_task(_countdown_task(self)) async def stop_countdown(self, countdown: MultiplayerCountdown): - info = next( - ( - info - for info in self._tracked_countdown.values() - if info.countdown.id == countdown.id - ), - None, - ) + info = self._tracked_countdown.get(countdown.id) if info is None: return - if info.task is not None and not info.task.done(): - info.task.cancel() del self._tracked_countdown[countdown.id] self.room.active_countdowns.remove(countdown) await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + if info.task is not None and not info.task.done(): + info.task.cancel() async def stop_all_countdowns(self): for countdown in list(self._tracked_countdown.values()): @@ -817,19 +815,19 @@ class ServerMultiplayerRoom: self.room.active_countdowns.clear() -class _MatchServerEvent(BaseModel): ... +class _MatchServerEvent(SignalRUnionMessage): ... class CountdownStartedEvent(_MatchServerEvent): countdown: MultiplayerCountdown - type: Literal[0] = Field(default=0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 class CountdownStoppedEvent(_MatchServerEvent): id: int - type: Literal[1] = Field(default=1, exclude=True) + union_type: ClassVar[Literal[1]] = 1 MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent diff --git a/app/models/signalr.py b/app/models/signalr.py index 7116ea0..ffbaf6b 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -13,7 +13,6 @@ from pydantic import ( class SignalRMeta: member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute - use_upper_case: bool = False # use upper CamelCase for field names use_abbr: bool = True diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 3be8024..ef3dfcd 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -19,12 +19,14 @@ from app.models.multiplayer_hub import ( GameplayAbortReason, MatchRequest, MatchServerEvent, + MatchStartCountdown, MultiplayerClientState, MultiplayerRoom, MultiplayerRoomSettings, MultiplayerRoomUser, PlaylistItem, ServerMultiplayerRoom, + ServerShuttingDownCountdown, StartMatchCountdownRequest, StopCountdownRequest, ) @@ -160,7 +162,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): availability = user.availability if ( availability.state == beatmap_availability.state - and availability.progress == beatmap_availability.progress + and availability.download_progress == beatmap_availability.download_progress ): return user.availability = beatmap_availability @@ -512,6 +514,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def update_room_state(self, room: ServerMultiplayerRoom): match room.room.state: + case MultiplayerRoomState.OPEN: + if room.room.settings.auto_start_enabled: + if ( + not room.queue.current_item.expired + and any( + u.state == MultiplayerUserState.READY + for u in room.room.users + ) + and not any( + isinstance(countdown, MatchStartCountdown) + for countdown in room.room.active_countdowns + ) + ): + await room.start_countdown( + MatchStartCountdown( + time_remaining=room.room.settings.auto_start_duration + ), + self.start_match, + ) case MultiplayerRoomState.WAITING_FOR_LOAD: played_count = len( [True for user in room.room.users if user.state.is_playing] @@ -610,7 +631,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) await room.start_countdown( ForceGameplayStartCountdown( - remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) ), self.start_gameplay, ) @@ -885,15 +906,34 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("You are not in this room") if isinstance(request, StartMatchCountdownRequest): - # TODO: countdown - ... + if room.host and room.host.user_id != user.user_id: + raise InvokeException("You are not the host of this room") + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot start a countdown during ongoing play") + await server_room.start_countdown( + MatchStartCountdown(time_remaining=request.duration), + self.start_match, + ) elif isinstance(request, StopCountdownRequest): - ... + countdown = next( + (c for c in room.active_countdowns if c.id == request.id), + None, + ) + if countdown is None: + return + if ( + isinstance(countdown, MatchStartCountdown) + and room.settings.auto_start_enabled + ) or isinstance( + countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown) + ): + raise InvokeException("Cannot stop the requested countdown") + + await server_room.stop_countdown(countdown) else: await server_room.match_type_handler.handle_request(user, request) async def InvitePlayer(self, client: Client, user_id: int): - print(f"Inviting player... {client.user_id} {user_id}") store = self.get_or_create_state(client) if store.room_id == 0: raise InvokeException("You are not in a room") diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 9afb78d..09a36bd 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -15,7 +15,7 @@ from typing import ( ) from app.models.signalr import SignalRMeta, SignalRUnionMessage -from app.utils import camel_to_snake, snake_to_camel +from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal import msgpack_lazer_api as m from pydantic import BaseModel @@ -98,7 +98,7 @@ class MsgpackProtocol: elif issubclass(typ, datetime.datetime): return [v, 0] elif issubclass(typ, datetime.timedelta): - return int(v.total_seconds()) + return int(v.total_seconds() * 10_000_000) elif isinstance(v, dict): return { cls.serialize_msgpack(k): cls.serialize_msgpack(value) @@ -216,8 +216,8 @@ class MsgpackProtocol: elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return v[0] elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): - return datetime.timedelta(seconds=int(v)) - elif isinstance(v, list): + return datetime.timedelta(seconds=int(v / 10_000_000)) + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) @@ -300,10 +300,10 @@ class MsgpackProtocol: class JSONProtocol: @classmethod - def serialize_to_json(cls, v: Any, dict_key: bool = False): + def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False): typ = v.__class__ if issubclass(typ, BaseModel): - return cls.serialize_model(v) + return cls.serialize_model(v, in_union) elif isinstance(v, dict): return { cls.serialize_to_json(k, True): cls.serialize_to_json(value) @@ -327,22 +327,28 @@ class JSONProtocol: return v @classmethod - def serialize_model(cls, v: BaseModel) -> dict[str, Any]: + def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]: d = {} + is_union = issubclass(v.__class__, SignalRUnionMessage) for field, info in v.__class__.model_fields.items(): metadata = next( (m for m in info.metadata if isinstance(m, SignalRMeta)), None ) if metadata and metadata.json_ignore: continue - d[ + name = ( snake_to_camel( field, - metadata.use_upper_case if metadata else False, metadata.use_abbr if metadata else True, ) - ] = cls.serialize_to_json(getattr(v, field)) - if issubclass(v.__class__, SignalRUnionMessage): + if not is_union + else snake_to_pascal( + field, + metadata.use_abbr if metadata else True, + ) + ) + d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union) + if is_union and not in_union: return { "$dtype": v.__class__.__name__, "$value": d, @@ -360,11 +366,12 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - value = v.get( - snake_to_camel( - field, not from_union, metadata.use_abbr if metadata else True - ) + name = ( + snake_to_camel(field, metadata.use_abbr if metadata else True) + if not from_union + else snake_to_pascal(field, metadata.use_abbr if metadata else True) ) + value = v.get(name) anno = typ.model_fields[field].annotation if anno is None: d[field] = value @@ -433,7 +440,7 @@ class JSONProtocol: return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) elif len(parts) == 1: return datetime.timedelta(seconds=int(parts[0])) - elif isinstance(v, list): + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) diff --git a/app/utils.py b/app/utils.py index ac51b90..22f06dd 100644 --- a/app/utils.py +++ b/app/utils.py @@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str: return "".join(result) -def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str: +def snake_to_camel(name: str, use_abbr: bool = True) -> str: """Convert a snake_case string to camelCase.""" if not name: return name @@ -50,9 +50,43 @@ def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> if part.lower() in abbreviations and use_abbr: result.append(part.upper()) else: - if result or not lower_case: + if result: result.append(part.capitalize()) else: result.append(part.lower()) return "".join(result) + + +def snake_to_pascal(name: str, use_abbr: bool = True) -> str: + """Convert a snake_case string to PascalCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations and use_abbr: + result.append(part.upper()) + else: + result.append(part.capitalize()) + + return "".join(result) From 2b4d366e3e8160f28fcfd9f7029d417f2942b1c6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Tue, 5 Aug 2025 17:21:53 +0000 Subject: [PATCH 34/45] fix(score): remove foreign key to fix missing index error --- app/database/best_score.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/app/database/best_score.py b/app/database/best_score.py index 42b0024..8688d5b 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -29,9 +29,7 @@ class BestScore(SQLModel, table=True): ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - total_score: int = Field( - default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score")) - ) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) mods: list[str] = Field( default_factory=list, sa_column=Column(JSON), From 84dac34a05ec3270ce3c86549ae7a3ff4746216a Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Wed, 6 Aug 2025 06:55:45 +0000 Subject: [PATCH 35/45] fix(multiplayer): fix fliters --- app/router/room.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/app/router/room.py b/app/router/room.py index cfaaf56..476eaf2 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -40,11 +40,14 @@ async def get_all_rooms( room.room.host.user_id if room.room.host is not None else 0 ) != current_user.id: continue - else: - if ( - room.room.host.user_id if room.room.host is not None else 0 - ) != current_user.id: + else: continue + else: + if mode == "owned": + if ( + room.room.host.user_id if room.room.host is not None else 0 + ) != current_user.id: + continue if room.status != status: continue resp_list.append(await RoomResp.from_hub(room)) From 87bb74d1caa0affba79b2f32e9fed22bc147ba19 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Wed, 6 Aug 2025 10:51:37 +0000 Subject: [PATCH 36/45] feat(multiplayer): support leaderboard --- app/database/__init__.py | 6 + app/database/playlist_best_score.py | 107 +++++++++++ app/database/playlists.py | 2 +- app/database/score.py | 19 +- app/models/model.py | 7 + app/router/score.py | 177 +++++++++++++++++- app/signalr/hub/multiplayer.py | 11 +- ...d0c1b2cefe91_playlist_index_playlist_id.py | 89 +++++++++ 8 files changed, 411 insertions(+), 7 deletions(-) create mode 100644 app/database/playlist_best_score.py create mode 100644 migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 2c01f7a..b3e65f9 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -16,12 +16,15 @@ from .lazer_user import ( UserResp, ) from .playlist_attempts import ItemAttemptsCount +from .playlist_best_score import PlaylistBestScore from .playlists import Playlist, PlaylistResp from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType from .room import Room, RoomResp from .score import ( + MultiplayerScores, Score, + ScoreAround, ScoreBase, ScoreResp, ScoreStatistics, @@ -47,9 +50,11 @@ __all__ = [ "DailyChallengeStatsResp", "FavouriteBeatmapset", "ItemAttemptsCount", + "MultiplayerScores", "OAuthToken", "PPBestScore", "Playlist", + "PlaylistBestScore", "PlaylistResp", "Relationship", "RelationshipResp", @@ -57,6 +62,7 @@ __all__ = [ "Room", "RoomResp", "Score", + "ScoreAround", "ScoreBase", "ScoreResp", "ScoreStatistics", diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py new file mode 100644 index 0000000..49fb459 --- /dev/null +++ b/app/database/playlist_best_score.py @@ -0,0 +1,107 @@ +from typing import TYPE_CHECKING + +from .lazer_user import User + +from redis.asyncio import Redis +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .score import Score + + +class PlaylistBestScore(SQLModel, table=True): + __tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType] + + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + room_id: int = Field(foreign_key="rooms.id", index=True) + playlist_id: int = Field(foreign_key="room_playlists.id", index=True) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + + user: User = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[PlaylistBestScore.score_id]", + "lazy": "joined", + } + ) + + +async def process_playlist_best_score( + room_id: int, + playlist_id: int, + user_id: int, + score_id: int, + total_score: int, + session: AsyncSession, + redis: Redis, +): + previous = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.user_id == user_id, + ) + ) + ).first() + if previous is None: + score = PlaylistBestScore( + user_id=user_id, + score_id=score_id, + room_id=room_id, + playlist_id=playlist_id, + total_score=total_score, + ) + session.add(score) + else: + previous.score_id = score_id + previous.total_score = total_score + await session.commit() + await redis.decr(f"multiplayer:{room_id}:gameplay:players") + + +async def get_position( + room_id: int, + playlist_id: int, + score_id: int, + session: AsyncSession, +) -> int: + rownum = ( + func.row_number() + .over( + partition_by=( + col(PlaylistBestScore.playlist_id), + col(PlaylistBestScore.room_id), + ), + order_by=col(PlaylistBestScore.total_score).desc(), + ) + .label("row_number") + ) + subq = ( + select(PlaylistBestScore, rownum) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + .subquery() + ) + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) + result = await session.exec(stmt) + s = result.one_or_none() + return s if s else 0 diff --git a/app/database/playlists.py b/app/database/playlists.py index 328f17d..432c3b0 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: class PlaylistBase(SQLModel, UTCBaseModel): - id: int = 0 + id: int = Field(index=True) owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) ruleset_id: int = Field(ge=0, le=3) expired: bool = Field(default=False) diff --git a/app/database/score.py b/app/database/score.py index abc3d75..37b96a3 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from datetime import UTC, date, datetime import json import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from app.calculator import ( calculate_pp, @@ -14,7 +14,7 @@ from app.calculator import ( clamp, ) from app.database.team import TeamMember -from app.models.model import UTCBaseModel +from app.models.model import RespWithCursor, UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( INT_TO_MODE, @@ -88,6 +88,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): default=0, sa_column=Column(BigInteger), exclude=True ) type: str + beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") # optional # TODO: current_user_attributes @@ -99,7 +100,6 @@ class Score(ScoreBase, table=True): id: int | None = Field( default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True) ) - beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") user_id: int = Field( default=None, sa_column=Column( @@ -162,7 +162,8 @@ class ScoreResp(ScoreBase): maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None - position: int = 1 # TODO + position: int | None = None + scores_around: "ScoreAround | None" = None @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": @@ -234,6 +235,16 @@ class ScoreResp(ScoreBase): return s +class MultiplayerScores(RespWithCursor): + scores: list[ScoreResp] = Field(default_factory=list) + params: dict[str, Any] = Field(default_factory=dict) + + +class ScoreAround(SQLModel): + higher: MultiplayerScores | None = None + lower: MultiplayerScores | None = None + + async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() diff --git a/app/models/model.py b/app/models/model.py index bc00585..5ba8093 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -13,3 +13,10 @@ class UTCBaseModel(BaseModel): v = v.replace(tzinfo=UTC) return v.astimezone(UTC).isoformat() return v + + +Cursor = dict[str, int] + + +class RespWithCursor(BaseModel): + cursor: Cursor | None = None diff --git a/app/router/score.py b/app/router/score.py index b50911d..818155d 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,5 +1,8 @@ from __future__ import annotations +import time + +from app.calculator import clamp from app.database import ( Beatmap, Playlist, @@ -9,7 +12,18 @@ from app.database import ( ScoreTokenResp, User, ) -from app.database.score import get_leaderboard, process_score, process_user +from app.database.playlist_best_score import ( + PlaylistBestScore, + get_position, + process_playlist_best_score, +) +from app.database.score import ( + MultiplayerScores, + ScoreAround, + get_leaderboard, + process_score, + process_user, +) from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -33,6 +47,8 @@ from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +READ_SCORE_TIMEOUT = 10 + async def submit_score( info: SoloScoreSubmissionInfo, @@ -337,4 +353,163 @@ async def submit_playlist_score( item.id, room_id, ) + await process_playlist_best_score( + room_id, + playlist_id, + current_user.id, + score_resp.id, + score_resp.total_score, + session, + redis, + ) return score_resp + + +class IndexedScoreResp(MultiplayerScores): + total: int + user_score: ScoreResp | None = None + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp +) +async def index_playlist_scores( + room_id: int, + playlist_id: int, + limit: int = 50, + cursor: int = Query(2000000, alias="cursor[total_score]"), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + limit = clamp(limit, 1, 50) + + scores = ( + await session.exec( + select(PlaylistBestScore) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.total_score < cursor, + ) + .order_by(col(PlaylistBestScore.total_score).desc()) + .limit(limit + 1) + ) + ).all() + has_more = len(scores) > limit + if has_more: + scores = scores[:-1] + + user_score = None + score_resp = [await ScoreResp.from_db(session, score.score) for score in scores] + for score in score_resp: + score.position = await get_position(room_id, playlist_id, score.id, session) + if score.user_id == current_user.id: + user_score = score + resp = IndexedScoreResp( + scores=score_resp, + user_score=user_score, + total=len(scores), + params={ + "limit": limit, + }, + ) + if has_more: + resp.cursor = { + "total_score": scores[-1].total_score, + } + return resp + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}", + response_model=ScoreResp, +) +async def show_playlist_score( + room_id: int, + playlist_id: int, + score_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + start_time = time.time() + score_record = None + completed = False + while time.time() - start_time < READ_SCORE_TIMEOUT: + if score_record is None: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.score_id == score_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + ) + ).first() + if completed_players := await redis.get( + f"multiplayer:{room_id}:gameplay:players" + ): + completed = completed_players == "0" + if score_record and completed: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position(room_id, playlist_id, score_id, session) + if completed: + scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + ) + ).all() + higher_scores = [] + lower_scores = [] + for score in scores: + if score.total_score > resp.total_score: + higher_scores.append(await ScoreResp.from_db(session, score.score)) + elif score.total_score < resp.total_score: + lower_scores.append(await ScoreResp.from_db(session, score.score)) + resp.scores_around = ScoreAround( + higher=MultiplayerScores(scores=higher_scores), + lower=MultiplayerScores(scores=lower_scores), + ) + + return resp + + +@router.get( + "rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}", + response_model=ScoreResp, +) +async def get_user_playlist_score( + room_id: int, + playlist_id: int, + user_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + score_record = None + start_time = time.time() + while time.time() - start_time < READ_SCORE_TIMEOUT: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.user_id == user_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + ) + ).first() + if score_record: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position( + room_id, playlist_id, score_record.score_id, session + ) + return resp diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index ef3dfcd..af28d26 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -9,7 +9,7 @@ from app.database.beatmap import Beatmap from app.database.lazer_user import User from app.database.playlists import Playlist from app.database.relationship import Relationship, RelationshipType -from app.dependencies.database import engine +from app.dependencies.database import engine, get_redis from app.exception import InvokeException from app.log import logger from app.models.mods import APIMod @@ -642,6 +642,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if room.queue.current_item.expired: raise InvokeException("Current playlist item is expired") playing = False + played_user = 0 for user in room.room.users: client = self.get_client_by_id(str(user.user_id)) if client is None: @@ -652,6 +653,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): MultiplayerUserState.LOADED, ): playing = True + played_user += 1 await self.change_user_state(room, user, MultiplayerUserState.PLAYING) await self.call_noblock(client, "GameplayStarted") elif user.state == MultiplayerUserState.WAITING_FOR_LOAD: @@ -665,6 +667,13 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room, (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), ) + if playing: + redis = get_redis() + await redis.set( + f"multiplayer:{room.room.room_id}:gameplay:players", + played_user, + ex=3600, + ) async def send_match_event( self, room: ServerMultiplayerRoom, event: MatchServerEvent diff --git a/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py new file mode 100644 index 0000000..74f2e56 --- /dev/null +++ b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py @@ -0,0 +1,89 @@ +"""playlist: index playlist id + +Revision ID: d0c1b2cefe91 +Revises: 58a11441d302 +Create Date: 2025-08-06 06:02:10.512616 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "d0c1b2cefe91" +down_revision: str | Sequence[str] | None = "58a11441d302" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_index( + op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False + ) + op.create_table( + "playlist_best_scores", + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("score_id", sa.BigInteger(), nullable=False), + sa.Column("room_id", sa.Integer(), nullable=False), + sa.Column("playlist_id", sa.Integer(), nullable=False), + sa.Column("total_score", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["playlist_id"], + ["room_playlists.id"], + ), + sa.ForeignKeyConstraint( + ["room_id"], + ["rooms.id"], + ), + sa.ForeignKeyConstraint( + ["score_id"], + ["scores.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("score_id"), + ) + op.create_index( + op.f("ix_playlist_best_scores_playlist_id"), + "playlist_best_scores", + ["playlist_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_room_id"), + "playlist_best_scores", + ["room_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_user_id"), + "playlist_best_scores", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores" + ) + op.drop_table("playlist_best_scores") + op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists") + # ### end Alembic commands ### From 47d02e4e9c94d30e4cbe9fe3caba1ea577117023 Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Thu, 7 Aug 2025 06:28:07 +0000 Subject: [PATCH 37/45] feat(room): add POST /room API --- app/database/room.py | 2 +- app/models/multiplayer_hub.py | 52 ++++++++++++++++++++ app/router/room.py | 86 +++++++++++++++++++++++++++------- app/signalr/hub/multiplayer.py | 2 +- 4 files changed, 122 insertions(+), 20 deletions(-) diff --git a/app/database/room.py b/app/database/room.py index 80457b6..08f1466 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -29,7 +29,7 @@ class RoomBase(SQLModel): name: str = Field(index=True) category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) duration: int | None = Field(default=None) # minutes - starts_at: datetime = Field( + starts_at: datetime | None = Field( sa_column=Column( DateTime(timezone=True), ), diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 9d78282..ed37b98 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -335,6 +335,58 @@ class MultiplayerRoom(BaseModel): active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) channel_id: int + @classmethod + def from_db(cls, room) -> "MultiplayerRoom": + """ + 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) + """ + + # 用户列表 + users = [MultiplayerRoomUser(user_id=room.host_id)] + host_user = MultiplayerRoomUser(user_id=room.host_id) + # playlist 转换 + playlist = [] + if hasattr(room, "playlist"): + for item in room.playlist: + playlist.append( + PlaylistItem( + id=item.id, + owner_id=item.owner_id, + beatmap_id=item.beatmap_id, + beatmap_checksum=item.beatmap.checksum if item.beatmap else "", + ruleset_id=item.ruleset_id, + required_mods=item.required_mods, + allowed_mods=item.allowed_mods, + expired=item.expired, + playlist_order=item.playlist_order, + played_at=item.played_at, + star_rating=item.beatmap.difficulty_rating + if item.beatmap is not None + else 0.0, + freestyle=item.freestyle, + ) + ) + + return cls( + room_id=room.id, + state=getattr(room, "state", MultiplayerRoomState.OPEN), + settings=MultiplayerRoomSettings( + name=room.name, + playlist_item_id=playlist[0].id if playlist else 0, + password=getattr(room, "password", ""), + match_type=room.type, + queue_mode=room.queue_mode, + auto_start_duration=timedelta(seconds=room.auto_start_duration), + auto_skip=room.auto_skip, + ), + users=users, + host=host_user, + match_state=None, + playlist=playlist, + active_countdowns=[], + channel_id=getattr(room, "channel_id", 0), + ) + class MultiplayerQueue: def __init__(self, room: "ServerMultiplayerRoom"): diff --git a/app/router/room.py b/app/router/room.py index 476eaf2..800c861 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -1,13 +1,17 @@ from __future__ import annotations +from datetime import UTC, datetime +from time import timezone from typing import Literal from app.database.lazer_user import User -from app.database.room import RoomResp +from app.database.playlists import Playlist +from app.database.room import Room, RoomBase, RoomResp from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user from app.fetcher import Fetcher +from app.models.multiplayer_hub import MultiplayerRoom, ServerMultiplayerRoom from app.models.room import RoomStatus from app.signalr.hub import MultiplayerHubs @@ -33,22 +37,68 @@ async def get_all_rooms( rooms = MultiplayerHubs.rooms.values() resp_list: list[RoomResp] = [] for room in rooms: - if category != "realtime": # 歌单模式的处理逻辑 - if room.category == category: - if mode == "owned": - if ( - room.room.host.user_id if room.room.host is not None else 0 - ) != current_user.id: - continue - else: - continue - else: - if mode == "owned": - if ( - room.room.host.user_id if room.room.host is not None else 0 - ) != current_user.id: - continue - if room.status != status: - continue + if category == "realtime" and room.category != "normal": + continue + elif category != room.category: + continue resp_list.append(await RoomResp.from_hub(room)) return resp_list + + +class APICreatedRoom(RoomResp): + error: str = "" + + +class APIUploadedRoom(RoomBase): + def to_room(self) -> Room: + """ + 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。 + """ + room_dict = self.model_dump() + room_dict.pop("playlist", None) + # host_id 已在字段中 + return Room(**room_dict) + + id: int | None + host_id: int | None = None + playlist: list[Playlist] + + +@router.post("/rooms", tags=["room"], response_model=APICreatedRoom) +async def create_room( + room: APIUploadedRoom, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + # db_room = Room.from_resp(room) + await db.refresh(current_user) + user_id = current_user.id + db_room = room.to_room() + db_room.host_id = current_user.id if current_user.id else 1 + db.add(db_room) + await db.commit() + await db.refresh(db_room) + + playlist: list[Playlist] = [] + # 处理 APIUploadedRoom 里的 playlist 字段 + for item in room.playlist: + # 确保 room_id 正确赋值 + item.id = await Playlist.get_next_id_for_room(db_room.id, db) + item.room_id = db_room.id + item.owner_id = user_id if user_id else 1 + db.add(item) + await db.commit() + await db.refresh(item) + playlist.append(item) + await db.refresh(db_room) + db_room.playlist = playlist + server_room = ServerMultiplayerRoom( + room=MultiplayerRoom.from_db(db_room), + category=db_room.category, + start_at=datetime.now(UTC), + hub=MultiplayerHubs, + ) + MultiplayerHubs.rooms[db_room.id] = server_room + created_room = APICreatedRoom.model_validate(db_room) + created_room.error = "" + return created_room diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index af28d26..fa869b6 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -103,7 +103,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): item = room.playlist[0] item.owner_id = client.user_id room.room_id = db_room.id - starts_at = db_room.starts_at + starts_at = db_room.starts_at or datetime.now(UTC) await Playlist.add_to_db(item, db_room.id, session) server_room = ServerMultiplayerRoom( room=room, From ff25e58696780157a1523151ce7d6493d47a8069 Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Thu, 7 Aug 2025 07:37:24 +0000 Subject: [PATCH 38/45] fix(room): solve 500 in API POST /rooms --- app/database/playlists.py | 5 +++-- app/router/room.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/app/database/playlists.py b/app/database/playlists.py index 432c3b0..3ecb75f 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -134,6 +134,7 @@ class PlaylistResp(PlaylistBase): @classmethod async def from_db(cls, playlist: Playlist) -> "PlaylistResp": - resp = cls.model_validate(playlist) - resp.beatmap = await BeatmapResp.from_db(playlist.beatmap) + data = playlist.model_dump() + data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True) + resp = cls.model_validate(data) return resp diff --git a/app/router/room.py b/app/router/room.py index 800c861..5f3a684 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -37,10 +37,10 @@ async def get_all_rooms( rooms = MultiplayerHubs.rooms.values() resp_list: list[RoomResp] = [] for room in rooms: - if category == "realtime" and room.category != "normal": - continue - elif category != room.category: - continue + # if category == "realtime" and room.category != "normal": + # continue + # elif category != room.category and category != "": + # continue resp_list.append(await RoomResp.from_hub(room)) return resp_list @@ -99,6 +99,6 @@ async def create_room( hub=MultiplayerHubs, ) MultiplayerHubs.rooms[db_room.id] = server_room - created_room = APICreatedRoom.model_validate(db_room) + created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room)) created_room.error = "" return created_room From bf04ea02d861da9afc84a8724edc86ddbaed2c65 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 7 Aug 2025 08:11:26 +0000 Subject: [PATCH 39/45] fix(multiplayer): don't re-add the last item when `HOST_ONLY` --- app/models/multiplayer_hub.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index ed37b98..25e359c 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -284,6 +284,8 @@ class PlaylistItem(BaseModel): copy = self.model_copy() copy.required_mods = list(self.required_mods) copy.allowed_mods = list(self.allowed_mods) + copy.expired = False + copy.played_at = None return copy From d130915b4a4e194cca704784a682701ad61a61cc Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Thu, 7 Aug 2025 11:16:28 +0000 Subject: [PATCH 40/45] feat(rooms): add API GET /rooms/{room} --- app/router/room.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/app/router/room.py b/app/router/room.py index 5f3a684..b992e5b 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -102,3 +102,12 @@ async def create_room( created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room)) created_room.error = "" return created_room + + +@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp) +async def get_room( + room: int, + db: AsyncSession = Depends(get_db), +): + server_room = MultiplayerHubs.rooms[room] + return await RoomResp.from_hub(server_room) From 18d16e2542997d5eb39e6b8b94ab8bb130be4e40 Mon Sep 17 00:00:00 2001 From: chenjintang-shrimp Date: Thu, 7 Aug 2025 12:00:19 +0000 Subject: [PATCH 41/45] feat(rooms): add router PUT /rooms/{room}/users/{user} --- app/router/room.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/app/router/room.py b/app/router/room.py index b992e5b..1c51753 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -5,21 +5,27 @@ from time import timezone from typing import Literal from app.database.lazer_user import User -from app.database.playlists import Playlist +from app.database.playlists import Playlist, PlaylistResp from app.database.room import Room, RoomBase, RoomResp from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user from app.fetcher import Fetcher -from app.models.multiplayer_hub import MultiplayerRoom, ServerMultiplayerRoom +from app.models.multiplayer_hub import ( + MultiplayerRoom, + MultiplayerRoomUser, + ServerMultiplayerRoom, +) from app.models.room import RoomStatus from app.signalr.hub import MultiplayerHubs from .api_router import router -from fastapi import Depends, Query +from fastapi import Depends, HTTPException, Query from redis.asyncio import Redis +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from starlette.status import HTTP_417_EXPECTATION_FAILED @router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) @@ -111,3 +117,30 @@ async def get_room( ): server_room = MultiplayerHubs.rooms[room] return await RoomResp.from_hub(server_room) + + +@router.delete("/rooms/{room}", tags=["room"]) +async def delete_room(room: int, db: AsyncSession = Depends(get_db)): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is None: + raise HTTPException(404, "Room not found") + else: + await db.delete(db_room) + return None + + +@router.put("/rooms/{room}/users/{user}", tags=["room"]) +async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)): + server_room = MultiplayerHubs.rooms[room] + server_room.room.users.append(MultiplayerRoomUser(user_id=user)) + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is not None: + db_room.participant_count += 1 + await db.commit() + resp = await RoomResp.from_hub(server_room) + await db.refresh(db_room) + for item in db_room.playlist: + resp.playlist.append(await PlaylistResp.from_db(item)) + return resp + else: + raise HTTPException(404, "room not found0") From bc2961de1094a4453b2f66e8ddaa349afb805ca2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 7 Aug 2025 14:52:02 +0000 Subject: [PATCH 42/45] feat(playlist): support leaderboard **UNTESTED** --- app/database/__init__.py | 3 +- app/database/playlist_attempts.py | 117 ++++++++++++++++++++++++++-- app/database/playlist_best_score.py | 2 + app/database/room.py | 10 --- app/router/room.py | 40 +++++++++- app/router/score.py | 26 ++++++- 6 files changed, 175 insertions(+), 23 deletions(-) diff --git a/app/database/__init__.py b/app/database/__init__.py index b3e65f9..dbfd3b8 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -15,7 +15,7 @@ from .lazer_user import ( User, UserResp, ) -from .playlist_attempts import ItemAttemptsCount +from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp from .playlist_best_score import PlaylistBestScore from .playlists import Playlist, PlaylistResp from .pp_best_score import PPBestScore @@ -50,6 +50,7 @@ __all__ = [ "DailyChallengeStatsResp", "FavouriteBeatmapset", "ItemAttemptsCount", + "ItemAttemptsResp", "MultiplayerScores", "OAuthToken", "PPBestScore", diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py index 5b4710a..da49981 100644 --- a/app/database/playlist_attempts.py +++ b/app/database/playlist_attempts.py @@ -1,9 +1,116 @@ -from sqlmodel import Field, SQLModel +from .lazer_user import User, UserResp +from .playlist_best_score import PlaylistBestScore + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession -class ItemAttemptsCount(SQLModel, table=True): - __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] - id: int = Field(foreign_key="room_playlists.db_id", primary_key=True, index=True) +class ItemAttemptsCountBase(SQLModel): room_id: int = Field(foreign_key="rooms.id", index=True) attempts: int = Field(default=0) - passed: int = Field(default=0) + completed: int = Field(default=0) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + accuracy: float = 0.0 + pp: float = 0 + total_score: int = 0 + + +class ItemAttemptsCount(ItemAttemptsCountBase, table=True): + __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, foreign_key="room_playlists.db_id", primary_key=True + ) + + user: User = Relationship() + + async def get_position(self, session: AsyncSession) -> int: + rownum = ( + func.row_number() + .over( + partition_by=col(ItemAttemptsCountBase.room_id), + order_by=col(ItemAttemptsCountBase.total_score).desc(), + ) + .label("rn") + ) + subq = select(ItemAttemptsCountBase, rownum).subquery() + stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id) + result = await session.exec(stmt) + return result.one() + + async def update(self, session: AsyncSession): + playlist_scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == self.room_id, + PlaylistBestScore.user_id == self.user_id, + ) + ) + ).all() + self.attempts = sum(score.attempts for score in playlist_scores) + self.total_score = sum(score.total_score for score in playlist_scores) + self.pp = sum(score.score.pp for score in playlist_scores) + self.completed = len(playlist_scores) + self.accuracy = ( + sum(score.score.accuracy * score.attempts for score in playlist_scores) + / self.completed + if self.completed > 0 + else 0.0 + ) + await session.commit() + await session.refresh(self) + + @classmethod + async def get_or_create( + cls, + room_id: int, + user_id: int, + session: AsyncSession, + ) -> "ItemAttemptsCount": + item_attempts = await session.exec( + select(cls).where( + cls.room_id == room_id, + cls.user_id == user_id, + ) + ) + item_attempts = item_attempts.first() + if item_attempts is None: + item_attempts = cls(room_id=room_id, user_id=user_id) + session.add(item_attempts) + await session.commit() + await session.refresh(item_attempts) + await item_attempts.update(session) + return item_attempts + + +class ItemAttemptsResp(ItemAttemptsCountBase): + user: UserResp | None = None + position: int | None = None + + @classmethod + async def from_db( + cls, + item_attempts: ItemAttemptsCount, + session: AsyncSession, + include: list[str] = [], + ) -> "ItemAttemptsResp": + resp = cls.model_validate(item_attempts) + resp.user = await UserResp.from_db( + item_attempts.user, + session=session, + include=["statistics", "team", "daily_challenge_user_stats"], + ) + if "position" in include: + resp.position = await item_attempts.get_position(session) + return resp diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py index 49fb459..46bbfba 100644 --- a/app/database/playlist_best_score.py +++ b/app/database/playlist_best_score.py @@ -32,6 +32,7 @@ class PlaylistBestScore(SQLModel, table=True): room_id: int = Field(foreign_key="rooms.id", index=True) playlist_id: int = Field(foreign_key="room_playlists.id", index=True) total_score: int = Field(default=0, sa_column=Column(BigInteger)) + attempts: int = Field(default=0) # playlist user: User = Relationship() score: "Score" = Relationship( @@ -72,6 +73,7 @@ async def process_playlist_best_score( else: previous.score_id = score_id previous.total_score = total_score + previous.attempts += 1 await session.commit() await redis.decr(f"multiplayer:{room_id}:gameplay:players") diff --git a/app/database/room.py b/app/database/room.py index 08f1466..e01dece 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -11,7 +11,6 @@ from app.models.room import ( ) from .lazer_user import User, UserResp -from .playlist_attempts import ItemAttemptsCount from .playlists import Playlist, PlaylistResp from sqlmodel import ( @@ -67,13 +66,6 @@ class Room(RoomBase, table=True): "overlaps": "room", } ) - # playlist_item_attempts: list["ItemAttemptsCount"] = Relationship( - # sa_relationship_kwargs={ - # "lazy": "joined", - # "cascade": "all, delete-orphan", - # "primaryjoin": "ItemAttemptsCount.room_id == Room.id", - # } - # ) class RoomResp(RoomBase): @@ -84,7 +76,6 @@ class RoomResp(RoomBase): playlist_item_stats: RoomPlaylistItemStats | None = None difficulty_range: RoomDifficultyRange | None = None current_playlist_item: PlaylistResp | None = None - playlist_item_attempts: list[ItemAttemptsCount] = [] @classmethod async def from_db(cls, room: Room) -> "RoomResp": @@ -112,7 +103,6 @@ class RoomResp(RoomBase): resp.playlist_item_stats = stats resp.difficulty_range = difficulty_range resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None - # resp.playlist_item_attempts = room.playlist_item_attempts return resp diff --git a/app/router/room.py b/app/router/room.py index 1c51753..d5bc713 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -1,10 +1,10 @@ from __future__ import annotations from datetime import UTC, datetime -from time import timezone from typing import Literal from app.database.lazer_user import User +from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp from app.database.playlists import Playlist, PlaylistResp from app.database.room import Room, RoomBase, RoomResp from app.dependencies.database import get_db, get_redis @@ -22,10 +22,10 @@ from app.signalr.hub import MultiplayerHubs from .api_router import router from fastapi import Depends, HTTPException, Query +from pydantic import BaseModel, Field from redis.asyncio import Redis -from sqlmodel import select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -from starlette.status import HTTP_417_EXPECTATION_FAILED @router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) @@ -144,3 +144,37 @@ async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_ return resp else: raise HTTPException(404, "room not found0") + + +class APILeaderboard(BaseModel): + leaderboard: list[ItemAttemptsResp] = Field(default_factory=list) + user_score: ItemAttemptsResp | None = None + + +@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard) +async def get_room_leaderboard( + room: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + server_room = MultiplayerHubs.rooms[room] + if not server_room: + raise HTTPException(404, "Room not found") + + aggs = await db.exec( + select(ItemAttemptsCount) + .where(ItemAttemptsCount.room_id == room) + .order_by(col(ItemAttemptsCount.total_score).desc()) + ) + aggs_resp = [] + user_agg = None + for i, agg in enumerate(aggs): + resp = await ItemAttemptsResp.from_db(agg, db) + resp.position = i + 1 + aggs_resp.append(resp) + if agg.user_id == current_user.id: + user_agg = resp + return APILeaderboard( + leaderboard=aggs_resp, + user_score=user_agg, + ) diff --git a/app/router/score.py b/app/router/score.py index 818155d..5db171d 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,17 +1,20 @@ from __future__ import annotations +from datetime import UTC, datetime import time from app.calculator import clamp from app.database import ( Beatmap, Playlist, + Room, Score, ScoreResp, ScoreToken, ScoreTokenResp, User, ) +from app.database.playlist_attempts import ItemAttemptsCount from app.database.playlist_best_score import ( PlaylistBestScore, get_position, @@ -36,7 +39,6 @@ from app.models.score import ( Rank, SoloScoreSubmissionInfo, ) -from app.signalr.hub import MultiplayerHubs from .api_router import router @@ -278,9 +280,11 @@ async def create_playlist_score( current_user: User = Depends(get_current_user), session: AsyncSession = Depends(get_db), ): - room = MultiplayerHubs.rooms[room_id] + room = await session.get(Room, room_id) if not room: raise HTTPException(status_code=404, detail="Room not found") + if room.ended_at and room.ended_at < datetime.now(UTC): + raise HTTPException(status_code=400, detail="Room has ended") item = ( await session.exec( select(Playlist).where( @@ -301,7 +305,18 @@ async def create_playlist_score( raise HTTPException( status_code=400, detail="Beatmap ID mismatch in playlist item" ) - # TODO: max attempts + agg = await session.exec( + select(ItemAttemptsCount).where( + ItemAttemptsCount.room_id == room_id, + ItemAttemptsCount.user_id == current_user.id, + ) + ) + agg = agg.first() + if agg and room.max_attempts and agg.attempts >= room.max_attempts: + raise HTTPException( + status_code=422, + detail="You have reached the maximum attempts for this room", + ) if item.expired: raise HTTPException(status_code=400, detail="Playlist item has expired") if item.played_at: @@ -342,6 +357,8 @@ async def submit_playlist_score( ).first() if not item: raise HTTPException(status_code=404, detail="Playlist item not found") + + user_id = current_user.id score_resp = await submit_score( info, item.beatmap_id, @@ -356,12 +373,13 @@ async def submit_playlist_score( await process_playlist_best_score( room_id, playlist_id, - current_user.id, + user_id, score_resp.id, score_resp.total_score, session, redis, ) + await ItemAttemptsCount.get_or_create(room_id, user_id, session) return score_resp From 7a2c8c1fb4f8cebd065b846ede1331ec69103d76 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 7 Aug 2025 16:18:54 +0000 Subject: [PATCH 43/45] feat(multiplayer): support multiplayer events --- app/database/__init__.py | 3 + app/database/multiplayer_event.py | 53 +++++++++++++ app/database/playlist_attempts.py | 4 +- app/database/playlists.py | 7 +- app/database/room.py | 2 +- app/models/multiplayer_hub.py | 35 ++++++++- app/router/room.py | 120 +++++++++++++++++++++++++++- app/signalr/hub/multiplayer.py | 126 ++++++++++++++++++++++++++++++ 8 files changed, 341 insertions(+), 9 deletions(-) create mode 100644 app/database/multiplayer_event.py diff --git a/app/database/__init__.py b/app/database/__init__.py index dbfd3b8..0ee253b 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -15,6 +15,7 @@ from .lazer_user import ( User, UserResp, ) +from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp from .playlist_best_score import PlaylistBestScore from .playlists import Playlist, PlaylistResp @@ -51,6 +52,8 @@ __all__ = [ "FavouriteBeatmapset", "ItemAttemptsCount", "ItemAttemptsResp", + "MultiplayerEvent", + "MultiplayerEventResp", "MultiplayerScores", "OAuthToken", "PPBestScore", diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py new file mode 100644 index 0000000..b80f957 --- /dev/null +++ b/app/database/multiplayer_event.py @@ -0,0 +1,53 @@ +from datetime import UTC, datetime +from typing import Any + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + SQLModel, +) + + +class MultiplayerEventBase(SQLModel, UTCBaseModel): + playlist_item_id: int | None = None + user_id: int | None = Field( + default=None, + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), + ) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_type: str = Field(index=True) + + +class MultiplayerEvent(MultiplayerEventBase, table=True): + __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] + id: int | None = Field(default=None, primary_key=True) + room_id: int = Field(foreign_key="rooms.id", index=True) + updated_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_detail: dict[str, Any] | None = Field( + sa_column=Column(JSON), + default_factory=dict, + ) + + +class MultiplayerEventResp(MultiplayerEventBase): + id: int + + @classmethod + def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp": + return cls.model_validate(event) diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py index da49981..93bc8c5 100644 --- a/app/database/playlist_attempts.py +++ b/app/database/playlist_attempts.py @@ -29,9 +29,7 @@ class ItemAttemptsCountBase(SQLModel): class ItemAttemptsCount(ItemAttemptsCountBase, table=True): __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] - id: int | None = Field( - default=None, foreign_key="room_playlists.db_id", primary_key=True - ) + id: int | None = Field(default=None, primary_key=True) user: User = Relationship() diff --git a/app/database/playlists.py b/app/database/playlists.py index 3ecb75f..3f7ae40 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -133,8 +133,11 @@ class PlaylistResp(PlaylistBase): beatmap: BeatmapResp | None = None @classmethod - async def from_db(cls, playlist: Playlist) -> "PlaylistResp": + async def from_db( + cls, playlist: Playlist, include: list[str] = [] + ) -> "PlaylistResp": data = playlist.model_dump() - data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True) + if "beatmap" in include: + data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True) resp = cls.model_validate(data) return resp diff --git a/app/database/room.py b/app/database/room.py index e01dece..7817805 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -98,7 +98,7 @@ class RoomResp(RoomBase): difficulty_range.max = max( difficulty_range.max, playlist.beatmap.difficulty_rating ) - resp.playlist.append(await PlaylistResp.from_db(playlist)) + resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) stats.ruleset_ids = list(rulesets) resp.playlist_item_stats = stats resp.difficulty_range = difficulty_range diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 25e359c..09d8900 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -6,7 +6,16 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import IntEnum -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, cast, override +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Literal, + TypedDict, + cast, + override, +) from app.database.beatmap import Beatmap from app.dependencies.database import engine @@ -705,6 +714,9 @@ class MatchTypeHandler(ABC): @abstractmethod async def handle_leave(self, user: MultiplayerRoomUser): ... + @abstractmethod + def get_details(self) -> MatchStartedEventDetail: ... + class HeadToHeadHandler(MatchTypeHandler): @override @@ -721,6 +733,11 @@ class HeadToHeadHandler(MatchTypeHandler): @override async def handle_leave(self, user: MultiplayerRoomUser): ... + @override + def get_details(self) -> MatchStartedEventDetail: + detail = MatchStartedEventDetail(room_type="head_to_head", team=None) + return detail + class TeamVersusHandler(MatchTypeHandler): @override @@ -780,6 +797,17 @@ class TeamVersusHandler(MatchTypeHandler): @override async def handle_leave(self, user: MultiplayerRoomUser): ... + @override + def get_details(self) -> MatchStartedEventDetail: + teams: dict[int, Literal["blue", "red"]] = {} + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" + detail = MatchStartedEventDetail(room_type="team_versus", team=teams) + return detail + MATCH_TYPE_HANDLERS = { MatchType.HEAD_TO_HEAD: HeadToHeadHandler, @@ -890,3 +918,8 @@ MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent class GameplayAbortReason(IntEnum): LOAD_TOOK_TOO_LONG = 0 HOST_ABORTED = 1 + + +class MatchStartedEventDetail(TypedDict): + room_type: Literal["playlists", "head_to_head", "team_versus"] + team: dict[int, Literal["blue", "red"]] | None diff --git a/app/router/room.py b/app/router/room.py index d5bc713..2677b75 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -3,10 +3,14 @@ from __future__ import annotations from datetime import UTC, datetime from typing import Literal -from app.database.lazer_user import User +from app.database.beatmap import Beatmap, BeatmapResp +from app.database.beatmapset import BeatmapsetResp +from app.database.lazer_user import User, UserResp +from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp from app.database.playlists import Playlist, PlaylistResp from app.database.room import Room, RoomBase, RoomResp +from app.database.score import Score from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -140,7 +144,7 @@ async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_ resp = await RoomResp.from_hub(server_room) await db.refresh(db_room) for item in db_room.playlist: - resp.playlist.append(await PlaylistResp.from_db(item)) + resp.playlist.append(await PlaylistResp.from_db(item, ["beatmap"])) return resp else: raise HTTPException(404, "room not found0") @@ -178,3 +182,115 @@ async def get_room_leaderboard( leaderboard=aggs_resp, user_score=user_agg, ) + + +class RoomEvents(BaseModel): + beatmaps: list[BeatmapResp] = Field(default_factory=list) + beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict) + current_playlist_item_id: int = 0 + events: list[MultiplayerEventResp] = Field(default_factory=list) + first_event_id: int = 0 + last_event_id: int = 0 + playlist_items: list[PlaylistResp] = Field(default_factory=list) + room: RoomResp + user: list[UserResp] = Field(default_factory=list) + + +@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"]) +async def get_room_events( + room_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), + limit: int = Query(100, ge=1, le=1000), + after: int | None = Query(None, ge=0), + before: int | None = Query(None, ge=0), +): + events = ( + await db.exec( + select(MultiplayerEvent) + .where( + MultiplayerEvent.room_id == room_id, + col(MultiplayerEvent.id) > after if after is not None else True, + col(MultiplayerEvent.id) < before if before is not None else True, + ) + .order_by(col(MultiplayerEvent.id).desc()) + .limit(limit) + ) + ).all() + + user_ids = set() + playlist_items = {} + beatmap_ids = set() + + event_resps = [] + first_event_id = 0 + last_event_id = 0 + + current_playlist_item_id = 0 + for event in events: + event_resps.append(MultiplayerEventResp.from_db(event)) + + if event.user_id: + user_ids.add(event.user_id) + + if event.playlist_item_id is not None and ( + playitem := ( + await db.exec( + select(Playlist).where( + Playlist.id == event.playlist_item_id, + Playlist.room_id == room_id, + ) + ) + ).first() + ): + current_playlist_item_id = playitem.id + playlist_items[event.playlist_item_id] = playitem + beatmap_ids.add(playitem.beatmap_id) + scores = await db.exec( + select(Score).where( + Score.playlist_item_id == event.playlist_item_id, + Score.room_id == room_id, + ) + ) + for score in scores: + user_ids.add(score.user_id) + beatmap_ids.add(score.beatmap_id) + + assert event.id is not None + first_event_id = min(first_event_id, event.id) + last_event_id = max(last_event_id, event.id) + + if room := MultiplayerHubs.rooms.get(room_id): + current_playlist_item_id = room.queue.current_item.id + room_resp = await RoomResp.from_hub(room) + else: + room = (await db.exec(select(Room).where(Room.id == room_id))).first() + if room is None: + raise HTTPException(404, "Room not found") + room_resp = await RoomResp.from_db(room) + + users = await db.exec(select(User).where(col(User.id).in_(user_ids))) + user_resps = [await UserResp.from_db(user, db) for user in users] + beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) + beatmap_resps = [ + await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps + ] + beatmapset_resps = {} + for beatmap_resp in beatmap_resps: + beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset + + playlist_items_resps = [ + await PlaylistResp.from_db(item) for item in playlist_items.values() + ] + + return RoomEvents( + beatmaps=beatmap_resps, + beatmapsets=beatmapset_resps, + current_playlist_item_id=current_playlist_item_id, + events=event_resps, + first_event_id=first_event_id, + last_event_id=last_event_id, + playlist_items=playlist_items_resps, + room=room_resp, + user=user_resps, + ) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index fa869b6..3688efa 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -7,6 +7,7 @@ from typing import override from app.database import Room from app.database.beatmap import Beatmap from app.database.lazer_user import User +from app.database.multiplayer_event import MultiplayerEvent from app.database.playlists import Playlist from app.database.relationship import Relationship, RelationshipType from app.dependencies.database import engine, get_redis @@ -20,6 +21,7 @@ from app.models.multiplayer_hub import ( MatchRequest, MatchServerEvent, MatchStartCountdown, + MatchStartedEventDetail, MultiplayerClientState, MultiplayerRoom, MultiplayerRoomSettings, @@ -49,11 +51,100 @@ from sqlmodel.ext.asyncio.session import AsyncSession GAMEPLAY_LOAD_TIMEOUT = 30 +class MultiplayerEventLogger: + def __init__(self): + pass + + async def log_event(self, event: MultiplayerEvent): + try: + async with AsyncSession(engine) as session: + session.add(event) + await session.commit() + except Exception as e: + logger.warning(f"Failed to log multiplayer room event to database: {e}") + + async def room_created(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_created", + ) + await self.log_event(event) + + async def room_disbanded(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_disbanded", + ) + await self.log_event(event) + + async def player_joined(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_joined", + ) + await self.log_event(event) + + async def player_left(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_left", + ) + await self.log_event(event) + + async def player_kicked(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_kicked", + ) + await self.log_event(event) + + async def host_changed(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="host_changed", + ) + await self.log_event(event) + + async def game_started( + self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail + ): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_started", + event_detail=details, # pyright: ignore[reportArgumentType] + ) + await self.log_event(event) + + async def game_aborted(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_aborted", + ) + await self.log_event(event) + + async def game_completed(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_completed", + ) + await self.log_event(event) + + class MultiplayerHub(Hub[MultiplayerClientState]): @override def __init__(self): super().__init__() self.rooms: dict[int, ServerMultiplayerRoom] = {} + self.event_logger = MultiplayerEventLogger() @staticmethod def group_id(room: int) -> str: @@ -113,6 +204,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) self.rooms[room.room_id] = server_room await server_room.set_handler() + await self.event_logger.room_created(room.room_id, client.user_id) return await self.JoinRoomWithPassword( client, room.room_id, room.settings.password ) @@ -143,6 +235,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room.users.append(user) self.add_to_group(client, self.group_id(room_id)) await server_room.match_type_handler.handle_join(user) + await self.event_logger.player_joined(room_id, user.user_id) return room async def ChangeBeatmapAvailability( @@ -550,10 +643,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if all( u.state != MultiplayerUserState.PLAYING for u in room.room.users ): + any_user_finished_playing = False for u in filter( lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, room.room.users, ): + any_user_finished_playing = True await self.change_user_state( room, u, MultiplayerUserState.RESULTS ) @@ -562,6 +657,16 @@ class MultiplayerHub(Hub[MultiplayerClientState]): self.group_id(room.room.room_id), "ResultsReady", ) + if any_user_finished_playing: + await self.event_logger.game_completed( + room.room.room_id, + room.queue.current_item.id, + ) + else: + await self.event_logger.game_aborted( + room.room.room_id, + room.queue.current_item.id, + ) await room.queue.finish_current_item() async def change_room_state( @@ -635,6 +740,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ), self.start_gameplay, ) + await self.event_logger.game_started( + room.room.room_id, + room.queue.current_item.id, + details=room.match_type_handler.get_details(), + ) async def start_gameplay(self, room: ServerMultiplayerRoom): if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: @@ -737,6 +847,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): host_id=room.room.host.user_id, ) ) + await self.event_logger.room_disbanded( + room.room.room_id, + room.room.host.user_id, + ) del self.rooms[room.room.room_id] async def LeaveRoom(self, client: Client): @@ -751,6 +865,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user is None: raise InvokeException("You are not in this room") + await self.event_logger.player_left( + room.room_id, + user.user_id, + ) await self.make_user_leave(client, server_room, user) async def KickUser(self, client: Client, user_id: int): @@ -772,6 +890,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user is None: raise InvokeException("User not found in this room") + await self.event_logger.player_kicked( + room.room_id, + user.user_id, + ) target_client = self.get_client_by_id(str(user.user_id)) if target_client is None: return @@ -800,6 +922,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): new_host = next((u for u in room.users if u.user_id == user_id), None) if new_host is None: raise InvokeException("User not found in this room") + await self.event_logger.host_changed( + room.room_id, + new_host.user_id, + ) await self.set_host(server_room, new_host) async def AbortGameplay(self, client: Client): From 2bb1e4bad2d6806f83cbf59b17a757b4868903ab Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 7 Aug 2025 16:21:56 +0000 Subject: [PATCH 44/45] fix(multiplayer): use bigint for `event.id` --- app/database/multiplayer_event.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py index b80f957..904fbe4 100644 --- a/app/database/multiplayer_event.py +++ b/app/database/multiplayer_event.py @@ -31,7 +31,10 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel): class MultiplayerEvent(MultiplayerEventBase, table=True): __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] - id: int | None = Field(default=None, primary_key=True) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) room_id: int = Field(foreign_key="rooms.id", index=True) updated_at: datetime = Field( sa_column=Column( From fb0bba1a6eeae8b0076cd0557dfeaeac9080503c Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 8 Aug 2025 06:25:31 +0000 Subject: [PATCH 45/45] fix(signalr): fail to parse `MessagePack-CSharp-Union | None` type when protocol is msgpack --- app/signalr/packet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 09a36bd..8949f4b 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -242,7 +242,9 @@ class MsgpackProtocol: # except `X (Other Type) | None` if NoneType in args and v is None: return None - if not all(issubclass(arg, SignalRUnionMessage) for arg in args): + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): raise ValueError( f"Cannot validate {v} to {typ}, " "only SignalRUnionMessage subclasses are supported"