diff --git a/.gitignore b/.gitignore index 369e759..05622b7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +test-cert/ htmlcov/ .tox/ .nox/ diff --git a/README.md b/README.md index a4e1e22..267e2b5 100644 --- a/README.md +++ b/README.md @@ -1,205 +1,218 @@ -# osu! API 模拟服务器 - -这是一个使用 FastAPI + MySQL + Redis 实现的 osu! API 模拟服务器,提供了完整的用户认证和数据管理功能。 - -## 功能特性 - -- **OAuth 2.0 认证**: 支持密码流和刷新令牌流 -- **用户数据管理**: 完整的用户信息、统计数据、成就等 -- **多游戏模式支持**: osu!, taiko, fruits, mania -- **数据库持久化**: MySQL 存储用户数据 -- **缓存支持**: Redis 缓存令牌和会话信息 -- **容器化部署**: Docker 和 Docker Compose 支持 - -## API 端点 - -### 认证端点 -- `POST /oauth/token` - OAuth 令牌获取/刷新 - -### 用户端点 -- `GET /api/v2/me/{ruleset}` - 获取当前用户信息 - -### 其他端点 -- `GET /` - 根端点 -- `GET /health` - 健康检查 - -## 快速开始 - -### 使用 Docker Compose (推荐) - -1. 克隆项目 -```bash -git clone -cd osu_lazer_api -``` - -2. 启动服务 -```bash -docker-compose up -d -``` - -3. 创建示例数据 -```bash -docker-compose exec api python create_sample_data.py -``` - -4. 测试 API -```bash -# 获取访问令牌 -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" - -# 使用令牌获取用户信息 -curl -X GET http://localhost:8000/api/v2/me/osu \ - -H "Authorization: Bearer YOUR_ACCESS_TOKEN" -``` - -### 本地开发 - -1. 安装依赖 -```bash -pip install -r requirements.txt -``` - -2. 配置环境变量 -```bash -# 复制服务器配置文件 -cp .env .env.local - -# 复制客户端配置文件(用于测试脚本) -cp .env.client .env.client.local -``` - -3. 启动 MySQL 和 Redis -```bash -# 使用 Docker -docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0 -docker run -d --name redis -p 6379:6379 redis:7-alpine -``` - - -4. 启动应用 -```bash -uvicorn main:app --reload -``` - -## 项目结构 - -``` -osu_lazer_api/ -├── app/ -│ ├── __init__.py -│ ├── models.py # Pydantic 数据模型 -│ ├── database.py # SQLAlchemy 数据库模型 -│ ├── config.py # 配置设置 -│ ├── dependencies.py # 依赖注入 -│ ├── auth.py # 认证和令牌管理 -│ └── utils.py # 工具函数 -├── main.py # FastAPI 应用主文件 -├── create_sample_data.py # 示例数据创建脚本 -├── requirements.txt # Python 依赖 -├── .env # 环境变量配置 -├── docker-compose.yml # Docker Compose 配置 -├── Dockerfile # Docker 镜像配置 -└── README.md # 项目说明 -``` - -## 示例用户 - -创建示例数据后,您可以使用以下凭据进行测试: - -- **用户名**: `Googujiang` -- **密码**: `password123` -- **用户ID**: `15651670` - -## 环境变量配置 - -项目包含两个环境配置文件: - -### 服务器配置 (`.env`) -用于配置 FastAPI 服务器的运行参数: - -| 变量名 | 描述 | 默认值 | -|--------|------|--------| -| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` | -| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` | -| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` | -| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` | -| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | -| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | -| `HOST` | 服务器监听地址 | `0.0.0.0` | -| `PORT` | 服务器监听端口 | `8000` | -| `DEBUG` | 调试模式 | `True` | - -### 客户端配置 (`.env.client`) -用于配置客户端脚本的 API 连接参数: - -| 变量名 | 描述 | 默认值 | -|--------|------|--------| -| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | -| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | -| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` | - -> **注意**: 在生产环境中,请务必更改默认的密钥和密码! - -## API 使用示例 - -### 获取访问令牌 - -```bash -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" -``` - -响应: -```json -{ - "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", - "token_type": "Bearer", - "expires_in": 86400, - "refresh_token": "abc123...", - "scope": "*" -} -``` - -### 获取用户信息 - -```bash -curl -X GET http://localhost:8000/api/v2/me/osu \ - -H "Authorization: Bearer YOUR_ACCESS_TOKEN" -``` - -### 刷新令牌 - -```bash -curl -X POST http://localhost:8000/oauth/token \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" -``` - -## 开发 - -### 添加新用户 - -您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。 - -### 扩展功能 - -- 添加更多 API 端点(排行榜、谱面信息等) -- 实现实时功能(WebSocket) -- 添加管理面板 -- 实现数据导入/导出功能 - -### 迁移数据库 - -参考[数据库迁移指南](./MIGRATE_GUIDE.md) - -## 许可证 - -MIT License - -## 贡献 - -欢迎提交 Issue 和 Pull Request! +# osu! API 模拟服务器 + +这是一个使用 FastAPI + MySQL + Redis 实现的 osu! API 模拟服务器,提供了完整的用户认证和数据管理功能。 + +## 功能特性 + +- **OAuth 2.0 认证**: 支持密码流和刷新令牌流 +- **用户数据管理**: 完整的用户信息、统计数据、成就等 +- **多游戏模式支持**: osu!, taiko, fruits, mania +- **数据库持久化**: MySQL 存储用户数据 +- **缓存支持**: Redis 缓存令牌和会话信息 +- **容器化部署**: Docker 和 Docker Compose 支持 + +## API 端点 + +### 认证端点 +- `POST /oauth/token` - OAuth 令牌获取/刷新 + +### 用户端点 +- `GET /api/v2/me/{ruleset}` - 获取当前用户信息 + +### 其他端点 +- `GET /` - 根端点 +- `GET /health` - 健康检查 + +## 快速开始 + +### 使用 Docker Compose (推荐) + +1. 克隆项目 +```bash +git clone +cd osu_lazer_api +``` + +2. 启动服务 +```bash +docker-compose up -d +``` + +3. 创建示例数据 +```bash +docker-compose exec api python create_sample_data.py +``` + +4. 测试 API +```bash +# 获取访问令牌 +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" + +# 使用令牌获取用户信息 +curl -X GET http://localhost:8000/api/v2/me/osu \ + -H "Authorization: Bearer YOUR_ACCESS_TOKEN" +``` + +### 本地开发 + +1. 安装依赖 +```bash +pip install -r requirements.txt +``` + +2. 配置环境变量 +```bash +# 复制服务器配置文件 +cp .env .env.local + +# 复制客户端配置文件(用于测试脚本) +cp .env.client .env.client.local +``` + +3. 启动 MySQL 和 Redis +```bash +# 使用 Docker +docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0 +docker run -d --name redis -p 6379:6379 redis:7-alpine +``` + +4. 创建示例数据 +```bash +python create_sample_data.py +``` + +5. 启动应用 +```bash +uvicorn main:app --reload +``` + +6. 测试 API +```bash +# 使用测试脚本(会自动加载 .env 文件) +python test_api.py + +# 或使用原始示例脚本 +python osu_api_example.py +``` + +## 项目结构 + +``` +osu_lazer_api/ +├── app/ +│ ├── __init__.py +│ ├── models.py # Pydantic 数据模型 +│ ├── database.py # SQLAlchemy 数据库模型 +│ ├── config.py # 配置设置 +│ ├── dependencies.py # 依赖注入 +│ ├── auth.py # 认证和令牌管理 +│ └── utils.py # 工具函数 +├── main.py # FastAPI 应用主文件 +├── create_sample_data.py # 示例数据创建脚本 +├── requirements.txt # Python 依赖 +├── .env # 环境变量配置 +├── docker-compose.yml # Docker Compose 配置 +├── Dockerfile # Docker 镜像配置 +└── README.md # 项目说明 +``` + +## 示例用户 + +创建示例数据后,您可以使用以下凭据进行测试: + +- **用户名**: `Googujiang` +- **密码**: `password123` +- **用户ID**: `15651670` + +## 环境变量配置 + +项目包含两个环境配置文件: + +### 服务器配置 (`.env`) +用于配置 FastAPI 服务器的运行参数: + +| 变量名 | 描述 | 默认值 | +|--------|------|--------| +| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` | +| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` | +| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` | +| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` | +| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | +| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | +| `HOST` | 服务器监听地址 | `0.0.0.0` | +| `PORT` | 服务器监听端口 | `8000` | +| `DEBUG` | 调试模式 | `True` | + +### 客户端配置 (`.env.client`) +用于配置客户端脚本的 API 连接参数: + +| 变量名 | 描述 | 默认值 | +|--------|------|--------| +| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | +| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | +| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` | + +> **注意**: 在生产环境中,请务必更改默认的密钥和密码! + +## API 使用示例 + +### 获取访问令牌 + +```bash +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*" +``` + +响应: +```json +{ + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", + "token_type": "Bearer", + "expires_in": 86400, + "refresh_token": "abc123...", + "scope": "*" +} +``` + +### 获取用户信息 + +```bash +curl -X GET http://localhost:8000/api/v2/me/osu \ + -H "Authorization: Bearer YOUR_ACCESS_TOKEN" +``` + +### 刷新令牌 + +```bash +curl -X POST http://localhost:8000/oauth/token \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" +``` + +## 开发 + +### 添加新用户 + +您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。 + +### 扩展功能 + +- 添加更多 API 端点(排行榜、谱面信息等) +- 实现实时功能(WebSocket) +- 添加管理面板 +- 实现数据导入/导出功能 + +### 迁移数据库 + +参考[数据库迁移指南](./MIGRATE_GUIDE.md) + +## 许可证 + +MIT License + +## 贡献 + +欢迎提交 Issue 和 Pull Request! diff --git a/app/database/__init__.py b/app/database/__init__.py index 568104b..7b5c228 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -16,10 +16,22 @@ from .lazer_user import ( User, UserResp, ) +from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp +from .playlist_attempts import ( + ItemAttemptsCount, + ItemAttemptsResp, + PlaylistAggregateScore, +) +from .playlist_best_score import PlaylistBestScore +from .playlists import Playlist, PlaylistResp from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType +from .room import APIUploadedRoom, Room, RoomResp +from .room_participated_user import RoomParticipatedUser from .score import ( + MultiplayerScores, Score, + ScoreAround, ScoreBase, ScoreResp, ScoreStatistics, @@ -37,6 +49,7 @@ from .user_account_history import ( ) __all__ = [ + "APIUploadedRoom", "Beatmap", "BeatmapPlaycounts", "BeatmapPlaycountsResp", @@ -46,12 +59,25 @@ __all__ = [ "DailyChallengeStats", "DailyChallengeStatsResp", "FavouriteBeatmapset", + "ItemAttemptsCount", + "ItemAttemptsResp", + "MultiplayerEvent", + "MultiplayerEventResp", + "MultiplayerScores", "OAuthToken", "PPBestScore", + "Playlist", + "PlaylistAggregateScore", + "PlaylistBestScore", + "PlaylistResp", "Relationship", "RelationshipResp", "RelationshipType", + "Room", + "RoomParticipatedUser", + "RoomResp", "Score", + "ScoreAround", "ScoreBase", "ScoreResp", "ScoreStatistics", diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 5470277..192ac71 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import TYPE_CHECKING from app.models.beatmap import BeatmapRankStatus -from app.models.model import UTCBaseModel from app.models.score import MODE_TO_INT, GameMode from .beatmap_playcounts import BeatmapPlaycounts @@ -23,7 +22,7 @@ class BeatmapOwner(SQLModel): username: str -class BeatmapBase(SQLModel, UTCBaseModel): +class BeatmapBase(SQLModel): # Beatmap url: str mode: GameMode @@ -63,7 +62,7 @@ class BeatmapBase(SQLModel, UTCBaseModel): class Beatmap(BeatmapBase, table=True): __tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType] - id: int | None = Field(default=None, primary_key=True, index=True) + id: int = Field(primary_key=True, index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 65a3b2a..8a95017 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -2,7 +2,6 @@ from datetime import datetime from typing import TYPE_CHECKING, TypedDict, cast from app.models.beatmap import BeatmapRankStatus, Genre, Language -from app.models.model import UTCBaseModel from app.models.score import GameMode from .lazer_user import BASE_INCLUDES, User, UserResp @@ -14,6 +13,8 @@ from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: + from app.fetcher import Fetcher + from .beatmap import Beatmap, BeatmapResp from .favourite_beatmapset import FavouriteBeatmapset @@ -87,7 +88,7 @@ class BeatmapTranslationText(BaseModel): id: int | None = None -class BeatmapsetBase(SQLModel, UTCBaseModel): +class BeatmapsetBase(SQLModel): # Beatmapset artist: str = Field(index=True) artist_unicode: str = Field(index=True) @@ -186,6 +187,16 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) return beatmapset + @classmethod + async def get_or_fetch( + cls, session: AsyncSession, fetcher: "Fetcher", sid: int + ) -> "Beatmapset": + beatmapset = await session.get(Beatmapset, sid) + if not beatmapset: + resp = await fetcher.get_beatmapset(sid) + beatmapset = await cls.from_resp(session, resp) + return beatmapset + class BeatmapsetResp(BeatmapsetBase): id: int diff --git a/app/database/best_score.py b/app/database/best_score.py index 42b0024..8688d5b 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -29,9 +29,7 @@ class BestScore(SQLModel, table=True): ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - total_score: int = Field( - default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score")) - ) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) mods: list[str] = Field( default_factory=list, sa_column=Column(JSON), diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py new file mode 100644 index 0000000..904fbe4 --- /dev/null +++ b/app/database/multiplayer_event.py @@ -0,0 +1,56 @@ +from datetime import UTC, datetime +from typing import Any + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + SQLModel, +) + + +class MultiplayerEventBase(SQLModel, UTCBaseModel): + playlist_item_id: int | None = None + user_id: int | None = Field( + default=None, + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), + ) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_type: str = Field(index=True) + + +class MultiplayerEvent(MultiplayerEventBase, table=True): + __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) + room_id: int = Field(foreign_key="rooms.id", index=True) + updated_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_detail: dict[str, Any] | None = Field( + sa_column=Column(JSON), + default_factory=dict, + ) + + +class MultiplayerEventResp(MultiplayerEventBase): + id: int + + @classmethod + def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp": + return cls.model_validate(event) diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py new file mode 100644 index 0000000..5d580cf --- /dev/null +++ b/app/database/playlist_attempts.py @@ -0,0 +1,151 @@ +from .lazer_user import User, UserResp +from .playlist_best_score import PlaylistBestScore + +from pydantic import BaseModel +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + + +class ItemAttemptsCountBase(SQLModel): + room_id: int = Field(foreign_key="rooms.id", index=True) + attempts: int = Field(default=0) + completed: int = Field(default=0) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + accuracy: float = 0.0 + pp: float = 0 + total_score: int = 0 + + +class ItemAttemptsCount(ItemAttemptsCountBase, table=True): + __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] + id: int | None = Field(default=None, primary_key=True) + + user: User = Relationship() + + async def get_position(self, session: AsyncSession) -> int: + rownum = ( + func.row_number() + .over( + partition_by=col(ItemAttemptsCountBase.room_id), + order_by=col(ItemAttemptsCountBase.total_score).desc(), + ) + .label("rn") + ) + subq = select(ItemAttemptsCountBase, rownum).subquery() + stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id) + result = await session.exec(stmt) + return result.one() + + async def update(self, session: AsyncSession): + playlist_scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == self.room_id, + PlaylistBestScore.user_id == self.user_id, + ) + ) + ).all() + self.attempts = sum(score.attempts for score in playlist_scores) + self.total_score = sum(score.total_score for score in playlist_scores) + self.pp = sum(score.score.pp for score in playlist_scores) + self.completed = len(playlist_scores) + self.accuracy = ( + sum(score.score.accuracy for score in playlist_scores) / self.completed + if self.completed > 0 + else 0.0 + ) + await session.commit() + await session.refresh(self) + + @classmethod + async def get_or_create( + cls, + room_id: int, + user_id: int, + session: AsyncSession, + ) -> "ItemAttemptsCount": + item_attempts = await session.exec( + select(cls).where( + cls.room_id == room_id, + cls.user_id == user_id, + ) + ) + item_attempts = item_attempts.first() + if item_attempts is None: + item_attempts = cls(room_id=room_id, user_id=user_id) + session.add(item_attempts) + await session.commit() + await session.refresh(item_attempts) + await item_attempts.update(session) + return item_attempts + + +class ItemAttemptsResp(ItemAttemptsCountBase): + user: UserResp | None = None + position: int | None = None + + @classmethod + async def from_db( + cls, + item_attempts: ItemAttemptsCount, + session: AsyncSession, + include: list[str] = [], + ) -> "ItemAttemptsResp": + resp = cls.model_validate(item_attempts.model_dump()) + resp.user = await UserResp.from_db( + item_attempts.user, + session=session, + include=["statistics", "team", "daily_challenge_user_stats"], + ) + if "position" in include: + resp.position = await item_attempts.get_position(session) + # resp.accuracy *= 100 + return resp + + +class ItemAttemptsCountForItem(BaseModel): + id: int + attempts: int + passed: bool + + +class PlaylistAggregateScore(BaseModel): + playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list) + + @classmethod + async def from_db( + cls, + room_id: int, + user_id: int, + session: AsyncSession, + ) -> "PlaylistAggregateScore": + playlist_scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.user_id == user_id, + ) + ) + ).all() + playlist_item_attempts = [] + for score in playlist_scores: + playlist_item_attempts.append( + ItemAttemptsCountForItem( + id=score.playlist_id, + attempts=score.attempts, + passed=score.score.passed, + ) + ) + return cls(playlist_item_attempts=playlist_item_attempts) diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py new file mode 100644 index 0000000..6ecb18a --- /dev/null +++ b/app/database/playlist_best_score.py @@ -0,0 +1,110 @@ +from typing import TYPE_CHECKING + +from .lazer_user import User + +from redis.asyncio import Redis +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .score import Score + + +class PlaylistBestScore(SQLModel, table=True): + __tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType] + + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + room_id: int = Field(foreign_key="rooms.id", index=True) + playlist_id: int = Field(foreign_key="room_playlists.id", index=True) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + attempts: int = Field(default=0) # playlist + + user: User = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[PlaylistBestScore.score_id]", + "lazy": "joined", + } + ) + + +async def process_playlist_best_score( + room_id: int, + playlist_id: int, + user_id: int, + score_id: int, + total_score: int, + session: AsyncSession, + redis: Redis, +): + previous = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.user_id == user_id, + ) + ) + ).first() + if previous is None: + previous = PlaylistBestScore( + user_id=user_id, + score_id=score_id, + room_id=room_id, + playlist_id=playlist_id, + total_score=total_score, + ) + session.add(previous) + elif not previous.score.passed or previous.total_score < total_score: + previous.score_id = score_id + previous.total_score = total_score + previous.attempts += 1 + await session.commit() + if await redis.exists(f"multiplayer:{room_id}:gameplay:players"): + await redis.decr(f"multiplayer:{room_id}:gameplay:players") + + +async def get_position( + room_id: int, + playlist_id: int, + score_id: int, + session: AsyncSession, +) -> int: + rownum = ( + func.row_number() + .over( + partition_by=( + col(PlaylistBestScore.playlist_id), + col(PlaylistBestScore.room_id), + ), + order_by=col(PlaylistBestScore.total_score).desc(), + ) + .label("row_number") + ) + subq = ( + select(PlaylistBestScore, rownum) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + .subquery() + ) + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) + result = await session.exec(stmt) + s = result.one_or_none() + return s if s else 0 diff --git a/app/database/playlists.py b/app/database/playlists.py new file mode 100644 index 0000000..c177432 --- /dev/null +++ b/app/database/playlists.py @@ -0,0 +1,143 @@ +from datetime import datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel +from app.models.mods import APIMod +from app.models.multiplayer_hub import PlaylistItem + +from .beatmap import Beatmap, BeatmapResp + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .room import Room + + +class PlaylistBase(SQLModel, UTCBaseModel): + id: int = Field(index=True) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) + ruleset_id: int = Field(ge=0, le=3) + expired: bool = Field(default=False) + playlist_order: int = Field(default=0) + played_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True)), + default=None, + ) + allowed_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + required_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + beatmap_id: int = Field( + foreign_key="beatmaps.id", + ) + freestyle: bool = Field(default=False) + + +class Playlist(PlaylistBase, table=True): + __tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType] + db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) + room_id: int = Field(foreign_key="rooms.id", exclude=True) + + beatmap: Beatmap = Relationship( + sa_relationship_kwargs={ + "lazy": "joined", + } + ) + room: "Room" = Relationship() + + @classmethod + async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: + stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( + cls.room_id == room_id + ) + result = await session.exec(stmt) + return result.one() + + @classmethod + async def from_hub( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ) -> "Playlist": + next_id = await cls.get_next_id_for_room(room_id, session=session) + return cls( + id=next_id, + owner_id=playlist.owner_id, + ruleset_id=playlist.ruleset_id, + beatmap_id=playlist.beatmap_id, + required_mods=playlist.required_mods, + allowed_mods=playlist.allowed_mods, + expired=playlist.expired, + playlist_order=playlist.playlist_order, + played_at=playlist.played_at, + freestyle=playlist.freestyle, + room_id=room_id, + ) + + @classmethod + async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == playlist.id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + db_playlist.owner_id = playlist.owner_id + db_playlist.ruleset_id = playlist.ruleset_id + db_playlist.beatmap_id = playlist.beatmap_id + db_playlist.required_mods = playlist.required_mods + db_playlist.allowed_mods = playlist.allowed_mods + db_playlist.expired = playlist.expired + db_playlist.playlist_order = playlist.playlist_order + db_playlist.played_at = playlist.played_at + db_playlist.freestyle = playlist.freestyle + await session.commit() + + @classmethod + async def add_to_db( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ): + db_playlist = await cls.from_hub(playlist, room_id, session) + session.add(db_playlist) + await session.commit() + await session.refresh(db_playlist) + playlist.id = db_playlist.id + + @classmethod + async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == item_id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + await session.delete(db_playlist) + await session.commit() + + +class PlaylistResp(PlaylistBase): + beatmap: BeatmapResp | None = None + + @classmethod + async def from_db( + cls, playlist: Playlist, include: list[str] = [] + ) -> "PlaylistResp": + data = playlist.model_dump() + if "beatmap" in include: + data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap) + resp = cls.model_validate(data) + return resp diff --git a/app/database/room.py b/app/database/room.py index 7a1aff8..368a04a 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -1,6 +1,177 @@ -from sqlmodel import Field, SQLModel +from datetime import UTC, datetime + +from app.database.playlist_attempts import PlaylistAggregateScore +from app.database.room_participated_user import RoomParticipatedUser +from app.models.model import UTCBaseModel +from app.models.multiplayer_hub import ServerMultiplayerRoom +from app.models.room import ( + MatchType, + QueueMode, + RoomCategory, + RoomDifficultyRange, + RoomPlaylistItemStats, + RoomStatus, +) + +from .lazer_user import User, UserResp +from .playlists import Playlist, PlaylistResp + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession -class RoomIndex(SQLModel, table=True): - __tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType] - id: int | None = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue] +class RoomBase(SQLModel, UTCBaseModel): + name: str = Field(index=True) + category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) + duration: int | None = Field(default=None) # minutes + starts_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + ends_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=None, + ) + participant_count: int = Field(default=0) + max_attempts: int | None = Field(default=None) # playlists + type: MatchType + queue_mode: QueueMode + auto_skip: bool + auto_start_duration: int + status: RoomStatus + # TODO: channel_id + + +class Room(AsyncAttrs, RoomBase, table=True): + __tablename__ = "rooms" # pyright: ignore[reportAssignmentType] + id: int = Field(default=None, primary_key=True, index=True) + host_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + + host: User = Relationship() + playlist: list[Playlist] = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + "cascade": "all, delete-orphan", + "overlaps": "room", + } + ) + + +class RoomResp(RoomBase): + id: int + has_password: bool = False + host: UserResp | None = None + playlist: list[PlaylistResp] = [] + playlist_item_stats: RoomPlaylistItemStats | None = None + difficulty_range: RoomDifficultyRange | None = None + current_playlist_item: PlaylistResp | None = None + current_user_score: PlaylistAggregateScore | None = None + recent_participants: list[UserResp] = Field(default_factory=list) + + @classmethod + async def from_db( + cls, + room: Room, + session: AsyncSession, + include: list[str] = [], + user: User | None = None, + ) -> "RoomResp": + resp = cls.model_validate(room.model_dump()) + + stats = RoomPlaylistItemStats(count_active=0, count_total=0) + difficulty_range = RoomDifficultyRange( + min=0, + max=0, + ) + rulesets = set() + for playlist in room.playlist: + stats.count_total += 1 + if not playlist.expired: + stats.count_active += 1 + rulesets.add(playlist.ruleset_id) + difficulty_range.min = min( + difficulty_range.min, playlist.beatmap.difficulty_rating + ) + difficulty_range.max = max( + difficulty_range.max, playlist.beatmap.difficulty_rating + ) + resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) + stats.ruleset_ids = list(rulesets) + resp.playlist_item_stats = stats + resp.difficulty_range = difficulty_range + resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None + resp.recent_participants = [] + for recent_participant in await session.exec( + select(RoomParticipatedUser) + .where( + RoomParticipatedUser.room_id == room.id, + col(RoomParticipatedUser.left_at).is_(None), + ) + .limit(8) + .order_by(col(RoomParticipatedUser.joined_at).desc()) + ): + resp.recent_participants.append( + await UserResp.from_db( + await recent_participant.awaitable_attrs.user, + session, + include=["statistics"], + ) + ) + resp.host = await UserResp.from_db( + await room.awaitable_attrs.host, session, include=["statistics"] + ) + if "current_user_score" in include and user: + resp.current_user_score = await PlaylistAggregateScore.from_db( + room.id, user.id, session + ) + return resp + + @classmethod + async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp": + room = server_room.room + resp = cls( + id=room.room_id, + name=room.settings.name, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=int(room.settings.auto_start_duration.total_seconds()), + status=server_room.status, + category=server_room.category, + # duration = room.settings.duration, + starts_at=server_room.start_at, + participant_count=len(room.users), + ) + return resp + + +class APIUploadedRoom(RoomBase): + def to_room(self) -> Room: + """ + 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。 + """ + room_dict = self.model_dump() + room_dict.pop("playlist", None) + # host_id 已在字段中 + return Room(**room_dict) + + id: int | None + host_id: int | None = None + playlist: list[Playlist] = Field(default_factory=list) diff --git a/app/database/room_participated_user.py b/app/database/room_participated_user.py new file mode 100644 index 0000000..18b0aeb --- /dev/null +++ b/app/database/room_participated_user.py @@ -0,0 +1,39 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + from .room import Room + + +class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True): + __tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True) + ) + room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False)) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False) + ) + joined_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False), + default=datetime.now(UTC), + ) + left_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True), nullable=True), default=None + ) + + room: "Room" = Relationship() + user: "User" = Relationship() diff --git a/app/database/score.py b/app/database/score.py index bbef7ab..adeeec9 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from datetime import UTC, date, datetime import json import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from app.calculator import ( calculate_pp, @@ -14,7 +14,7 @@ from app.calculator import ( clamp, ) from app.database.team import TeamMember -from app.models.model import UTCBaseModel +from app.models.model import RespWithCursor, UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( INT_TO_MODE, @@ -89,10 +89,11 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): default=0, sa_column=Column(BigInteger), exclude=True ) type: str + beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") # optional # TODO: current_user_attributes - position: int | None = Field(default=None) # multiplayer + # position: int | None = Field(default=None) # multiplayer class Score(ScoreBase, table=True): @@ -100,7 +101,6 @@ class Score(ScoreBase, table=True): id: int | None = Field( default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True) ) - beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") user_id: int = Field( default=None, sa_column=Column( @@ -163,6 +163,8 @@ class ScoreResp(ScoreBase): maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None + position: int | None = None + scores_around: "ScoreAround | None" = None @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": @@ -234,6 +236,16 @@ class ScoreResp(ScoreBase): return s +class MultiplayerScores(RespWithCursor): + scores: list[ScoreResp] = Field(default_factory=list) + params: dict[str, Any] = Field(default_factory=dict) + + +class ScoreAround(SQLModel): + higher: MultiplayerScores | None = None + lower: MultiplayerScores | None = None + + async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() @@ -329,6 +341,10 @@ async def get_leaderboard( self_query = ( select(BestScore) .where(BestScore.user_id == user.id) + .where( + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ) .order_by(col(BestScore.total_score).desc()) .limit(1) ) @@ -616,6 +632,8 @@ async def process_score( fetcher: "Fetcher", session: AsyncSession, redis: Redis, + item_id: int | None = None, + room_id: int | None = None, ) -> Score: assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) @@ -647,6 +665,8 @@ async def process_score( nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), + playlist_item_id=item_id, + room_id=room_id, ) if can_get_pp: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) @@ -678,4 +698,5 @@ async def process_score( await session.refresh(score) await session.refresh(score_token) await session.refresh(user) + await redis.publish("score:processed", score.id) return score diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 77b15c3..e74af93 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -38,3 +38,7 @@ async def create_tables(): # Redis 依赖 def get_redis(): return redis_client + + +def get_redis_pubsub(): + return redis_client.pubsub() diff --git a/app/dependencies/scheduler.py b/app/dependencies/scheduler.py new file mode 100644 index 0000000..fa20396 --- /dev/null +++ b/app/dependencies/scheduler.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from datetime import UTC + +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +scheduler: AsyncIOScheduler | None = None + + +def init_scheduler(): + global scheduler + scheduler = AsyncIOScheduler(timezone=UTC) + scheduler.start() + + +def get_scheduler() -> AsyncIOScheduler: + global scheduler + if scheduler is None: + init_scheduler() + return scheduler # pyright: ignore[reportReturnType] + + +def stop_scheduler(): + global scheduler + if scheduler: + scheduler.shutdown() diff --git a/app/signalr/exception.py b/app/exception.py similarity index 100% rename from app/signalr/exception.py rename to app/exception.py diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 3206d03..8bf237d 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -3,10 +3,12 @@ from __future__ import annotations from enum import IntEnum from typing import ClassVar, Literal -from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState +from app.models.signalr import SignalRUnionMessage, UserState from pydantic import BaseModel, Field +TOTAL_SCORE_DISTRIBUTION_BINS = 13 + class _UserActivity(SignalRUnionMessage): ... @@ -96,16 +98,14 @@ UserActivity = ( | ModdingBeatmap | TestingBeatmap | InDailyChallengeLobby + | PlayingDailyChallenge ) class UserPresence(BaseModel): - activity: UserActivity | None = Field( - default=None, metadata=SignalRMeta(use_upper_case=True) - ) - status: OnlineStatus | None = Field( - default=None, metadata=SignalRMeta(use_upper_case=True) - ) + activity: UserActivity | None = None + + status: OnlineStatus | None = None @property def pushable(self) -> bool: @@ -126,3 +126,34 @@ class OnlineStatus(IntEnum): OFFLINE = 0 # 隐身 DO_NOT_DISTURB = 1 ONLINE = 2 + + +class DailyChallengeInfo(BaseModel): + room_id: int + + +class MultiplayerPlaylistItemStats(BaseModel): + playlist_item_id: int = 0 + total_score_distribution: list[int] = Field( + default_factory=list, + min_length=TOTAL_SCORE_DISTRIBUTION_BINS, + max_length=TOTAL_SCORE_DISTRIBUTION_BINS, + ) + cumulative_score: int = 0 + last_processed_score_id: int = 0 + + +class MultiplayerRoomStats(BaseModel): + room_id: int + playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field( + default_factory=dict + ) + + +class MultiplayerRoomScoreSetEvent(BaseModel): + room_id: int + playlist_item_id: int + score_id: int + user_id: int + total_score: int + new_rank: int | None = None diff --git a/app/models/model.py b/app/models/model.py index bc00585..5ba8093 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -13,3 +13,10 @@ class UTCBaseModel(BaseModel): v = v.replace(tzinfo=UTC) return v.astimezone(UTC).isoformat() return v + + +Cursor = dict[str, int] + + +class RespWithCursor(BaseModel): + cursor: Cursor | None = None diff --git a/app/models/mods.py b/app/models/mods.py index abcd2cd..299a05f 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -8,7 +8,7 @@ from app.path import STATIC_DIR class APIMod(TypedDict): acronym: str - settings: NotRequired[dict[str, bool | float | str]] + settings: NotRequired[dict[str, bool | float | str | int]] # https://github.com/ppy/osu-api/wiki#mods diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py new file mode 100644 index 0000000..35b94c1 --- /dev/null +++ b/app/models/multiplayer_hub.py @@ -0,0 +1,926 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from enum import IntEnum +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Literal, + TypedDict, + cast, + override, +) + +from app.database.beatmap import Beatmap +from app.dependencies.database import engine +from app.dependencies.fetcher import get_fetcher +from app.exception import InvokeException + +from .mods import APIMod +from .room import ( + DownloadState, + MatchType, + MultiplayerRoomState, + MultiplayerUserState, + QueueMode, + RoomCategory, + RoomStatus, +) +from .signalr import ( + SignalRMeta, + SignalRUnionMessage, + UserState, +) + +from pydantic import BaseModel, Field +from sqlalchemy import update +from sqlmodel import col +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.signalr.hub import MultiplayerHub + +HOST_LIMIT = 50 +PER_USER_LIMIT = 3 + + +class MultiplayerClientState(UserState): + room_id: int = 0 + + +class MultiplayerRoomSettings(BaseModel): + name: str = "Unnamed Room" + playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] + password: str = "" + match_type: MatchType = MatchType.HEAD_TO_HEAD + queue_mode: QueueMode = QueueMode.HOST_ONLY + auto_start_duration: timedelta = timedelta(seconds=0) + auto_skip: bool = False + + @property + def auto_start_enabled(self) -> bool: + return self.auto_start_duration != timedelta(seconds=0) + + +class BeatmapAvailability(BaseModel): + state: DownloadState = DownloadState.UNKNOWN + download_progress: float | None = None + + +class _MatchUserState(SignalRUnionMessage): ... + + +class TeamVersusUserState(_MatchUserState): + team_id: int + + union_type: ClassVar[Literal[0]] = 0 + + +MatchUserState = TeamVersusUserState + + +class _MatchRoomState(SignalRUnionMessage): ... + + +class MultiplayerTeam(BaseModel): + id: int + name: str + + +class TeamVersusRoomState(_MatchRoomState): + teams: list[MultiplayerTeam] = Field( + default_factory=lambda: [ + MultiplayerTeam(id=0, name="Team Red"), + MultiplayerTeam(id=1, name="Team Blue"), + ] + ) + + union_type: ClassVar[Literal[0]] = 0 + + +MatchRoomState = TeamVersusRoomState + + +class PlaylistItem(BaseModel): + id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] + owner_id: int + beatmap_id: int + beatmap_checksum: str + ruleset_id: int + required_mods: list[APIMod] = Field(default_factory=list) + allowed_mods: list[APIMod] = Field(default_factory=list) + expired: bool + playlist_order: int + played_at: datetime | None = None + star_rating: float + freestyle: bool + + def _get_api_mods(self): + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + return API_MODS + + def _validate_mod_for_ruleset( + self, mod: APIMod, ruleset_key: int, context: str = "mod" + ) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + # Check if mod is valid for ruleset + if ( + typed_ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[typed_ruleset_key] + ): + raise InvokeException( + f"{context} {mod['acronym']} is invalid for this ruleset" + ) + + mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] + + # Check if mod is unplayable in multiplayer + if mod_settings.get("UserPlayable", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not playable by users" + ) + + if mod_settings.get("ValidForMultiplayer", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not valid for multiplayer" + ) + + def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + for i, mod1 in enumerate(mods): + mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"]) + if mod1_settings: + incompatible = set(mod1_settings.get("IncompatibleMods", [])) + for mod2 in mods[i + 1 :]: + if mod2["acronym"] in incompatible: + raise InvokeException( + f"Mods {mod1['acronym']} and " + f"{mod2['acronym']} are incompatible" + ) + + def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + + for req_mod in self.required_mods: + req_acronym = req_mod["acronym"] + req_settings = API_MODS[typed_ruleset_key].get(req_acronym) + if req_settings: + incompatible = set(req_settings.get("IncompatibleMods", [])) + conflicting_allowed = allowed_acronyms & incompatible + if conflicting_allowed: + conflict_list = ", ".join(conflicting_allowed) + raise InvokeException( + f"Required mod {req_acronym} conflicts with " + f"allowed mods: {conflict_list}" + ) + + def validate_playlist_item_mods(self) -> None: + ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) + + # Validate required mods + for mod in self.required_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod") + + # Validate allowed mods + for mod in self.allowed_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod") + + # Check internal compatibility of required mods + self._check_mod_compatibility(self.required_mods, ruleset_key) + + # Check compatibility between required and allowed mods + self._check_required_allowed_compatibility(ruleset_key) + + def validate_user_mods( + self, + user: "MultiplayerRoomUser", + proposed_mods: list[APIMod], + ) -> tuple[bool, list[APIMod]]: + """ + Validates user mods against playlist item rules and returns valid mods. + Returns (is_valid, valid_mods). + """ + from typing import Literal, cast + + API_MODS = self._get_api_mods() + + ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id + ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) + + valid_mods = [] + all_proposed_valid = True + + # Check if mods are valid for the ruleset + for mod in proposed_mods: + if ( + ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[ruleset_key] + ): + all_proposed_valid = False + continue + valid_mods.append(mod) + + # Check mod compatibility within user mods + incompatible_mods = set() + final_valid_mods = [] + for mod in valid_mods: + if mod["acronym"] in incompatible_mods: + all_proposed_valid = False + continue + setting_mods = API_MODS[ruleset_key].get(mod["acronym"]) + if setting_mods: + incompatible_mods.update(setting_mods["IncompatibleMods"]) + final_valid_mods.append(mod) + + # If not freestyle, check against allowed mods + if not self.freestyle: + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + filtered_valid_mods = [] + for mod in final_valid_mods: + if mod["acronym"] not in allowed_acronyms: + all_proposed_valid = False + else: + filtered_valid_mods.append(mod) + final_valid_mods = filtered_valid_mods + + # Check compatibility with required mods + required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} + all_mod_acronyms = { + mod["acronym"] for mod in final_valid_mods + } | required_mod_acronyms + + # Check for incompatibility between required and user mods + filtered_valid_mods = [] + for mod in final_valid_mods: + mod_acronym = mod["acronym"] + is_compatible = True + + for other_acronym in all_mod_acronyms: + if other_acronym == mod_acronym: + continue + setting_mods = API_MODS[ruleset_key].get(mod_acronym) + if setting_mods and other_acronym in setting_mods["IncompatibleMods"]: + is_compatible = False + all_proposed_valid = False + break + + if is_compatible: + filtered_valid_mods.append(mod) + + return all_proposed_valid, filtered_valid_mods + + def clone(self) -> "PlaylistItem": + copy = self.model_copy() + copy.required_mods = list(self.required_mods) + copy.allowed_mods = list(self.allowed_mods) + copy.expired = False + copy.played_at = None + return copy + + +class _MultiplayerCountdown(SignalRUnionMessage): + id: int = 0 + time_remaining: timedelta + is_exclusive: Annotated[ + bool, Field(default=True), SignalRMeta(member_ignore=True) + ] = True + + +class MatchStartCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[0]] = 0 + + +class ForceGameplayStartCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[1]] = 1 + + +class ServerShuttingDownCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[2]] = 2 + + +MultiplayerCountdown = ( + MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown +) + + +class MultiplayerRoomUser(BaseModel): + user_id: int + state: MultiplayerUserState = MultiplayerUserState.IDLE + availability: BeatmapAvailability = BeatmapAvailability( + state=DownloadState.UNKNOWN, download_progress=None + ) + mods: list[APIMod] = Field(default_factory=list) + match_state: MatchUserState | None = None + ruleset_id: int | None = None # freestyle + beatmap_id: int | None = None # freestyle + + +class MultiplayerRoom(BaseModel): + room_id: int + state: MultiplayerRoomState + settings: MultiplayerRoomSettings + users: list[MultiplayerRoomUser] = Field(default_factory=list) + host: MultiplayerRoomUser | None = None + match_state: MatchRoomState | None = None + playlist: list[PlaylistItem] = Field(default_factory=list) + active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) + channel_id: int + + @classmethod + def from_db(cls, room) -> "MultiplayerRoom": + """ + 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) + """ + + # 用户列表 + users = [MultiplayerRoomUser(user_id=room.host_id)] + host_user = MultiplayerRoomUser(user_id=room.host_id) + # playlist 转换 + playlist = [] + if hasattr(room, "playlist"): + for item in room.playlist: + playlist.append( + PlaylistItem( + id=item.id, + owner_id=item.owner_id, + beatmap_id=item.beatmap_id, + beatmap_checksum=item.beatmap.checksum if item.beatmap else "", + ruleset_id=item.ruleset_id, + required_mods=item.required_mods, + allowed_mods=item.allowed_mods, + expired=item.expired, + playlist_order=item.playlist_order, + played_at=item.played_at, + star_rating=item.beatmap.difficulty_rating + if item.beatmap is not None + else 0.0, + freestyle=item.freestyle, + ) + ) + + return cls( + room_id=room.id, + state=getattr(room, "state", MultiplayerRoomState.OPEN), + settings=MultiplayerRoomSettings( + name=room.name, + playlist_item_id=playlist[0].id if playlist else 0, + password=getattr(room, "password", ""), + match_type=room.type, + queue_mode=room.queue_mode, + auto_start_duration=timedelta(seconds=room.auto_start_duration), + auto_skip=room.auto_skip, + ), + users=users, + host=host_user, + match_state=None, + playlist=playlist, + active_countdowns=[], + channel_id=getattr(room, "channel_id", 0), + ) + + +class MultiplayerQueue: + def __init__(self, room: "ServerMultiplayerRoom"): + self.server_room = room + self.current_index = 0 + + @property + def hub(self) -> "MultiplayerHub": + return self.server_room.hub + + @property + def upcoming_items(self): + return sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda i: i.playlist_order, + ) + + @property + def room(self): + return self.server_room.room + + async def update_order(self): + from app.database import Playlist + + match self.room.settings.queue_mode: + case QueueMode.ALL_PLAYERS_ROUND_ROBIN: + ordered_active_items = [] + + is_first_set = True + first_set_order_by_user_id = {} + + active_items = [item for item in self.room.playlist if not item.expired] + active_items.sort(key=lambda x: x.id) + + user_item_groups = {} + for item in active_items: + if item.owner_id not in user_item_groups: + user_item_groups[item.owner_id] = [] + user_item_groups[item.owner_id].append(item) + + max_items = max( + (len(items) for items in user_item_groups.values()), default=0 + ) + + for i in range(max_items): + current_set = [] + for user_id, items in user_item_groups.items(): + if i < len(items): + current_set.append(items[i]) + + if is_first_set: + current_set.sort( + key=lambda item: (item.playlist_order, item.id) + ) + ordered_active_items.extend(current_set) + first_set_order_by_user_id = { + item.owner_id: idx + for idx, item in enumerate(ordered_active_items) + } + else: + current_set.sort( + key=lambda item: first_set_order_by_user_id.get( + item.owner_id, 0 + ) + ) + ordered_active_items.extend(current_set) + + is_first_set = False + case _: + ordered_active_items = sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda x: x.id, + ) + async with AsyncSession(engine) as session: + for idx, item in enumerate(ordered_active_items): + if item.playlist_order == idx: + continue + item.playlist_order = idx + await Playlist.update(item, self.room.room_id, session) + await self.hub.playlist_changed( + self.server_room, item, beatmap_changed=False + ) + + async def update_current_item(self): + upcoming_items = self.upcoming_items + next_item = ( + upcoming_items[0] + if upcoming_items + else max( + self.room.playlist, + key=lambda i: i.played_at or datetime.min, + ) + ) + self.current_index = self.room.playlist.index(next_item) + last_id = self.room.settings.playlist_item_id + self.room.settings.playlist_item_id = next_item.id + if last_id != next_item.id: + await self.hub.setting_changed(self.server_room, True) + + async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + is_host = self.room.host and self.room.host.user_id == user.user_id + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host: + raise InvokeException("You are not the host") + + limit = HOST_LIMIT if is_host else PER_USER_LIMIT + if ( + len([True for u in self.room.playlist if u.owner_id == user.user_id]) + >= limit + ): + raise InvokeException(f"You can only have {limit} items in the queue") + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + fetcher = await get_fetcher() + async with session: + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=item.beatmap_id + ) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.beatmap_checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + item.validate_playlist_item_mods() + item.owner_id = user.user_id + item.star_rating = float( + beatmap.difficulty_rating + ) # FIXME: beatmap use decimal + await Playlist.add_to_db(item, self.room.room_id, session) + self.room.playlist.append(item) + await self.hub.playlist_added(self.server_room, item) + await self.update_order() + await self.update_current_item() + + async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + fetcher = await get_fetcher() + async with session: + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=item.beatmap_id + ) + if item.beatmap_checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + existing_item = next( + (i for i in self.room.playlist if i.id == item.id), None + ) + if existing_item is None: + raise InvokeException( + "Attempted to change an item that doesn't exist" + ) + + if existing_item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to change an item which is not owned by the user" + ) + + if existing_item.expired: + raise InvokeException( + "Attempted to change an item which has already been played" + ) + + item.validate_playlist_item_mods() + item.owner_id = user.user_id + item.star_rating = float(beatmap.difficulty_rating) + item.playlist_order = existing_item.playlist_order + + await Playlist.update(item, self.room.room_id, session) + + # Update item in playlist + for idx, playlist_item in enumerate(self.room.playlist): + if playlist_item.id == item.id: + self.room.playlist[idx] = item + break + + await self.hub.playlist_changed( + self.server_room, + item, + beatmap_changed=item.beatmap_checksum + != existing_item.beatmap_checksum, + ) + + async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): + from app.database import Playlist + + item = next( + (i for i in self.room.playlist if i.id == playlist_item_id), + None, + ) + + if item is None: + raise InvokeException("Item does not exist in the room") + + # Check if it's the only item and current item + if item == self.current_item: + upcoming_items = [i for i in self.room.playlist if not i.expired] + if len(upcoming_items) == 1: + raise InvokeException("The only item in the room cannot be removed") + + if item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to remove an item which is not owned by the user" + ) + + if item.expired: + raise InvokeException( + "Attempted to remove an item which has already been played" + ) + + async with AsyncSession(engine) as session: + await Playlist.delete_item(item.id, self.room.room_id, session) + + self.room.playlist.remove(item) + self.current_index = self.room.playlist.index(self.upcoming_items[0]) + + await self.update_order() + await self.update_current_item() + await self.hub.playlist_removed(self.server_room, item.id) + + async def finish_current_item(self): + from app.database import Playlist + + async with AsyncSession(engine) as session: + played_at = datetime.now(UTC) + await session.execute( + update(Playlist) + .where( + col(Playlist.id) == self.current_item.id, + col(Playlist.room_id) == self.room.room_id, + ) + .values(expired=True, played_at=played_at) + ) + self.room.playlist[self.current_index].expired = True + self.room.playlist[self.current_index].played_at = played_at + await self.hub.playlist_changed(self.server_room, self.current_item, True) + await self.update_order() + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + await self.update_current_item() + + async def update_queue_mode(self): + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + await self.update_order() + await self.update_current_item() + + @property + def current_item(self): + return self.room.playlist[self.current_index] + + +@dataclass +class CountdownInfo: + countdown: MultiplayerCountdown + duration: timedelta + task: asyncio.Task | None = None + + def __init__(self, countdown: MultiplayerCountdown): + self.countdown = countdown + self.duration = ( + countdown.time_remaining + if countdown.time_remaining > timedelta(seconds=0) + else timedelta(seconds=0) + ) + + +class _MatchRequest(SignalRUnionMessage): ... + + +class ChangeTeamRequest(_MatchRequest): + union_type: ClassVar[Literal[0]] = 0 + team_id: int + + +class StartMatchCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[1]] = 1 + duration: timedelta + + +class StopCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[2]] = 2 + id: int + + +MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest + + +class MatchTypeHandler(ABC): + def __init__(self, room: "ServerMultiplayerRoom"): + self.room = room + self.hub = room.hub + + @abstractmethod + async def handle_join(self, user: MultiplayerRoomUser): ... + + @abstractmethod + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @abstractmethod + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @abstractmethod + def get_details(self) -> MatchStartedEventDetail: ... + + +class HeadToHeadHandler(MatchTypeHandler): + @override + async def handle_join(self, user: MultiplayerRoomUser): + if user.match_state is not None: + user.match_state = None + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @override + def get_details(self) -> MatchStartedEventDetail: + detail = MatchStartedEventDetail(room_type="head_to_head", team=None) + return detail + + +class TeamVersusHandler(MatchTypeHandler): + @override + def __init__(self, room: "ServerMultiplayerRoom"): + super().__init__(room) + self.state = TeamVersusRoomState() + room.room.match_state = self.state + task = asyncio.create_task(self.hub.change_room_match_state(self.room)) + self.hub.tasks.add(task) + task.add_done_callback(self.hub.tasks.discard) + + def _get_best_available_team(self) -> int: + for team in self.state.teams: + if all( + ( + user.match_state is None + or not isinstance(user.match_state, TeamVersusUserState) + or user.match_state.team_id != team.id + ) + for user in self.room.room.users + ): + return team.id + + from collections import defaultdict + + team_counts = defaultdict(int) + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + team_counts[user.match_state.team_id] += 1 + + if team_counts: + min_count = min(team_counts.values()) + for team_id, count in team_counts.items(): + if count == min_count: + return team_id + return self.state.teams[0].id if self.state.teams else 0 + + @override + async def handle_join(self, user: MultiplayerRoomUser): + best_team_id = self._get_best_available_team() + user.match_state = TeamVersusUserState(team_id=best_team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): + if not isinstance(request, ChangeTeamRequest): + return + + if request.team_id not in [team.id for team in self.state.teams]: + raise InvokeException("Invalid team ID") + + user.match_state = TeamVersusUserState(team_id=request.team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @override + def get_details(self) -> MatchStartedEventDetail: + teams: dict[int, Literal["blue", "red"]] = {} + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" + detail = MatchStartedEventDetail(room_type="team_versus", team=teams) + return detail + + +MATCH_TYPE_HANDLERS = { + MatchType.HEAD_TO_HEAD: HeadToHeadHandler, + MatchType.TEAM_VERSUS: TeamVersusHandler, +} + + +@dataclass +class ServerMultiplayerRoom: + room: MultiplayerRoom + category: RoomCategory + status: RoomStatus + start_at: datetime + hub: "MultiplayerHub" + match_type_handler: MatchTypeHandler + queue: MultiplayerQueue + _next_countdown_id: int + _countdown_id_lock: asyncio.Lock + _tracked_countdown: dict[int, CountdownInfo] + + def __init__( + self, + room: MultiplayerRoom, + category: RoomCategory, + start_at: datetime, + hub: "MultiplayerHub", + ): + self.room = room + self.category = category + self.status = RoomStatus.IDLE + self.start_at = start_at + self.hub = hub + self.queue = MultiplayerQueue(self) + self._next_countdown_id = 0 + self._countdown_id_lock = asyncio.Lock() + self._tracked_countdown = {} + + async def set_handler(self): + self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( + self + ) + for i in self.room.users: + await self.match_type_handler.handle_join(i) + + async def get_next_countdown_id(self) -> int: + async with self._countdown_id_lock: + self._next_countdown_id += 1 + return self._next_countdown_id + + async def start_countdown( + self, + countdown: MultiplayerCountdown, + on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None, + ): + async def _countdown_task(self: "ServerMultiplayerRoom"): + await asyncio.sleep(info.duration.total_seconds()) + if on_complete is not None: + await on_complete(self) + await self.stop_countdown(countdown) + + if countdown.is_exclusive: + await self.stop_all_countdowns(countdown.__class__) + countdown.id = await self.get_next_countdown_id() + info = CountdownInfo(countdown) + self.room.active_countdowns.append(info.countdown) + self._tracked_countdown[countdown.id] = info + await self.hub.send_match_event( + self, CountdownStartedEvent(countdown=info.countdown) + ) + info.task = asyncio.create_task(_countdown_task(self)) + + async def stop_countdown(self, countdown: MultiplayerCountdown): + info = self._tracked_countdown.get(countdown.id) + if info is None: + return + del self._tracked_countdown[countdown.id] + self.room.active_countdowns.remove(countdown) + await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + if info.task is not None and not info.task.done(): + info.task.cancel() + + async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]): + for countdown in list(self._tracked_countdown.values()): + if isinstance(countdown.countdown, typ): + await self.stop_countdown(countdown.countdown) + + +class _MatchServerEvent(SignalRUnionMessage): ... + + +class CountdownStartedEvent(_MatchServerEvent): + countdown: MultiplayerCountdown + + union_type: ClassVar[Literal[0]] = 0 + + +class CountdownStoppedEvent(_MatchServerEvent): + id: int + + union_type: ClassVar[Literal[1]] = 1 + + +MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent + + +class GameplayAbortReason(IntEnum): + LOAD_TOOK_TOO_LONG = 0 + HOST_ABORTED = 1 + + +class MatchStartedEventDetail(TypedDict): + room_type: Literal["playlists", "head_to_head", "team_versus"] + team: dict[int, Literal["blue", "red"]] | None diff --git a/app/models/oauth.py b/app/models/oauth.py index 22fcf63..6665965 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -1,7 +1,6 @@ # OAuth 相关模型 from __future__ import annotations -from typing import List from pydantic import BaseModel @@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel): class RegistrationErrorResponse(BaseModel): """注册错误响应模型""" + form_error: dict class UserRegistrationErrors(BaseModel): """用户注册错误模型""" - username: List[str] = [] - user_email: List[str] = [] - password: List[str] = [] + + username: list[str] = [] + user_email: list[str] = [] + password: list[str] = [] class RegistrationRequestErrors(BaseModel): """注册请求错误模型""" + message: str | None = None redirect: str | None = None user: UserRegistrationErrors | None = None diff --git a/app/models/room.py b/app/models/room.py index 85aae24..3cba32f 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -1,15 +1,8 @@ from __future__ import annotations -from datetime import datetime from enum import Enum -from app.database import User -from app.database.beatmap import Beatmap -from app.models.mods import APIMod - -from .model import UTCBaseModel - -from pydantic import BaseModel, Field +from pydantic import BaseModel class RoomCategory(str, Enum): @@ -17,6 +10,7 @@ class RoomCategory(str, Enum): SPOTLIGHT = "spotlight" FEATURED_ARTIST = "featured_artist" DAILY_CHALLENGE = "daily_challenge" + REALTIME = "realtime" # INTERNAL USE ONLY, DO NOT USE IN API class MatchType(str, Enum): @@ -42,18 +36,40 @@ class RoomStatus(str, Enum): PLAYING = "playing" -class PlaylistItem(UTCBaseModel): - id: int | None - owner_id: int - ruleset_id: int - expired: bool - playlist_order: int | None - played_at: datetime | None - allowed_mods: list[APIMod] = Field(default_factory=list) - required_mods: list[APIMod] = Field(default_factory=list) - beatmap_id: int - beatmap: Beatmap | None - freestyle: bool +class MultiplayerRoomState(str, Enum): + OPEN = "open" + WAITING_FOR_LOAD = "waiting_for_load" + PLAYING = "playing" + CLOSED = "closed" + + +class MultiplayerUserState(str, Enum): + IDLE = "idle" + READY = "ready" + WAITING_FOR_LOAD = "waiting_for_load" + LOADED = "loaded" + READY_FOR_GAMEPLAY = "ready_for_gameplay" + PLAYING = "playing" + FINISHED_PLAY = "finished_play" + RESULTS = "results" + SPECTATING = "spectating" + + @property + def is_playing(self) -> bool: + return self in { + self.WAITING_FOR_LOAD, + self.PLAYING, + self.READY_FOR_GAMEPLAY, + self.LOADED, + } + + +class DownloadState(str, Enum): + UNKNOWN = "unknown" + NOT_DOWNLOADED = "not_downloaded" + DOWNLOADING = "downloading" + IMPORTING = "importing" + LOCALLY_AVAILABLE = "locally_available" class RoomPlaylistItemStats(BaseModel): @@ -67,39 +83,7 @@ class RoomDifficultyRange(BaseModel): max: float -class ItemAttemptsCount(BaseModel): - id: int - attempts: int - passed: bool - - -class PlaylistAggregateScore(BaseModel): - playlist_item_attempts: list[ItemAttemptsCount] - - -class Room(UTCBaseModel): - id: int | None - name: str = "" - password: str | None - has_password: bool = False - host: User | None - category: RoomCategory = RoomCategory.NORMAL - duration: int | None - starts_at: datetime | None - ends_at: datetime | None - participant_count: int = 0 - recent_participants: list[User] = Field(default_factory=list) - max_attempts: int | None - playlist: list[PlaylistItem] = Field(default_factory=list) - playlist_item_stats: RoomPlaylistItemStats | None - difficulty_range: RoomDifficultyRange | None - type: MatchType = MatchType.PLAYLISTS - queue_mode: QueueMode = QueueMode.HOST_ONLY - auto_skip: bool = False - auto_start_duration: int = 0 - current_user_score: PlaylistAggregateScore | None - current_playlist_item: PlaylistItem | None - channel_id: int = 0 - status: RoomStatus = RoomStatus.IDLE - # availability 字段在当前序列化中未包含,但可能在某些场景下需要 - availability: RoomAvailability | None +class PlaylistStatus(BaseModel): + count_active: int + count_total: int + ruleset_ids: list[int] diff --git a/app/models/signalr.py b/app/models/signalr.py index 90ef95f..ffbaf6b 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,12 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum -from typing import Any, ClassVar +from typing import ClassVar from pydantic import ( BaseModel, - BeforeValidator, Field, ) @@ -15,23 +13,7 @@ from pydantic import ( class SignalRMeta: member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute - use_upper_case: bool = False # use upper CamelCase for field names - - -def _by_index(v: Any, class_: type[Enum]): - enum_list = list(class_) - if not isinstance(v, int): - return v - if 0 <= v < len(enum_list): - return enum_list[v] - raise ValueError( - f"Value {v} is out of range for enum " - f"{class_.__name__} with {len(enum_list)} items" - ) - - -def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: - return BeforeValidator(lambda v: _by_index(v, enum_class)) + use_abbr: bool = True class SignalRUnionMessage(BaseModel): diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index a9e9042..9f35932 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime from enum import IntEnum -from typing import Any +from typing import Annotated, Any from app.models.beatmap import BeatmapRankStatus from app.models.mods import APIMod @@ -89,9 +89,9 @@ class LegacyReplayFrame(BaseModel): mouse_y: float | None = None button_state: int - header: FrameHeader | None = Field( - default=None, metadata=[SignalRMeta(member_ignore=True)] - ) + header: Annotated[ + FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True) + ] class FrameDataBundle(BaseModel): diff --git a/app/router/__init__.py b/app/router/__init__.py index 1e87343..22f6c70 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 beatmapset, me, relationship, + room, score, user, ) @@ -14,4 +15,9 @@ from .api_router import router as api_router from .auth import router as auth_router from .fetcher import fetcher_router as fetcher_router -__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"] +__all__ = [ + "api_router", + "auth_router", + "fetcher_router", + "signalr_router", +] diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 9574bdb..6800246 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -74,9 +74,10 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp) @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( - b_ids: list[int] = Query(alias="id", default_factory=list), + b_ids: list[int] = Query(alias="ids[]", default_factory=list), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), ): if not b_ids: # select 50 beatmaps by last_updated @@ -86,9 +87,27 @@ async def batch_get_beatmaps( ) ).all() else: - beatmaps = ( - await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) - ).all() + beatmaps = list( + ( + await db.exec( + select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50) + ) + ).all() + ) + not_found_beatmaps = [ + bid for bid in b_ids if bid not in [bm.id for bm in beatmaps] + ] + beatmaps.extend( + beatmap + for beatmap in await asyncio.gather( + *[ + Beatmap.get_or_fetch(db, fetcher, bid=bid) + for bid in not_found_beatmaps + ], + return_exceptions=True, + ) + if isinstance(beatmap, Beatmap) + ) return BatchGetResp( beatmaps=[ diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index c02b559..bebd178 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Literal -from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User +from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,11 +12,25 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from fastapi.responses import RedirectResponse -from httpx import HTTPStatusError +from httpx import HTTPError from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp) +async def lookup_beatmapset( + beatmap_id: int = Query(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), +): + beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) + resp = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=db, user=current_user + ) + return resp + + @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, @@ -24,18 +38,13 @@ async def get_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() - if not beatmapset: - try: - resp = await fetcher.get_beatmapset(sid) - await Beatmapset.from_resp(db, resp) - except HTTPStatusError: - raise HTTPException(status_code=404, detail="Beatmapset not found") - else: - resp = await BeatmapsetResp.from_db( + try: + beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid) + return await BeatmapsetResp.from_db( beatmapset, session=db, include=["recent_favourites"], user=current_user ) - return resp + except HTTPError: + raise HTTPException(status_code=404, detail="Beatmapset not found") @router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) diff --git a/app/router/relationship.py b/app/router/relationship.py index 02292c9..0832d09 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -96,9 +96,7 @@ async def add_relationship( ) ).first() assert relationship, "Relationship should exist after commit" - return AddFriendResp( - user_relation=await RelationshipResp.from_db(db, relationship) - ) + return await RelationshipResp.from_db(db, relationship) @router.delete("/friends/{target}", tags=["relationship"]) diff --git a/app/router/room.py b/app/router/room.py index 3a65617..6918364 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -1,33 +1,346 @@ from __future__ import annotations -from app.database.room import RoomIndex +from datetime import UTC, datetime +from typing import Literal + +from app.database.beatmap import Beatmap, BeatmapResp +from app.database.beatmapset import BeatmapsetResp +from app.database.lazer_user import User, UserResp +from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp +from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp +from app.database.playlists import Playlist, PlaylistResp +from app.database.room import APIUploadedRoom, Room, RoomResp +from app.database.room_participated_user import RoomParticipatedUser +from app.database.score import Score from app.dependencies.database import get_db, get_redis -from app.models.room import Room +from app.dependencies.user import get_current_user +from app.models.room import RoomCategory, RoomStatus +from app.service.room import create_playlist_room_from_api +from app.signalr.hub import MultiplayerHubs from .api_router import router -from fastapi import Depends, Query +from fastapi import Depends, HTTPException, Query +from pydantic import BaseModel, Field from redis.asyncio import Redis -from sqlmodel import select +from sqlalchemy.sql.elements import ColumnElement +from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/rooms", tags=["rooms"], response_model=list[Room]) +@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) async def get_all_rooms( - mode: str = Query( - None - ), # TODO: lazer源码显示房间不会是除了open以外的其他状态,先放在这里 - status: str = Query(None), - category: str = Query(None), + mode: Literal["open", "ended", "participated", "owned", None] = Query( + default="open" + ), + category: RoomCategory = Query(RoomCategory.NORMAL), + status: RoomStatus | None = Query(None), db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + resp_list: list[RoomResp] = [] + where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category] + now = datetime.now(UTC) + if status is not None: + where_clauses.append(col(Room.status) == status) + if mode == "open": + where_clauses.append( + (col(Room.ends_at).is_(None)) + | (col(Room.ends_at) > now.replace(tzinfo=UTC)) + ) + if category == RoomCategory.REALTIME: + where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys())) + if mode == "participated": + where_clauses.append( + exists().where( + col(RoomParticipatedUser.room_id) == Room.id, + col(RoomParticipatedUser.user_id) == current_user.id, + ) + ) + if mode == "owned": + where_clauses.append(col(Room.host_id) == current_user.id) + if mode == "ended": + where_clauses.append( + (col(Room.ends_at).is_not(None)) + & (col(Room.ends_at) < now.replace(tzinfo=UTC)) + ) + + db_rooms = ( + ( + await db.exec( + select(Room).where( + *where_clauses, + ) + ) + ) + .unique() + .all() + ) + + for room in db_rooms: + resp = await RoomResp.from_db(room, db) + if category == RoomCategory.REALTIME: + resp.has_password = bool( + MultiplayerHubs.rooms[room.id].room.settings.password.strip() + ) + resp.category = RoomCategory.NORMAL + resp_list.append(resp) + + return resp_list + + +class APICreatedRoom(RoomResp): + error: str = "" + + +async def _participate_room( + room_id: int, user_id: int, db_room: Room, session: AsyncSession +): + participated_user = ( + await session.exec( + select(RoomParticipatedUser).where( + RoomParticipatedUser.room_id == room_id, + RoomParticipatedUser.user_id == user_id, + ) + ) + ).first() + if participated_user is None: + participated_user = RoomParticipatedUser( + room_id=room_id, + user_id=user_id, + joined_at=datetime.now(UTC), + ) + session.add(participated_user) + else: + participated_user.left_at = None + participated_user.joined_at = datetime.now(UTC) + db_room.participant_count += 1 + + +@router.post("/rooms", tags=["room"], response_model=APICreatedRoom) +async def create_room( + room: APIUploadedRoom, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + user_id = current_user.id + db_room = await create_playlist_room_from_api(db, room, user_id) + await _participate_room(db_room.id, user_id, db_room, db) + # await db.commit() + # await db.refresh(db_room) + created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db)) + created_room.error = "" + return created_room + + +@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp) +async def get_room( + room: int, + category: str = Query(default=""), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), redis: Redis = Depends(get_redis), ): - all_room_ids = (await db.exec(select(RoomIndex).where(True))).all() - roomsList: list[Room] = [] - for room_index in all_room_ids: - dumped_room = await redis.get(str(room_index.id)) - if dumped_room: - actual_room = Room.model_validate_json(str(dumped_room)) - if actual_room.status == status and actual_room.category == category: - roomsList.append(actual_room) - return roomsList + # 直接从db获取信息,毕竟都一样 + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is None: + raise HTTPException(404, "Room not found") + resp = await RoomResp.from_db( + db_room, include=["current_user_score"], session=db, user=current_user + ) + return resp + + +@router.delete("/rooms/{room}", tags=["room"]) +async def delete_room(room: int, db: AsyncSession = Depends(get_db)): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is None: + raise HTTPException(404, "Room not found") + else: + db_room.ends_at = datetime.now(UTC) + await db.commit() + return None + + +@router.put("/rooms/{room}/users/{user}", tags=["room"]) +async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is not None: + await _participate_room(room, user, db_room, db) + await db.commit() + await db.refresh(db_room) + resp = await RoomResp.from_db(db_room, db) + + return resp + else: + raise HTTPException(404, "room not found0") + + +@router.delete("/rooms/{room}/users/{user}", tags=["room"]) +async def remove_user_from_room( + room: int, user: int, db: AsyncSession = Depends(get_db) +): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is not None: + participated_user = ( + await db.exec( + select(RoomParticipatedUser).where( + RoomParticipatedUser.room_id == room, + RoomParticipatedUser.user_id == user, + ) + ) + ).first() + if participated_user is not None: + participated_user.left_at = datetime.now(UTC) + db_room.participant_count -= 1 + await db.commit() + return None + else: + raise HTTPException(404, "Room not found") + + +class APILeaderboard(BaseModel): + leaderboard: list[ItemAttemptsResp] = Field(default_factory=list) + user_score: ItemAttemptsResp | None = None + + +@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard) +async def get_room_leaderboard( + room: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is None: + raise HTTPException(404, "Room not found") + + aggs = await db.exec( + select(ItemAttemptsCount) + .where(ItemAttemptsCount.room_id == room) + .order_by(col(ItemAttemptsCount.total_score).desc()) + ) + aggs_resp = [] + user_agg = None + for i, agg in enumerate(aggs): + resp = await ItemAttemptsResp.from_db(agg, db) + resp.position = i + 1 + # resp.accuracy *= 100 + aggs_resp.append(resp) + if agg.user_id == current_user.id: + user_agg = resp + return APILeaderboard( + leaderboard=aggs_resp, + user_score=user_agg, + ) + + +class RoomEvents(BaseModel): + beatmaps: list[BeatmapResp] = Field(default_factory=list) + beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict) + current_playlist_item_id: int = 0 + events: list[MultiplayerEventResp] = Field(default_factory=list) + first_event_id: int = 0 + last_event_id: int = 0 + playlist_items: list[PlaylistResp] = Field(default_factory=list) + room: RoomResp + user: list[UserResp] = Field(default_factory=list) + + +@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"]) +async def get_room_events( + room_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), + limit: int = Query(100, ge=1, le=1000), + after: int | None = Query(None, ge=0), + before: int | None = Query(None, ge=0), +): + events = ( + await db.exec( + select(MultiplayerEvent) + .where( + MultiplayerEvent.room_id == room_id, + col(MultiplayerEvent.id) > after if after is not None else True, + col(MultiplayerEvent.id) < before if before is not None else True, + ) + .order_by(col(MultiplayerEvent.id).desc()) + .limit(limit) + ) + ).all() + + user_ids = set() + playlist_items = {} + beatmap_ids = set() + + event_resps = [] + first_event_id = 0 + last_event_id = 0 + + current_playlist_item_id = 0 + for event in events: + event_resps.append(MultiplayerEventResp.from_db(event)) + + if event.user_id: + user_ids.add(event.user_id) + + if event.playlist_item_id is not None and ( + playitem := ( + await db.exec( + select(Playlist).where( + Playlist.id == event.playlist_item_id, + Playlist.room_id == room_id, + ) + ) + ).first() + ): + current_playlist_item_id = playitem.id + playlist_items[event.playlist_item_id] = playitem + beatmap_ids.add(playitem.beatmap_id) + scores = await db.exec( + select(Score).where( + Score.playlist_item_id == event.playlist_item_id, + Score.room_id == room_id, + ) + ) + for score in scores: + user_ids.add(score.user_id) + beatmap_ids.add(score.beatmap_id) + + assert event.id is not None + first_event_id = min(first_event_id, event.id) + last_event_id = max(last_event_id, event.id) + + if room := MultiplayerHubs.rooms.get(room_id): + current_playlist_item_id = room.queue.current_item.id + room_resp = await RoomResp.from_hub(room) + else: + room = (await db.exec(select(Room).where(Room.id == room_id))).first() + if room is None: + raise HTTPException(404, "Room not found") + room_resp = await RoomResp.from_db(room, db) + + users = await db.exec(select(User).where(col(User.id).in_(user_ids))) + user_resps = [await UserResp.from_db(user, db) for user in users] + beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) + beatmap_resps = [ + await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps + ] + beatmapset_resps = {} + for beatmap_resp in beatmap_resps: + beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset + + playlist_items_resps = [ + await PlaylistResp.from_db(item) for item in playlist_items.values() + ] + + return RoomEvents( + beatmaps=beatmap_resps, + beatmapsets=beatmapset_resps, + current_playlist_item_id=current_playlist_item_id, + events=event_resps, + first_event_id=first_event_id, + last_event_id=last_event_id, + playlist_items=playlist_items_resps, + room=room_resp, + user=user_resps, + ) diff --git a/app/router/score.py b/app/router/score.py index 2f1303e..d826fd0 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,11 +1,38 @@ from __future__ import annotations -from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User -from app.database.score import get_leaderboard, process_score, process_user +from datetime import UTC, datetime +import time + +from app.calculator import clamp +from app.database import ( + Beatmap, + Playlist, + Room, + Score, + ScoreResp, + ScoreToken, + ScoreTokenResp, + User, +) +from app.database.playlist_attempts import ItemAttemptsCount +from app.database.playlist_best_score import ( + PlaylistBestScore, + get_position, + process_playlist_best_score, +) +from app.database.score import ( + MultiplayerScores, + ScoreAround, + get_leaderboard, + process_score, + process_user, +) from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus +from app.models.room import RoomCategory from app.models.score import ( INT_TO_MODE, GameMode, @@ -17,12 +44,78 @@ from app.models.score import ( from .api_router import router from fastapi import Depends, Form, HTTPException, Query +from httpx import HTTPError from pydantic import BaseModel from redis.asyncio import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +READ_SCORE_TIMEOUT = 10 + + +async def submit_score( + info: SoloScoreSubmissionInfo, + beatmap: int, + token: int, + current_user: User, + db: AsyncSession, + redis: Redis, + fetcher: Fetcher, + item_id: int | None = None, + room_id: int | None = None, +): + if not info.passed: + info.rank = Rank.F + score_token = ( + await db.exec( + select(ScoreToken) + .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] + .where(ScoreToken.id == token) + ) + ).first() + if not score_token or score_token.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Score token not found") + if score_token.score_id: + score = ( + await db.exec( + select(Score).where( + Score.id == score_token.score_id, + Score.user_id == current_user.id, + ) + ) + ).first() + if not score: + raise HTTPException(status_code=404, detail="Score not found") + else: + try: + db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap) + except HTTPError: + raise HTTPException(status_code=404, detail="Beatmap not found") + ranked = db_beatmap.beatmap_status in { + BeatmapRankStatus.RANKED, + BeatmapRankStatus.APPROVED, + } + score = await process_score( + current_user, + beatmap, + ranked, + score_token, + info, + fetcher, + db, + redis, + item_id, + room_id, + ) + await db.refresh(current_user) + score_id = score.id + score_token.score_id = score_id + await process_user(db, current_user, score, ranked) + score = (await db.exec(select(Score).where(Score.id == score_id))).first() + assert score is not None + return await ScoreResp.from_db(db, score) + class BeatmapScores(BaseModel): scores: list[ScoreResp] @@ -97,9 +190,10 @@ async def get_user_beatmap_score( status_code=404, detail=f"Cannot find user {user}'s score on this beatmap" ) else: + resp = await ScoreResp.from_db(db, user_score) return BeatmapUserScore( - position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score), + position=resp.rank_global or 0, + score=resp, ) @@ -173,55 +267,285 @@ async def submit_solo_score( redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): - if not info.passed: - info.rank = Rank.F - async with db: - score_token = ( - await db.exec( - select(ScoreToken) - .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] - .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) + return await submit_score(info, beatmap, token, current_user, db, redis, fetcher) + + +@router.post( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp +) +async def create_playlist_score( + room_id: int, + playlist_id: int, + beatmap_id: int = Form(), + beatmap_hash: str = Form(), + ruleset_id: int = Form(..., ge=0, le=3), + version_hash: str = Form(""), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + room = await session.get(Room, room_id) + if not room: + raise HTTPException(status_code=404, detail="Room not found") + db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None + if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC): + raise HTTPException(status_code=400, detail="Room has ended") + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id ) - ).first() - if not score_token or score_token.user_id != current_user.id: - raise HTTPException(status_code=404, detail="Score token not found") - if score_token.score_id: - score = ( - await db.exec( - select(Score).where( - Score.id == score_token.score_id, - Score.user_id == current_user.id, + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist not found") + + # validate + if not item.freestyle: + if item.ruleset_id != ruleset_id: + raise HTTPException( + status_code=400, detail="Ruleset mismatch in playlist item" + ) + if item.beatmap_id != beatmap_id: + raise HTTPException( + status_code=400, detail="Beatmap ID mismatch in playlist item" + ) + agg = await session.exec( + select(ItemAttemptsCount).where( + ItemAttemptsCount.room_id == room_id, + ItemAttemptsCount.user_id == current_user.id, + ) + ) + agg = agg.first() + if agg and room.max_attempts and agg.attempts >= room.max_attempts: + raise HTTPException( + status_code=422, + detail="You have reached the maximum attempts for this room", + ) + if item.expired: + raise HTTPException(status_code=400, detail="Playlist item has expired") + if item.played_at: + raise HTTPException( + status_code=400, detail="Playlist item has already been played" + ) + # 这里应该不用验证mod了吧。。。 + + score_token = ScoreToken( + user_id=current_user.id, + beatmap_id=beatmap_id, + ruleset_id=INT_TO_MODE[ruleset_id], + playlist_item_id=playlist_id, + ) + session.add(score_token) + await session.commit() + await session.refresh(score_token) + return ScoreTokenResp.from_db(score_token) + + +@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}") +async def submit_playlist_score( + room_id: int, + playlist_id: int, + token: int, + info: SoloScoreSubmissionInfo, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), + fetcher: Fetcher = Depends(get_fetcher), +): + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id + ) + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist item not found") + + user_id = current_user.id + score_resp = await submit_score( + info, + item.beatmap_id, + token, + current_user, + session, + redis, + fetcher, + item.id, + room_id, + ) + await process_playlist_best_score( + room_id, + playlist_id, + user_id, + score_resp.id, + score_resp.total_score, + session, + redis, + ) + await ItemAttemptsCount.get_or_create(room_id, user_id, session) + return score_resp + + +class IndexedScoreResp(MultiplayerScores): + total: int + user_score: ScoreResp | None = None + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp +) +async def index_playlist_scores( + room_id: int, + playlist_id: int, + limit: int = 50, + cursor: int = Query(2000000, alias="cursor[total_score]"), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + room = await session.get(Room, room_id) + if not room: + raise HTTPException(status_code=404, detail="Room not found") + + limit = clamp(limit, 1, 50) + + scores = ( + await session.exec( + select(PlaylistBestScore) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.total_score < cursor, + ) + .order_by(col(PlaylistBestScore.total_score).desc()) + .limit(limit + 1) + ) + ).all() + has_more = len(scores) > limit + if has_more: + scores = scores[:-1] + + user_score = None + score_resp = [await ScoreResp.from_db(session, score.score) for score in scores] + for score in score_resp: + score.position = await get_position(room_id, playlist_id, score.id, session) + if score.user_id == current_user.id: + user_score = score + + if room.category == RoomCategory.DAILY_CHALLENGE: + score_resp = [s for s in score_resp if s.passed] + if user_score and not user_score.passed: + user_score = None + + resp = IndexedScoreResp( + scores=score_resp, + user_score=user_score, + total=len(scores), + params={ + "limit": limit, + }, + ) + if has_more: + resp.cursor = { + "total_score": scores[-1].total_score, + } + return resp + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}", + response_model=ScoreResp, +) +async def show_playlist_score( + room_id: int, + playlist_id: int, + score_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + room = await session.get(Room, room_id) + if not room: + raise HTTPException(status_code=404, detail="Room not found") + + start_time = time.time() + score_record = None + completed = room.category != RoomCategory.REALTIME + while time.time() - start_time < READ_SCORE_TIMEOUT: + if score_record is None: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.score_id == score_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, ) ) ).first() - if not score: - raise HTTPException(status_code=404, detail="Score not found") - else: - beatmap_status = ( - await db.exec( - select(Beatmap.beatmap_status).where(Beatmap.id == beatmap) + if completed_players := await redis.get( + f"multiplayer:{room_id}:gameplay:players" + ): + completed = completed_players == "0" + if score_record and completed: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position(room_id, playlist_id, score_id, session) + if completed: + scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, ) - ).first() - if beatmap_status is None: - raise HTTPException(status_code=404, detail="Beatmap not found") - ranked = beatmap_status in { - BeatmapRankStatus.RANKED, - BeatmapRankStatus.APPROVED, - } - score = await process_score( - current_user, - beatmap, - ranked, - score_token, - info, - fetcher, - db, - redis, ) - await db.refresh(current_user) - score_id = score.id - score_token.score_id = score_id - await process_user(db, current_user, score, ranked) - score = (await db.exec(select(Score).where(Score.id == score_id))).first() - assert score is not None - return await ScoreResp.from_db(db, score) + ).all() + higher_scores = [] + lower_scores = [] + for score in scores: + if score.total_score > resp.total_score: + higher_scores.append(await ScoreResp.from_db(session, score.score)) + elif score.total_score < resp.total_score: + lower_scores.append(await ScoreResp.from_db(session, score.score)) + resp.scores_around = ScoreAround( + higher=MultiplayerScores(scores=higher_scores), + lower=MultiplayerScores(scores=lower_scores), + ) + + return resp + + +@router.get( + "rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}", + response_model=ScoreResp, +) +async def get_user_playlist_score( + room_id: int, + playlist_id: int, + user_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + score_record = None + start_time = time.time() + while time.time() - start_time < READ_SCORE_TIMEOUT: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.user_id == user_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + ) + ).first() + if score_record: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position( + room_id, playlist_id, score_record.score_id, session + ) + return resp diff --git a/app/service/__init__.py b/app/service/__init__.py new file mode 100644 index 0000000..cbb83a2 --- /dev/null +++ b/app/service/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from .daily_challenge import create_daily_challenge_room +from .room import create_playlist_room, create_playlist_room_from_api + +__all__ = [ + "create_daily_challenge_room", + "create_playlist_room", + "create_playlist_room_from_api", +] diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py new file mode 100644 index 0000000..ec7f9d0 --- /dev/null +++ b/app/service/daily_challenge.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +import json + +from app.database.playlists import Playlist +from app.database.room import Room +from app.dependencies.database import engine, get_redis +from app.dependencies.scheduler import get_scheduler +from app.log import logger +from app.models.metadata_hub import DailyChallengeInfo +from app.models.mods import APIMod +from app.models.room import RoomCategory + +from .room import create_playlist_room + +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def create_daily_challenge_room( + beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = [] +) -> Room: + async with AsyncSession(engine) as session: + today = datetime.now(UTC).date() + return await create_playlist_room( + session=session, + name=str(today), + host_id=3, + playlist=[ + Playlist( + id=0, + room_id=0, + owner_id=3, + ruleset_id=ruleset_id, + beatmap_id=beatmap, + required_mods=required_mods, + ) + ], + category=RoomCategory.DAILY_CHALLENGE, + duration=duration, + ) + + +@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge") +async def daily_challenge_job(): + from app.signalr.hub import MetadataHubs + + now = datetime.now(UTC) + redis = get_redis() + key = f"daily_challenge:{now.date()}" + if not await redis.exists(key): + return + async with AsyncSession(engine) as session: + room = ( + await session.exec( + select(Room).where( + Room.category == RoomCategory.DAILY_CHALLENGE, + col(Room.ends_at) > datetime.now(UTC), + ) + ) + ).first() + if room: + return + + try: + beatmap = await redis.hget(key, "beatmap") # pyright: ignore[reportGeneralTypeIssues] + ruleset_id = await redis.hget(key, "ruleset_id") # pyright: ignore[reportGeneralTypeIssues] + required_mods = await redis.hget(key, "required_mods") # pyright: ignore[reportGeneralTypeIssues] + + if beatmap is None or ruleset_id is None: + logger.warning( + f"[DailyChallenge] Missing required data for daily challenge {now}." + " Will try again in 5 minutes." + ) + get_scheduler().add_job( + daily_challenge_job, + "date", + run_date=datetime.now(UTC) + timedelta(minutes=5), + ) + return + + beatmap_int = int(beatmap) + ruleset_id_int = int(ruleset_id) + + mods_list = [] + if required_mods: + mods_list = json.loads(required_mods) + + next_day = (now + timedelta(days=1)).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + room = await create_daily_challenge_room( + beatmap=beatmap_int, + ruleset_id=ruleset_id_int, + required_mods=mods_list, + duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60), + ) + await MetadataHubs.broadcast_call( + "DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id) + ) + logger.success( + "[DailyChallenge] Added today's daily challenge: " + f"{beatmap=}, {ruleset_id=}, {required_mods=}" + ) + return + except (ValueError, json.JSONDecodeError) as e: + logger.warning( + f"[DailyChallenge] Error processing daily challenge data: {e}" + " Will try again in 5 minutes." + ) + except Exception as e: + logger.exception( + f"[DailyChallenge] Unexpected error in daily challenge job: {e}" + " Will try again in 5 minutes." + ) + get_scheduler().add_job( + daily_challenge_job, + "date", + run_date=datetime.now(UTC) + timedelta(minutes=5), + ) diff --git a/app/service/room.py b/app/service/room.py new file mode 100644 index 0000000..d11dced --- /dev/null +++ b/app/service/room.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +from app.database.beatmap import Beatmap +from app.database.playlists import Playlist +from app.database.room import APIUploadedRoom, Room +from app.dependencies.fetcher import get_fetcher +from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus + +from sqlalchemy import exists +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def create_playlist_room_from_api( + session: AsyncSession, room: APIUploadedRoom, host_id: int +) -> Room: + db_room = room.to_room() + db_room.host_id = host_id + db_room.starts_at = datetime.now(UTC) + db_room.ends_at = db_room.starts_at + timedelta( + minutes=db_room.duration if db_room.duration is not None else 0 + ) + session.add(db_room) + await session.commit() + await session.refresh(db_room) + await add_playlists_to_room(session, db_room.id, room.playlist, host_id) + await session.refresh(db_room) + return db_room + + +async def create_playlist_room( + session: AsyncSession, + name: str, + host_id: int, + category: RoomCategory = RoomCategory.NORMAL, + duration: int = 30, + max_attempts: int | None = None, + playlist: list[Playlist] = [], +) -> Room: + db_room = Room( + name=name, + category=category, + duration=duration, + starts_at=datetime.now(UTC), + ends_at=datetime.now(UTC) + timedelta(minutes=duration), + participant_count=0, + max_attempts=max_attempts, + type=MatchType.PLAYLISTS, + queue_mode=QueueMode.HOST_ONLY, + auto_skip=False, + auto_start_duration=0, + status=RoomStatus.IDLE, + host_id=host_id, + ) + session.add(db_room) + await session.commit() + await session.refresh(db_room) + await add_playlists_to_room(session, db_room.id, playlist, host_id) + await session.refresh(db_room) + return db_room + + +async def add_playlists_to_room( + session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int +): + for item in playlist: + if not ( + await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap))) + ).first(): + fetcher = await get_fetcher() + await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id) + item.id = await Playlist.get_next_id_for_room(room_id, session) + item.room_id = room_id + item.owner_id = owner_id + session.add(item) + await session.commit() diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py new file mode 100644 index 0000000..144dfd0 --- /dev/null +++ b/app/service/subscribers/base.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +from app.dependencies.database import get_redis_pubsub + + +class RedisSubscriber: + def __init__(self): + self.pubsub = get_redis_pubsub() + self.handlers: dict[str, list[Callable[[str, str], Awaitable[Any]]]] = {} + self.task: asyncio.Task | None = None + + async def subscribe(self, channel: str): + await self.pubsub.subscribe(channel) + if channel not in self.handlers: + self.handlers[channel] = [] + + async def unsubscribe(self, channel: str): + if channel in self.handlers: + del self.handlers[channel] + await self.pubsub.unsubscribe(channel) + + async def listen(self): + while True: + message = await self.pubsub.get_message( + ignore_subscribe_messages=True, timeout=None + ) + if message is not None and message["type"] == "message": + method = self.handlers.get(message["channel"]) + if method: + await asyncio.gather( + *[ + handler(message["channel"], message["data"]) + for handler in method + ] + ) + + def start(self): + if self.task is None or self.task.done(): + self.task = asyncio.create_task(self.listen()) + + def stop(self): + if self.task is not None and not self.task.done(): + self.task.cancel() + self.task = None diff --git a/app/service/subscribers/score_processed.py b/app/service/subscribers/score_processed.py new file mode 100644 index 0000000..b1bc5bd --- /dev/null +++ b/app/service/subscribers/score_processed.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from app.database import PlaylistBestScore, Score +from app.database.playlist_best_score import get_position +from app.dependencies.database import engine +from app.models.metadata_hub import MultiplayerRoomScoreSetEvent + +from .base import RedisSubscriber + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.signalr.hub import MetadataHub + + +CHANNEL = "score:processed" + + +class ScoreSubscriber(RedisSubscriber): + def __init__(self): + super().__init__() + self.room_subscriber: dict[int, list[int]] = {} + self.metadata_hub: "MetadataHub | None " = None + self.subscribed = False + self.handlers[CHANNEL] = [self._handler] + + async def subscribe_room_score(self, room_id: int, user_id: int): + if room_id not in self.room_subscriber: + await self.subscribe(CHANNEL) + self.start() + self.room_subscriber.setdefault(room_id, []).append(user_id) + + async def unsubscribe_room_score(self, room_id: int, user_id: int): + if room_id in self.room_subscriber: + self.room_subscriber[room_id].remove(user_id) + if not self.room_subscriber[room_id]: + del self.room_subscriber[room_id] + + async def _notify_room_score_processed(self, score_id: int): + if not self.metadata_hub: + return + async with AsyncSession(engine) as session: + score = await session.get(Score, score_id) + if ( + not score + or not score.passed + or score.room_id is None + or score.playlist_item_id is None + ): + return + if not self.room_subscriber.get(score.room_id, []): + return + + new_rank = None + user_best = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.user_id == score.user_id, + PlaylistBestScore.room_id == score.room_id, + ) + ) + ).first() + if user_best and user_best.score_id == score_id: + new_rank = await get_position( + user_best.room_id, + user_best.playlist_id, + user_best.score_id, + session, + ) + + event = MultiplayerRoomScoreSetEvent( + room_id=score.room_id, + playlist_item_id=score.playlist_item_id, + score_id=score_id, + user_id=score.user_id, + total_score=score.total_score, + new_rank=new_rank, + ) + await self.metadata_hub.notify_room_score_processed(event) + + async def _handler(self, channel: str, data: str): + score_id = int(data) + if self.metadata_hub: + await self._notify_room_score_processed(score_id) diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index f3c5b29..4bab451 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -6,9 +6,9 @@ import time from typing import Any from app.config import settings +from app.exception import InvokeException from app.log import logger from app.models.signalr import UserState -from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, CompletionPacket, diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 64232c0..f81aefa 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -1,18 +1,34 @@ from __future__ import annotations import asyncio +from collections import defaultdict from collections.abc import Coroutine from datetime import UTC, datetime +import math from typing import override -from app.database import Relationship, RelationshipType -from app.database.lazer_user import User +from app.calculator import clamp +from app.database import Relationship, RelationshipType, User +from app.database.playlist_best_score import PlaylistBestScore +from app.database.playlists import Playlist +from app.database.room import Room from app.dependencies.database import engine, get_redis -from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity +from app.models.metadata_hub import ( + TOTAL_SCORE_DISTRIBUTION_BINS, + DailyChallengeInfo, + MetadataClientState, + MultiplayerPlaylistItemStats, + MultiplayerRoomScoreSetEvent, + MultiplayerRoomStats, + OnlineStatus, + UserActivity, +) +from app.models.room import RoomCategory +from app.service.subscribers.score_processed import ScoreSubscriber from .hub import Client, Hub -from sqlmodel import select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers" @@ -21,11 +37,33 @@ ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers" class MetadataHub(Hub[MetadataClientState]): def __init__(self) -> None: super().__init__() + self.subscriber = ScoreSubscriber() + self.subscriber.metadata_hub = self + self._daily_challenge_stats: MultiplayerRoomStats | None = None + self._today = datetime.now(UTC).date() + self._lock = asyncio.Lock() + + def get_daily_challenge_stats( + self, daily_challenge_room: int + ) -> MultiplayerRoomStats: + if ( + self._daily_challenge_stats is None + or self._today != datetime.now(UTC).date() + ): + self._daily_challenge_stats = MultiplayerRoomStats( + room_id=daily_challenge_room, + playlist_item_stats={}, + ) + return self._daily_challenge_stats @staticmethod def online_presence_watchers_group() -> str: return ONLINE_PRESENCE_WATCHERS_GROUP + @staticmethod + def room_watcher_group(room_id: int) -> str: + return f"metadata:multiplayer-room-watchers:{room_id}" + def broadcast_tasks( self, user_id: int, store: MetadataClientState | None ) -> set[Coroutine]: @@ -102,10 +140,29 @@ class MetadataHub(Hub[MetadataClientState]): self.friend_presence_watchers_group(friend_id), "FriendPresenceUpdated", friend_id, - friend_state if friend_state.pushable else None, + friend_state.for_push + if friend_state.pushable + else None, ) ) await asyncio.gather(*tasks) + + daily_challenge_room = ( + await session.exec( + select(Room).where( + col(Room.ends_at) > datetime.now(UTC), + Room.category == RoomCategory.DAILY_CHALLENGE, + ) + ) + ).first() + if daily_challenge_room: + await self.call_noblock( + client, + "DailyChallengeUpdated", + DailyChallengeInfo( + room_id=daily_challenge_room.id, + ), + ) redis = get_redis() await redis.set(f"metadata:online:{user_id}", "") @@ -161,3 +218,76 @@ class MetadataHub(Hub[MetadataClientState]): async def EndWatchingUserPresence(self, client: Client) -> None: self.remove_from_group(client, self.online_presence_watchers_group()) + + async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent): + await self.broadcast_group_call( + self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event + ) + + async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int): + self.add_to_group(client, self.room_watcher_group(room_id)) + await self.subscriber.subscribe_room_score(room_id, client.user_id) + stats = self.get_daily_challenge_stats(room_id) + await self.update_daily_challenge_stats(stats) + return list(stats.playlist_item_stats.values()) + + async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None: + async with AsyncSession(engine) as session: + playlist_ids = ( + await session.exec( + select(Playlist.id).where( + Playlist.room_id == stats.room_id, + ) + ) + ).all() + for playlist_id in playlist_ids: + item = stats.playlist_item_stats.get(playlist_id, None) + if item is None: + item = MultiplayerPlaylistItemStats( + playlist_item_id=playlist_id, + total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS, + cumulative_score=0, + last_processed_score_id=0, + ) + stats.playlist_item_stats[playlist_id] = item + last_processed_score_id = item.last_processed_score_id + scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == stats.room_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.score_id > last_processed_score_id, + ) + ) + ).all() + if len(scores) == 0: + continue + + async with self._lock: + if item.last_processed_score_id == last_processed_score_id: + totals = defaultdict(int) + for score in scores: + bin_index = int( + clamp( + math.floor(score.total_score / 100000), + 0, + TOTAL_SCORE_DISTRIBUTION_BINS - 1, + ) + ) + totals[bin_index] += 1 + + item.cumulative_score += sum( + score.total_score for score in scores + ) + + for j in range(TOTAL_SCORE_DISTRIBUTION_BINS): + item.total_score_distribution[j] += totals.get(j, 0) + + if scores: + item.last_processed_score_id = max( + score.score_id for score in scores + ) + + async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int): + self.remove_from_group(client, self.room_watcher_group(room_id)) + await self.subscriber.unsubscribe_room_score(room_id, client.user_id) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 72b4a52..e397031 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,6 +1,1247 @@ from __future__ import annotations -from .hub import Hub +import asyncio +from datetime import UTC, datetime, timedelta +from typing import override + +from app.database import Room +from app.database.beatmap import Beatmap +from app.database.lazer_user import User +from app.database.multiplayer_event import MultiplayerEvent +from app.database.playlists import Playlist +from app.database.relationship import Relationship, RelationshipType +from app.database.room_participated_user import RoomParticipatedUser +from app.dependencies.database import engine, get_redis +from app.dependencies.fetcher import get_fetcher +from app.exception import InvokeException +from app.log import logger +from app.models.mods import APIMod +from app.models.multiplayer_hub import ( + BeatmapAvailability, + ForceGameplayStartCountdown, + GameplayAbortReason, + MatchRequest, + MatchServerEvent, + MatchStartCountdown, + MatchStartedEventDetail, + MultiplayerClientState, + MultiplayerRoom, + MultiplayerRoomSettings, + MultiplayerRoomUser, + PlaylistItem, + ServerMultiplayerRoom, + ServerShuttingDownCountdown, + StartMatchCountdownRequest, + StopCountdownRequest, +) +from app.models.room import ( + DownloadState, + MatchType, + MultiplayerRoomState, + MultiplayerUserState, + RoomCategory, + RoomStatus, +) +from app.models.score import GameMode + +from .hub import Client, Hub + +from httpx import HTTPError +from sqlalchemy import update +from sqlmodel import col, exists, select +from sqlmodel.ext.asyncio.session import AsyncSession + +GAMEPLAY_LOAD_TIMEOUT = 30 -class MultiplayerHub(Hub): ... +class MultiplayerEventLogger: + def __init__(self): + pass + + async def log_event(self, event: MultiplayerEvent): + try: + async with AsyncSession(engine) as session: + session.add(event) + await session.commit() + except Exception as e: + logger.warning(f"Failed to log multiplayer room event to database: {e}") + + async def room_created(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_created", + ) + await self.log_event(event) + + async def room_disbanded(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_disbanded", + ) + await self.log_event(event) + + async def player_joined(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_joined", + ) + await self.log_event(event) + + async def player_left(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_left", + ) + await self.log_event(event) + + async def player_kicked(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_kicked", + ) + await self.log_event(event) + + async def host_changed(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="host_changed", + ) + await self.log_event(event) + + async def game_started( + self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail + ): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_started", + event_detail=details, # pyright: ignore[reportArgumentType] + ) + await self.log_event(event) + + async def game_aborted(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_aborted", + ) + await self.log_event(event) + + async def game_completed(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_completed", + ) + await self.log_event(event) + + +class MultiplayerHub(Hub[MultiplayerClientState]): + @override + def __init__(self): + super().__init__() + self.rooms: dict[int, ServerMultiplayerRoom] = {} + self.event_logger = MultiplayerEventLogger() + + @staticmethod + def group_id(room: int) -> str: + return f"room:{room}" + + @override + def create_state(self, client: Client) -> MultiplayerClientState: + return MultiplayerClientState( + connection_id=client.connection_id, + connection_token=client.connection_token, + ) + + @override + async def _clean_state(self, state: MultiplayerClientState): + user_id = int(state.connection_id) + if state.room_id != 0 and state.room_id in self.rooms: + server_room = self.rooms[state.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == user_id), None) + if user is not None: + await self.make_user_leave( + self.get_client_by_id(str(user_id)), server_room, user + ) + + async def CreateRoom(self, client: Client, room: MultiplayerRoom): + logger.info(f"[MultiplayerHub] {client.user_id} creating room") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") + async with AsyncSession(engine) as session: + async with session: + db_room = Room( + name=room.settings.name, + category=RoomCategory.REALTIME, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=int( + room.settings.auto_start_duration.total_seconds() + ), + host_id=client.user_id, + status=RoomStatus.IDLE, + ) + session.add(db_room) + await session.commit() + await session.refresh(db_room) + + item = room.playlist[0] + item.owner_id = client.user_id + room.room_id = db_room.id + starts_at = db_room.starts_at or datetime.now(UTC) + beatmap_exists = await session.exec( + select(exists().where(col(Beatmap.id) == item.beatmap_id)) + ) + if not beatmap_exists.one(): + fetcher = await get_fetcher() + try: + await Beatmap.get_or_fetch( + session, fetcher, bid=item.beatmap_id + ) + except HTTPError: + raise InvokeException( + "Failed to fetch beatmap, please retry later" + ) + await Playlist.add_to_db(item, room.room_id, session) + + server_room = ServerMultiplayerRoom( + room=room, + category=RoomCategory.NORMAL, + start_at=starts_at, + hub=self, + ) + self.rooms[room.room_id] = server_room + await server_room.set_handler() + await self.event_logger.room_created(room.room_id, client.user_id) + return await self.JoinRoomWithPassword( + client, room.room_id, room.settings.password + ) + + async def JoinRoom(self, client: Client, room_id: int): + return self.JoinRoomWithPassword(client, room_id, "") + + async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): + logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") + user = MultiplayerRoomUser(user_id=client.user_id) + if room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[room_id] + room = server_room.room + for u in room.users: + if u.user_id == client.user_id: + raise InvokeException("You are already in this room") + if room.settings.password != password: + raise InvokeException("Incorrect password") + if room.host is None: + # from CreateRoom + room.host = user + store.room_id = room_id + await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user) + room.users.append(user) + self.add_to_group(client, self.group_id(room_id)) + await server_room.match_type_handler.handle_join(user) + await self.event_logger.player_joined(room_id, user.user_id) + + async with AsyncSession(engine) as session: + async with session.begin(): + if ( + participated_user := ( + await session.exec( + select(RoomParticipatedUser).where( + RoomParticipatedUser.room_id == room_id, + RoomParticipatedUser.user_id == client.user_id, + ) + ) + ).first() + ) is None: + participated_user = RoomParticipatedUser( + room_id=room_id, + user_id=client.user_id, + ) + session.add(participated_user) + else: + participated_user.left_at = None + participated_user.joined_at = datetime.now(UTC) + + db_room = await session.get(Room, room_id) + if db_room is None: + raise InvokeException("Room does not exist in database") + db_room.participant_count += 1 + return room + + async def change_beatmap_availability( + self, + room_id: int, + user: MultiplayerRoomUser, + beatmap_availability: BeatmapAvailability, + ): + availability = user.availability + if ( + availability.state == beatmap_availability.state + and availability.download_progress == beatmap_availability.download_progress + ): + return + user.availability = beatmap_availability + await self.broadcast_group_call( + self.group_id(room_id), + "UserBeatmapAvailabilityChanged", + user.user_id, + beatmap_availability, + ) + + async def ChangeBeatmapAvailability( + self, client: Client, beatmap_availability: BeatmapAvailability + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + await self.change_beatmap_availability( + room.room_id, + user, + beatmap_availability, + ) + + async def AddPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.add_item( + item, + user, + ) + + async def EditPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.edit_item( + item, + user, + ) + + async def RemovePlaylistItem(self, client: Client, item_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.remove_item( + item_id, + user, + ) + + async def setting_changed(self, room: ServerMultiplayerRoom, beatmap_changed: bool): + await self.validate_styles(room) + await self.unready_all_users(room, beatmap_changed) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "SettingsChanged", + room.room.settings, + ) + + async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemAdded", + item, + ) + + async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemRemoved", + item_id, + ) + + async def playlist_changed( + self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool + ): + if item.id == room.room.settings.playlist_item_id: + await self.validate_styles(room) + await self.unready_all_users(room, beatmap_changed) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemChanged", + item, + ) + + async def ChangeUserStyle( + self, client: Client, beatmap_id: int | None, ruleset_id: int | None + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_style( + beatmap_id, + ruleset_id, + server_room, + user, + ) + + async def validate_styles(self, room: ServerMultiplayerRoom): + fetcher = await get_fetcher() + if not room.queue.current_item.freestyle: + for user in room.room.users: + await self.change_user_style( + None, + None, + room, + user, + ) + async with AsyncSession(engine) as session: + try: + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=room.queue.current_item.beatmap_id + ) + except HTTPError: + raise InvokeException("Current item beatmap not found") + beatmap_ids = ( + await session.exec( + select(Beatmap.id, Beatmap.mode).where( + Beatmap.beatmapset_id == beatmap.beatmapset_id, + ) + ) + ).all() + for user in room.room.users: + beatmap_id = user.beatmap_id + ruleset_id = user.ruleset_id + user_beatmap = next( + (b for b in beatmap_ids if b[0] == beatmap_id), + None, + ) + if beatmap_id is not None and user_beatmap is None: + beatmap_id = None + beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode + if ( + ruleset_id is not None + and beatmap_ruleset != GameMode.OSU + and ruleset_id != beatmap_ruleset + ): + ruleset_id = None + await self.change_user_style( + beatmap_id, + ruleset_id, + room, + user, + ) + + for user in room.room.users: + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, user.mods + ) + if not is_valid: + await self.change_user_mods(valid_mods, room, user) + + async def change_user_style( + self, + beatmap_id: int | None, + ruleset_id: int | None, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + if user.beatmap_id == beatmap_id and user.ruleset_id == ruleset_id: + return + + if beatmap_id is not None or ruleset_id is not None: + if not room.queue.current_item.freestyle: + raise InvokeException("Current item does not allow free user styles.") + + async with AsyncSession(engine) as session: + item_beatmap = await session.get( + Beatmap, room.queue.current_item.beatmap_id + ) + if item_beatmap is None: + raise InvokeException("Item beatmap not found") + + user_beatmap = ( + item_beatmap + if beatmap_id is None + else await session.get(Beatmap, beatmap_id) + ) + + if user_beatmap is None: + raise InvokeException("Invalid beatmap selected.") + + if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id: + raise InvokeException( + "Selected beatmap is not from the same beatmap set." + ) + + if ( + ruleset_id is not None + and user_beatmap.mode != GameMode.OSU + and ruleset_id != user_beatmap.mode + ): + raise InvokeException( + "Selected ruleset is not supported for the given beatmap." + ) + + user.beatmap_id = beatmap_id + user.ruleset_id = ruleset_id + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStyleChanged", + user.user_id, + beatmap_id, + ruleset_id, + ) + + async def ChangeUserMods(self, client: Client, new_mods: list[APIMod]): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_mods(new_mods, server_room, user) + + async def change_user_mods( + self, + new_mods: list[APIMod], + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, new_mods + ) + if not is_valid: + incompatible_mods = [ + mod["acronym"] for mod in new_mods if mod not in valid_mods + ] + raise InvokeException( + f"Incompatible mods were selected: {','.join(incompatible_mods)}" + ) + + if user.mods == valid_mods: + return + + user.mods = valid_mods + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserModsChanged", + user.user_id, + valid_mods, + ) + + async def validate_user_stare( + self, + room: ServerMultiplayerRoom, + old: MultiplayerUserState, + new: MultiplayerUserState, + ): + match new: + case MultiplayerUserState.IDLE: + if old.is_playing: + raise InvokeException( + "Cannot return to idle without aborting gameplay." + ) + case MultiplayerUserState.READY: + if old != MultiplayerUserState.IDLE: + raise InvokeException(f"Cannot change state from {old} to {new}") + if room.queue.current_item.expired: + raise InvokeException( + "Cannot ready up while all items have been played." + ) + case MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.LOADED: + if old != MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.READY_FOR_GAMEPLAY: + if old != MultiplayerUserState.LOADED: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.PLAYING: + raise InvokeException("State is managed by the server.") + case MultiplayerUserState.FINISHED_PLAY: + if old != MultiplayerUserState.PLAYING: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.RESULTS: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.SPECTATING: + if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY): + raise InvokeException(f"Cannot change state from {old} to {new}") + + async def ChangeState(self, client: Client, state: MultiplayerUserState): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if user.state == state: + return + match state: + case MultiplayerUserState.IDLE: + if user.state.is_playing: + return + case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY: + if not user.state.is_playing: + return + await self.validate_user_stare( + server_room, + user.state, + state, + ) + await self.change_user_state(server_room, user, state) + if state == MultiplayerUserState.SPECTATING and ( + room.state == MultiplayerRoomState.PLAYING + or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + ): + await self.call_noblock(client, "LoadRequested") + await self.update_room_state(server_room) + + async def change_user_state( + self, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + state: MultiplayerUserState, + ): + user.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStateChanged", + user.user_id, + user.state, + ) + + async def update_room_state(self, room: ServerMultiplayerRoom): + match room.room.state: + case MultiplayerRoomState.OPEN: + if room.room.settings.auto_start_enabled: + if ( + not room.queue.current_item.expired + and any( + u.state == MultiplayerUserState.READY + for u in room.room.users + ) + and not any( + isinstance(countdown, MatchStartCountdown) + for countdown in room.room.active_countdowns + ) + ): + await room.start_countdown( + MatchStartCountdown( + time_remaining=room.room.settings.auto_start_duration + ), + self.start_match, + ) + case MultiplayerRoomState.WAITING_FOR_LOAD: + played_count = len( + [True for user in room.room.users if user.state.is_playing] + ) + ready_count = len( + [ + True + for user in room.room.users + if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY + ] + ) + if played_count == ready_count: + await self.start_gameplay(room) + case MultiplayerRoomState.PLAYING: + if all( + u.state != MultiplayerUserState.PLAYING for u in room.room.users + ): + any_user_finished_playing = False + for u in filter( + lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, + room.room.users, + ): + any_user_finished_playing = True + await self.change_user_state( + room, u, MultiplayerUserState.RESULTS + ) + await self.change_room_state(room, MultiplayerRoomState.OPEN) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "ResultsReady", + ) + if any_user_finished_playing: + await self.event_logger.game_completed( + room.room.room_id, + room.queue.current_item.id, + ) + else: + await self.event_logger.game_aborted( + room.room.room_id, + room.queue.current_item.id, + ) + await room.queue.finish_current_item() + + async def change_room_state( + self, room: ServerMultiplayerRoom, state: MultiplayerRoomState + ): + room.room.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "RoomStateChanged", + state, + ) + + async def StartMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + # Check host state - host must be ready or spectating + if room.host.state not in ( + MultiplayerUserState.SPECTATING, + MultiplayerUserState.READY, + ): + raise InvokeException("Can't start match when the host is not ready.") + + # Check if any users are ready + if all(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Can't start match when no users are ready.") + + await self.start_match(server_room) + + async def start_match(self, room: ServerMultiplayerRoom): + if room.room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Can't start match when already in a running state.") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + ready_users = [ + u + for u in room.room.users + if u.availability.state == DownloadState.LOCALLY_AVAILABLE + and ( + u.state == MultiplayerUserState.READY + or u.state == MultiplayerUserState.IDLE + ) + ] + await asyncio.gather( + *[ + self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD) + for u in ready_users + ] + ) + await self.change_room_state( + room, + MultiplayerRoomState.WAITING_FOR_LOAD, + ) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "LoadRequested", + ) + await room.start_countdown( + ForceGameplayStartCountdown( + time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + ), + self.start_gameplay, + ) + await self.event_logger.game_started( + room.room.room_id, + room.queue.current_item.id, + details=room.match_type_handler.get_details(), + ) + + async def start_gameplay(self, room: ServerMultiplayerRoom): + if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: + raise InvokeException("Room is not ready for gameplay") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + playing = False + played_user = 0 + for user in room.room.users: + client = self.get_client_by_id(str(user.user_id)) + if client is None: + continue + + if user.state in ( + MultiplayerUserState.READY_FOR_GAMEPLAY, + MultiplayerUserState.LOADED, + ): + playing = True + played_user += 1 + await self.change_user_state(room, user, MultiplayerUserState.PLAYING) + await self.call_noblock(client, "GameplayStarted") + elif user.state == MultiplayerUserState.WAITING_FOR_LOAD: + await self.change_user_state(room, user, MultiplayerUserState.IDLE) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "GameplayAborted", + GameplayAbortReason.LOAD_TOOK_TOO_LONG, + ) + await self.change_room_state( + room, + (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), + ) + if playing: + redis = get_redis() + await redis.set( + f"multiplayer:{room.room.room_id}:gameplay:players", + played_user, + ex=3600, + ) + + async def send_match_event( + self, room: ServerMultiplayerRoom, event: MatchServerEvent + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchEvent", + event, + ) + + async def make_user_leave( + self, + client: Client, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + kicked: bool = False, + ): + self.remove_from_group(client, self.group_id(room.room.room_id)) + room.room.users.remove(user) + + if len(room.room.users) == 0: + await self.end_room(room) + await self.update_room_state(room) + if ( + len(room.room.users) != 0 + and room.room.host + and room.room.host.user_id == user.user_id + ): + next_host = room.room.users[0] + await self.set_host(room, next_host) + + if kicked: + await self.call_noblock(client, "UserKicked", user) + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserKicked", user + ) + else: + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserLeft", user + ) + + async with AsyncSession(engine) as session: + async with session.begin(): + participated_user = ( + await session.exec( + select(RoomParticipatedUser).where( + RoomParticipatedUser.room_id == room.room.room_id, + RoomParticipatedUser.user_id == user.user_id, + ) + ) + ).first() + if participated_user is not None: + participated_user.left_at = datetime.now(UTC) + + db_room = await session.get(Room, room.room.room_id) + if db_room is None: + raise InvokeException("Room does not exist in database") + db_room.participant_count -= 1 + + target_store = self.state.get(user.user_id) + if target_store: + target_store.room_id = 0 + + async def end_room(self, room: ServerMultiplayerRoom): + assert room.room.host + async with AsyncSession(engine) as session: + await session.execute( + update(Room) + .where(col(Room.id) == room.room.room_id) + .values( + name=room.room.settings.name, + ends_at=datetime.now(UTC), + type=room.room.settings.match_type, + queue_mode=room.room.settings.queue_mode, + auto_skip=room.room.settings.auto_skip, + auto_start_duration=int( + room.room.settings.auto_start_duration.total_seconds() + ), + host_id=room.room.host.user_id, + ) + ) + await self.event_logger.room_disbanded( + room.room.room_id, + room.room.host.user_id, + ) + del self.rooms[room.room.room_id] + + async def LeaveRoom(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + return + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.event_logger.player_left( + room.room_id, + user.user_id, + ) + await self.make_user_leave(client, server_room, user) + + async def KickUser(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if user_id == client.user_id: + raise InvokeException("Can't kick self") + + user = next((u for u in room.users if u.user_id == user_id), None) + if user is None: + raise InvokeException("User not found in this room") + + await self.event_logger.player_kicked( + room.room_id, + user.user_id, + ) + target_client = self.get_client_by_id(str(user.user_id)) + if target_client is None: + return + await self.make_user_leave(target_client, server_room, user, kicked=True) + + async def set_host(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser): + room.room.host = user + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "HostChanged", + user.user_id, + ) + + async def TransferHost(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + new_host = next((u for u in room.users if u.user_id == user_id), None) + if new_host is None: + raise InvokeException("User not found in this room") + await self.event_logger.host_changed( + room.room_id, + new_host.user_id, + ) + await self.set_host(server_room, new_host) + + async def AbortGameplay(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if not user.state.is_playing: + raise InvokeException("Cannot abort gameplay while not in a gameplay state") + + await self.change_user_state( + server_room, + user, + MultiplayerUserState.IDLE, + ) + await self.update_room_state(server_room) + + async def AbortMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if ( + room.state != MultiplayerRoomState.PLAYING + and room.state != MultiplayerRoomState.WAITING_FOR_LOAD + ): + raise InvokeException("Cannot abort a match that hasn't started.") + + await asyncio.gather( + *[ + self.change_user_state(server_room, u, MultiplayerUserState.IDLE) + for u in room.users + if u.state.is_playing + ] + ) + await self.broadcast_group_call( + self.group_id(room.room_id), + "GameplayAborted", + GameplayAbortReason.HOST_ABORTED, + ) + await self.update_room_state(server_room) + + async def change_user_match_state( + self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchUserStateChanged", + user.user_id, + user.match_state, + ) + + async def change_room_match_state(self, room: ServerMultiplayerRoom): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchRoomStateChanged", + room.room.match_state, + ) + + async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot change settings while playing") + + if settings.match_type == MatchType.PLAYLISTS: + raise InvokeException("Invalid match type selected") + + previous_settings = room.settings + room.settings = settings + + if previous_settings.match_type != settings.match_type: + await server_room.set_handler() + if previous_settings.queue_mode != settings.queue_mode: + await server_room.queue.update_queue_mode() + + await self.setting_changed(server_room, beatmap_changed=False) + await self.update_room_state(server_room) + + async def SendMatchRequest(self, client: Client, request: MatchRequest): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if isinstance(request, StartMatchCountdownRequest): + if room.host and room.host.user_id != user.user_id: + raise InvokeException("You are not the host of this room") + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot start a countdown during ongoing play") + await server_room.start_countdown( + MatchStartCountdown(time_remaining=request.duration), + self.start_match, + ) + elif isinstance(request, StopCountdownRequest): + countdown = next( + (c for c in room.active_countdowns if c.id == request.id), + None, + ) + if countdown is None: + return + if ( + isinstance(countdown, MatchStartCountdown) + and room.settings.auto_start_enabled + ) or isinstance( + countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown) + ): + raise InvokeException("Cannot stop the requested countdown") + + await server_room.stop_countdown(countdown) + else: + await server_room.match_type_handler.handle_request(user, request) + + async def InvitePlayer(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + async with AsyncSession(engine) as session: + db_user = await session.get(User, user_id) + target_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == user_id, + Relationship.target_id == client.user_id, + ) + ) + ).first() + inviter_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == client.user_id, + Relationship.target_id == user_id, + ) + ) + ).first() + if db_user is None: + raise InvokeException("User not found") + if db_user.id == client.user_id: + raise InvokeException("You cannot invite yourself") + if db_user.id in [u.user_id for u in room.users]: + raise InvokeException("User already invited") + if db_user.is_restricted: + raise InvokeException("User is restricted") + if ( + inviter_relationship + and inviter_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + target_relationship + and target_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + db_user.pm_friends_only + and target_relationship is not None + and target_relationship.type != RelationshipType.FOLLOW + ): + raise InvokeException( + "Cannot perform action " + "because user has disabled non-friend communications" + ) + + target_client = self.get_client_by_id(str(user_id)) + if target_client is None: + raise InvokeException("User is not online") + await self.call_noblock( + target_client, + "Invited", + client.user_id, + room.room_id, + room.settings.password, + ) + + async def unready_all_users( + self, room: ServerMultiplayerRoom, reset_beatmap_availability: bool + ): + await asyncio.gather( + *[ + self.change_user_state( + room, + user, + MultiplayerUserState.IDLE, + ) + for user in room.room.users + if user.state == MultiplayerUserState.READY + ] + ) + if reset_beatmap_availability: + await asyncio.gather( + *[ + self.change_beatmap_availability( + room.room.room_id, + user, + BeatmapAvailability(state=DownloadState.UNKNOWN), + ) + for user in room.room.users + ] + ) + await room.stop_all_countdowns(MatchStartCountdown) diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index b9a3c99..d5a12ff 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -11,6 +11,7 @@ from app.database import Beatmap, User from app.database.score import Score from app.database.score_token import ScoreToken from app.dependencies.database import engine +from app.dependencies.fetcher import get_fetcher from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics @@ -179,15 +180,13 @@ class SpectatorHub(Hub[StoreClientState]): return if state.beatmap_id is None or state.ruleset_id is None: return + + fetcher = await get_fetcher() async with AsyncSession(engine) as session: async with session.begin(): - beatmap = ( - await session.exec( - select(Beatmap).where(Beatmap.id == state.beatmap_id) - ) - ).first() - if not beatmap: - return + beatmap = await Beatmap.get_or_fetch( + session, fetcher, bid=state.beatmap_id + ) user = ( await session.exec(select(User).where(User.id == user_id)) ).first() @@ -237,16 +236,16 @@ class SpectatorHub(Hub[StoreClientState]): user_id = int(client.connection_id) store = self.get_or_create_state(client) score = store.score - assert store.beatmap_status is not None - assert store.state is not None - assert store.score is not None - if not score or not store.score_token: + if ( + score is None + or store.score_token is None + or store.beatmap_status is None + or store.state is None + ): return if ( BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED - ) and any( - k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items() - ): + ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()): await self._process_score(store, client) store.state = None store.beatmap_status = None diff --git a/app/signalr/packet.py b/app/signalr/packet.py index be98c39..8949f4b 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -15,7 +15,7 @@ from typing import ( ) from app.models.signalr import SignalRMeta, SignalRUnionMessage -from app.utils import camel_to_snake, snake_to_camel +from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal import msgpack_lazer_api as m from pydantic import BaseModel @@ -97,6 +97,8 @@ class MsgpackProtocol: return [cls.serialize_msgpack(item) for item in v] elif issubclass(typ, datetime.datetime): return [v, 0] + elif issubclass(typ, datetime.timedelta): + return int(v.total_seconds() * 10_000_000) elif isinstance(v, dict): return { cls.serialize_msgpack(k): cls.serialize_msgpack(value) @@ -126,15 +128,19 @@ class MsgpackProtocol: def process_object(v: Any, typ: type[BaseModel]) -> Any: if isinstance(v, list): d = {} - for i, f in enumerate(typ.model_fields.items()): - field, info = f - if info.exclude: + i = 0 + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: continue anno = info.annotation if anno is None: d[camel_to_snake(field)] = v[i] - continue - d[field] = MsgpackProtocol.validate_object(v[i], anno) + else: + d[field] = MsgpackProtocol.validate_object(v[i], anno) + i += 1 return d return v @@ -209,7 +215,9 @@ class MsgpackProtocol: return typ.model_validate(obj=cls.process_object(v, typ)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return v[0] - elif isinstance(v, list): + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + return datetime.timedelta(seconds=int(v / 10_000_000)) + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) @@ -234,7 +242,9 @@ class MsgpackProtocol: # except `X (Other Type) | None` if NoneType in args and v is None: return None - if not all(issubclass(arg, SignalRUnionMessage) for arg in args): + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): raise ValueError( f"Cannot validate {v} to {typ}, " "only SignalRUnionMessage subclasses are supported" @@ -292,36 +302,55 @@ class MsgpackProtocol: class JSONProtocol: @classmethod - def serialize_to_json(cls, v: Any): + def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False): typ = v.__class__ if issubclass(typ, BaseModel): - return cls.serialize_model(v) + return cls.serialize_model(v, in_union) elif isinstance(v, dict): return { - cls.serialize_to_json(k): cls.serialize_to_json(value) + cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items() } elif isinstance(v, list): return [cls.serialize_to_json(item) for item in v] elif isinstance(v, datetime.datetime): return v.isoformat() - elif isinstance(v, Enum): + elif isinstance(v, datetime.timedelta): + # d.hh:mm:ss + total_seconds = int(v.total_seconds()) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02}:{minutes:02}:{seconds:02}" + elif isinstance(v, Enum) and dict_key: return v.value + elif isinstance(v, Enum): + list_ = list(typ) + return list_.index(v) return v @classmethod - def serialize_model(cls, v: BaseModel) -> dict[str, Any]: + def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]: d = {} + is_union = issubclass(v.__class__, SignalRUnionMessage) for field, info in v.__class__.model_fields.items(): metadata = next( (m for m in info.metadata if isinstance(m, SignalRMeta)), None ) if metadata and metadata.json_ignore: continue - d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( - cls.serialize_to_json(getattr(v, field)) + name = ( + snake_to_camel( + field, + metadata.use_abbr if metadata else True, + ) + if not is_union + else snake_to_pascal( + field, + metadata.use_abbr if metadata else True, + ) ) - if issubclass(v.__class__, SignalRUnionMessage): + d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union) + if is_union and not in_union: return { "$dtype": v.__class__.__name__, "$value": d, @@ -339,7 +368,12 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - value = v.get(snake_to_camel(field, not from_union)) + name = ( + snake_to_camel(field, metadata.use_abbr if metadata else True) + if not from_union + else snake_to_pascal(field, metadata.use_abbr if metadata else True) + ) + value = v.get(name) anno = typ.model_fields[field].annotation if anno is None: d[field] = value @@ -397,7 +431,18 @@ class JSONProtocol: return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return datetime.datetime.fromisoformat(v) - elif isinstance(v, list): + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + # d.hh:mm:ss + parts = v.split(":") + if len(parts) == 3: + return datetime.timedelta( + hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]) + ) + elif len(parts) == 2: + return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) + elif len(parts) == 1: + return datetime.timedelta(seconds=int(parts[0])) + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) diff --git a/app/signalr/router.py b/app/signalr/router.py index 72b22ac..237a575 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -6,7 +6,7 @@ import time from typing import Literal import uuid -from app.database import User +from app.database import User as DBUser from app.dependencies import get_current_user from app.dependencies.database import get_db from app.dependencies.user import get_current_user_by_token @@ -25,7 +25,7 @@ router = APIRouter() async def negotiate( hub: Literal["spectator", "multiplayer", "metadata"], negotiate_version: int = Query(1, alias="negotiateVersion"), - user: User = Depends(get_current_user), + user: DBUser = Depends(get_current_user), ): connectionId = str(user.id) connectionToken = f"{connectionId}:{uuid.uuid4()}" diff --git a/app/utils.py b/app/utils.py index 0d759a1..22f06dd 100644 --- a/app/utils.py +++ b/app/utils.py @@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str: return "".join(result) -def snake_to_camel(name: str, lower_case: bool = True) -> str: +def snake_to_camel(name: str, use_abbr: bool = True) -> str: """Convert a snake_case string to camelCase.""" if not name: return name @@ -47,12 +47,46 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str: result = [] for part in parts: - if part.lower() in abbreviations: + if part.lower() in abbreviations and use_abbr: result.append(part.upper()) else: - if result or not lower_case: + if result: result.append(part.capitalize()) else: result.append(part.lower()) return "".join(result) + + +def snake_to_pascal(name: str, use_abbr: bool = True) -> str: + """Convert a snake_case string to PascalCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations and use_abbr: + result.append(part.upper()) + else: + result.append(part.capitalize()) + + return "".join(result) diff --git a/create_sample_data.py b/create_sample_data.py new file mode 100644 index 0000000..5dcd79a --- /dev/null +++ b/create_sample_data.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +osu! API 模拟服务器的示例数据填充脚本 +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime +import random + +from app.auth import get_password_hash +from app.database import ( + User, +) +from app.database.beatmap import Beatmap +from app.database.beatmapset import Beatmapset +from app.database.score import Score +from app.dependencies.database import create_tables, engine +from app.models.beatmap import BeatmapRankStatus, Genre, Language +from app.models.mods import APIMod +from app.models.score import GameMode, Rank + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def create_sample_user(): + """创建示例用户数据""" + async with AsyncSession(engine) as session: + async with session.begin(): + # 检查用户是否已存在 + result = await session.exec(select(User).where(User.name == "Googujiang")) + result2 = await session.exec( + select(User).where(User.name == "MingxuanGame") + ) + existing_user = result.first() + existing_user2 = result2.first() + if existing_user is not None and existing_user2 is not None: + print("示例用户已存在,跳过创建") + return + + # 当前时间戳 + # current_timestamp = int(time.time()) + join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) + last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) + + # 创建用户 + user = User( + name="Googujiang", + safe_name="googujiang", # 安全用户名(小写) + email="googujiang@example.com", + priv=1, # 默认权限 + pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 + country="JP", + silence_end=0, + donor_end=0, + creation_time=join_timestamp, + latest_activity=last_visit_timestamp, + clan_id=0, + clan_priv=0, + preferred_mode=0, # 0 = osu! + play_style=0, + custom_badge_name=None, + custom_badge_icon=None, + userpage_content="「世界に忘れられた」", + api_key=None, + ) + user2 = User( + name="MingxuanGame", + safe_name="mingxuangame", # 安全用户名(小写) + email="mingxuangame@example.com", + priv=1, # 默认权限 + pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 + country="US", + silence_end=0, + donor_end=0, + creation_time=join_timestamp, + latest_activity=last_visit_timestamp, + clan_id=0, + clan_priv=0, + preferred_mode=0, # 0 = osu! + play_style=0, + custom_badge_name=None, + custom_badge_icon=None, + userpage_content="For love and fun!", + api_key=None, + ) + + session.add(user) + session.add(user2) + print(f"成功创建示例用户: {user.name} (ID: {user.id})") + print(f"安全用户名: {user.safe_name}") + print(f"邮箱: {user.email}") + print(f"国家: {user.country}") + print(f"成功创建示例用户: {user2.name} (ID: {user2.id})") + print(f"安全用户名: {user2.safe_name}") + print(f"邮箱: {user2.email}") + print(f"国家: {user2.country}") + + +async def create_sample_beatmap_data(): + """创建示例谱面数据""" + async with AsyncSession(engine) as session: + async with session.begin(): + user_id = random.randint(1, 1000) + # 检查谱面集是否已存在 + statement = select(Beatmapset).where(Beatmapset.id == 1) + result = await session.exec(statement) + existing_beatmapset = result.first() + if existing_beatmapset: + print("示例谱面集已存在,跳过创建") + return existing_beatmapset + + # 创建谱面集 + beatmapset = Beatmapset( + id=1, + artist="Example Artist", + artist_unicode="Example Artist", + covers=None, + creator="Googujiang", + favourite_count=0, + hype_current=0, + hype_required=0, + nsfw=False, + play_count=0, + preview_url="", + source="", + spotlight=False, + title="Example Song", + title_unicode="Example Song", + user_id=user_id, + video=False, + availability_info=None, + download_disabled=False, + bpm=180.0, + can_be_hyped=False, + discussion_locked=False, + last_updated=datetime.now(), + ranked_date=datetime.now(), + storyboard=False, + submitted_date=datetime.now(), + current_nominations=[], + beatmap_status=BeatmapRankStatus.RANKED, + beatmap_genre=Genre.ANY, # 使用整数表示Genre枚举 + beatmap_language=Language.ANY, # 使用整数表示Language枚举 + nominations_required=0, + nominations_current=0, + pack_tags=[], + ratings=[], + ) + session.add(beatmapset) + + # 创建谱面 + beatmap = Beatmap( + id=1, + url="", + mode=GameMode.OSU, + beatmapset_id=1, + difficulty_rating=5.5, + beatmap_status=BeatmapRankStatus.RANKED, + total_length=195, + user_id=user_id, + version="Example Difficulty", + checksum="example_checksum", + current_user_playcount=0, + max_combo=1200, + ar=9.0, + cs=4.0, + drain=5.0, + accuracy=8.0, + bpm=180.0, + count_circles=1000, + count_sliders=200, + count_spinners=1, + deleted_at=None, + hit_length=180, + last_updated=datetime.now(), + passcount=10, + playcount=50, + ) + session.add(beatmap) + + # 创建成绩 + score = Score( + id=1, + accuracy=0.9876, + map_md5="example_checksum", + user_id=1, + best_id=1, + build_id=None, + classic_total_score=1234567, + ended_at=datetime.now(), + has_replay=True, + max_combo=1100, + mods=[ + APIMod(acronym="HD", settings={}), + APIMod(acronym="DT", settings={}), + ], + passed=True, + playlist_item_id=None, + pp=250.5, + preserve=True, + rank=Rank.S, + room_id=None, + gamemode=GameMode.OSU, + started_at=datetime.now(), + total_score=1234567, + type="solo_score", + position=None, + beatmap_id=1, + n300=950, + n100=30, + n50=20, + nmiss=5, + ngeki=150, + nkatu=50, + nlarge_tick_miss=None, + nslider_tail_hit=None, + ) + session.add(score) + + print(f"成功创建示例谱面集: {beatmapset.title} (ID: {beatmapset.id})") + print(f"成功创建示例谱面: {beatmap.version} (ID: {beatmap.id})") + print(f"成功创建示例成绩: ID {score.id}") + return beatmapset + + +async def main(): + print("开始创建示例数据...") + await create_tables() + await create_sample_user() + await create_sample_beatmap_data() + print("示例数据创建完成!") + # print(f"用户名: {user.name}") + # print("密码: password123") + # print("现在您可以使用这些凭据来测试API了。") + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/main.py b/main.py index f5d20c1..8569afb 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,14 @@ from datetime import datetime from app.config import settings from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher -from app.router import api_router, auth_router, fetcher_router, signalr_router +from app.dependencies.scheduler import init_scheduler, stop_scheduler +from app.router import ( + api_router, + auth_router, + fetcher_router, + signalr_router, +) +from app.service.daily_challenge import daily_challenge_job from fastapi import FastAPI @@ -16,8 +23,11 @@ async def lifespan(app: FastAPI): # on startup await create_tables() await get_fetcher() # 初始化 fetcher + init_scheduler() + await daily_challenge_job() # on shutdown yield + stop_scheduler() await engine.dispose() await redis_client.aclose() @@ -41,104 +51,6 @@ async def health_check(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} -# @app.get("/api/v2/friends") -# async def get_friends(): -# return JSONResponse( -# content=[ -# { -# "id": 123456, -# "username": "BestFriend", -# "is_online": True, -# "is_supporter": False, -# "country": {"code": "US", "name": "United States"}, -# } -# ] -# ) - - -# @app.get("/api/v2/notifications") -# async def get_notifications(): -# return JSONResponse(content={"notifications": [], "unread_count": 0}) - - -# @app.post("/api/v2/chat/ack") -# async def chat_ack(): -# return JSONResponse(content={"status": "ok"}) - - -# @app.get("/api/v2/users/{user_id}/{mode}") -# async def get_user_mode(user_id: int, mode: str): -# return JSONResponse( -# content={ -# "id": user_id, -# "username": "测试测试测", -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 114514, -# "global_rank": 666, -# "country_rank": 1, -# "hit_accuracy": 100, -# }, -# "country": {"code": "JP", "name": "Japan"}, -# } -# ) - - -# @app.get("/api/v2/me") -# async def get_me(): -# return JSONResponse( -# content={ -# "id": 15651670, -# "username": "Googujiang", -# "is_online": True, -# "country": {"code": "JP", "name": "Japan"}, -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 2826.26, -# "global_rank": 298026, -# "country_rank": 11220, -# "hit_accuracy": 95.7168, -# }, -# } -# ) - - -# @app.post("/signalr/metadata/negotiate") -# async def metadata_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "abc123", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/spectator/negotiate") -# async def spectator_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "spec456", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/multiplayer/negotiate") -# async def multiplayer_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "multi789", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - if __name__ == "__main__": from app.log import logger # noqa: F401 diff --git a/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py new file mode 100644 index 0000000..74f2e56 --- /dev/null +++ b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py @@ -0,0 +1,89 @@ +"""playlist: index playlist id + +Revision ID: d0c1b2cefe91 +Revises: 58a11441d302 +Create Date: 2025-08-06 06:02:10.512616 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "d0c1b2cefe91" +down_revision: str | Sequence[str] | None = "58a11441d302" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_index( + op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False + ) + op.create_table( + "playlist_best_scores", + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("score_id", sa.BigInteger(), nullable=False), + sa.Column("room_id", sa.Integer(), nullable=False), + sa.Column("playlist_id", sa.Integer(), nullable=False), + sa.Column("total_score", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["playlist_id"], + ["room_playlists.id"], + ), + sa.ForeignKeyConstraint( + ["room_id"], + ["rooms.id"], + ), + sa.ForeignKeyConstraint( + ["score_id"], + ["scores.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("score_id"), + ) + op.create_index( + op.f("ix_playlist_best_scores_playlist_id"), + "playlist_best_scores", + ["playlist_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_room_id"), + "playlist_best_scores", + ["room_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_user_id"), + "playlist_best_scores", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores" + ) + op.drop_table("playlist_best_scores") + op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index cd90947..3ab61c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.12" dependencies = [ "aiomysql>=0.2.0", "alembic>=1.12.1", + "apscheduler>=3.11.0", "bcrypt>=4.1.2", "cryptography>=41.0.7", "fastapi>=0.104.1", diff --git a/remove_ansi.py b/remove_ansi.py new file mode 100644 index 0000000..1720888 --- /dev/null +++ b/remove_ansi.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +Script to remove ANSI escape codes from log files +""" + +from __future__ import annotations + +import re +import sys + + +def remove_ansi_codes(text): + """ + Remove ANSI escape codes from text + """ + # Regular expression to match ANSI escape codes + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +def process_log_file(input_file, output_file=None): + """ + Process log file and remove ANSI escape codes + """ + if output_file is None: + output_file = ( + input_file.replace(".log", "_clean.log") + if ".log" in input_file + else input_file + "_clean" + ) + + with open(input_file, "r", encoding="utf-8") as infile: + content = infile.read() + + # Remove ANSI escape codes + clean_content = remove_ansi_codes(content) + + with open(output_file, "w", encoding="utf-8") as outfile: + outfile.write(clean_content) + + print(f"Processed {input_file} -> {output_file}") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python remove_ansi.py [output_file]") + sys.exit(1) + + input_file = sys.argv[1] + output_file = sys.argv[2] if len(sys.argv) > 2 else None + + process_log_file(input_file, output_file) diff --git a/static/README.md b/static/README.md index 16ece63..77b54fe 100644 --- a/static/README.md +++ b/static/README.md @@ -2,4 +2,4 @@ - `mods.json`: 包含了游戏中的所有可用mod的详细信息。 - Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json - - Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380` + - Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082` diff --git a/static/mods.json b/static/mods.json index defb57f..0a8449b 100644 --- a/static/mods.json +++ b/static/mods.json @@ -2438,7 +2438,8 @@ "Settings": [], "IncompatibleMods": [ "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2460,7 +2461,8 @@ "AC", "AT", "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2477,7 +2479,8 @@ "Settings": [], "IncompatibleMods": [ "AT", - "CN" + "CN", + "MF" ], "RequiresConfiguration": false, "UserPlayable": true, @@ -2638,6 +2641,24 @@ "ValidForMultiplayerAsFreeMod": true, "AlwaysValidForSubmission": false }, + { + "Acronym": "MF", + "Name": "Moving Fast", + "Description": "Dashing by default, slow down!", + "Type": "Fun", + "Settings": [], + "IncompatibleMods": [ + "AT", + "CN", + "RX" + ], + "RequiresConfiguration": false, + "UserPlayable": true, + "ValidForMultiplayer": true, + "ValidForFreestyleAsRequiredMod": false, + "ValidForMultiplayerAsFreeMod": true, + "AlwaysValidForSubmission": false + }, { "Acronym": "SV2", "Name": "Score V2", diff --git a/uv.lock b/uv.lock index 3fc7d3c..22a4f1d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [manifest] @@ -57,6 +57,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916, upload-time = "2025-03-17T00:02:52.713Z" }, ] +[[package]] +name = "apscheduler" +version = "3.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/00/6d6814ddc19be2df62c8c898c4df6b5b1914f3bd024b780028caa392d186/apscheduler-3.11.0.tar.gz", hash = "sha256:4c622d250b0955a65d5d0eb91c33e6d43fd879834bf541e0a18661ae60460133", size = 107347, upload-time = "2024-11-24T19:39:26.463Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/ae/9a053dd9229c0fde6b1f1f33f609ccff1ee79ddda364c756a924c6d8563b/APScheduler-3.11.0-py3-none-any.whl", hash = "sha256:fc134ca32e50f5eadcc4938e3a4545ab19131435e851abb40b34d63d5141c6da", size = 64004, upload-time = "2024-11-24T19:39:24.442Z" }, +] + [[package]] name = "bcrypt" version = "4.3.0" @@ -493,6 +505,7 @@ source = { virtual = "." } dependencies = [ { name = "aiomysql" }, { name = "alembic" }, + { name = "apscheduler" }, { name = "bcrypt" }, { name = "cryptography" }, { name = "fastapi" }, @@ -522,6 +535,7 @@ dev = [ requires-dist = [ { name = "aiomysql", specifier = ">=0.2.0" }, { name = "alembic", specifier = ">=1.12.1" }, + { name = "apscheduler", specifier = ">=3.11.0" }, { name = "bcrypt", specifier = ">=4.1.2" }, { name = "cryptography", specifier = ">=41.0.7" }, { name = "fastapi", specifier = ">=0.104.1" }, @@ -904,6 +918,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, ] +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + +[[package]] +name = "tzlocal" +version = "5.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/2e/c14812d3d4d9cd1773c6be938f89e5735a1f11a9f184ac3639b93cef35d5/tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd", size = 30761, upload-time = "2025-03-05T21:17:41.549Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/14/e2a54fabd4f08cd7af1c07030603c3356b74da07f7cc056e600436edfa17/tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d", size = 18026, upload-time = "2025-03-05T21:17:39.857Z" }, +] + [[package]] name = "uvicorn" version = "0.35.0"