This commit is contained in:
jimmy-sketch
2025-08-12 05:29:04 +00:00
122 changed files with 8837 additions and 4219 deletions

View File

@@ -4,6 +4,13 @@
"service": "devcontainer",
"shutdownAction": "stopCompose",
"workspaceFolder": "/workspaces/osu_lazer_api",
"containerEnv": {
"MYSQL_DATABASE": "osu_api",
"MYSQL_USER": "osu_user",
"MYSQL_PASSWORD": "osu_password",
"MYSQL_HOST": "mysql",
"MYSQL_PORT": "3306"
},
"customizations": {
"vscode": {
"extensions": [
@@ -66,6 +73,6 @@
3306,
6379
],
"postCreateCommand": "uv sync --dev && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check",
"postCreateCommand": "uv sync --dev && uv pip install rosu-pp-py && uv run alembic upgrade head && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check",
"remoteUser": "vscode"
}
}

5
.dockerignore Normal file
View File

@@ -0,0 +1,5 @@
.venv/
.ruff_cache/
.vscode/
storage/
replays/

View File

@@ -1,4 +0,0 @@
# osu! API 客户端配置
OSU_CLIENT_ID=5
OSU_CLIENT_SECRET=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk
OSU_API_URL=http://localhost:8000

78
.env.example Normal file
View File

@@ -0,0 +1,78 @@
# 数据库设置
MYSQL_HOST="localhost"
MYSQL_PORT=3306
MYSQL_DATABASE="osu_api"
MYSQL_USER="osu_api"
MYSQL_PASSWORD="password"
MYSQL_ROOT_PASSWORD="password"
# Redis URL
REDIS_URL="redis://127.0.0.1:6379/0"
# JWT 密钥,使用 openssl rand -hex 32 生成
JWT_SECRET_KEY="your_jwt_secret_here"
# JWT 算法
ALGORITHM="HS256"
# JWT 过期时间
ACCESS_TOKEN_EXPIRE_MINUTES=1440
# 服务器地址
HOST="0.0.0.0"
PORT=8000
# 服务器 URL
SERVER_URL="http://localhost:8000"
# 调试模式,生产环境请设置为 false
DEBUG=false
# 私有 API 密钥,用于前后端 API 调用,使用 openssl rand -hex 32 生成
PRIVATE_API_SECRET="your_private_api_secret_here"
# osu! 登录设置
OSU_CLIENT_ID=5 # lazer client ID
OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" # lazer client secret
OSU_WEB_CLIENT_ID=6 # 网页端 client ID
OSU_WEB_CLIENT_SECRET="your_osu_web_client_secret_here" # 网页端 client secret使用 openssl rand -hex 40 生成
# SignalR 服务器设置
SIGNALR_NEGOTIATE_TIMEOUT=30
SIGNALR_PING_INTERVAL=15
# Fetcher 设置
FETCHER_CLIENT_ID=""
FETCHER_CLIENT_SECRET=""
FETCHER_SCOPES=public
# 日志设置
LOG_LEVEL="INFO"
# 游戏设置
ENABLE_OSU_RX=false # 启用 osu!RX 统计数据
ENABLE_OSU_AP=false # 启用 osu!AP 统计数据
ENABLE_ALL_MODS_PP=false # 启用所有 Mod 的 PP 计算
ENABLE_SUPPORTER_FOR_ALL_USERS=false # 启用所有新注册用户的支持者状态
ENABLE_ALL_BEATMAP_LEADERBOARD=false # 启用所有谱面的排行榜(没有排行榜的谱面会以 APPROVED 状态返回)
SEASONAL_BACKGROUNDS='[]' # 季节背景图 URL 列表
# 存储服务设置
# 支持的存储类型local本地存储、r2Cloudflare R2、s3AWS S3
STORAGE_SERVICE="local"
# 存储服务配置 (JSON 格式)
# 本地存储配置(当 STORAGE_SERVICE=local 时)
STORAGE_SETTINGS='{"local_storage_path": "./storage"}'
# Cloudflare R2 存储配置(当 STORAGE_SERVICE=r2 时)
# STORAGE_SETTINGS='{
# "r2_account_id": "your_cloudflare_r2_account_id",
# "r2_access_key_id": "your_r2_access_key_id",
# "r2_secret_access_key": "your_r2_secret_access_key",
# "r2_bucket_name": "your_r2_bucket_name",
# "r2_public_url_base": "https://your-custom-domain.com"
# }'
# AWS S3 存储配置(当 STORAGE_SERVICE=s3 时)
# STORAGE_SETTINGS='{
# "s3_access_key_id": "your_aws_access_key_id",
# "s3_secret_access_key": "your_aws_secret_access_key",
# "s3_bucket_name": "your_s3_bucket_name",
# "s3_region_name": "us-east-1",
# "s3_public_url_base": "https://your-custom-domain.com"
# }'

8
.gitignore vendored
View File

@@ -37,6 +37,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
test-cert/
htmlcov/
.tox/
.nox/
@@ -184,9 +185,9 @@ cython_debug/
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
@@ -211,5 +212,6 @@ bancho.py-master/*
.vscode/settings.json
# runtime file
storage/
replays/
osu-master/*
osu-master/*

3
.idea/.gitignore generated vendored
View File

@@ -1,3 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml

View File

@@ -1,17 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ourVersions">
<value>
<list size="4">
<item index="0" class="java.lang.String" itemvalue="3.7" />
<item index="1" class="java.lang.String" itemvalue="3.11" />
<item index="2" class="java.lang.String" itemvalue="3.12" />
<item index="3" class="java.lang.String" itemvalue="3.13" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

10
.idea/misc.xml generated
View File

@@ -1,10 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="osu_lazer_api" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="uv (osu_lazer_api)" project-jdk-type="Python SDK" />
<component name="PythonCompatibilityInspectionAdvertiser">
<option name="version" value="3" />
</component>
</project>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/osu_lazer_api.iml" filepath="$PROJECT_DIR$/.idea/osu_lazer_api.iml" />
</modules>
</component>
</project>

View File

@@ -1,14 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="uv (osu_lazer_api)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@@ -1,140 +0,0 @@
# Lazer API 数据同步指南
本指南将帮助您将现有的 bancho.py 数据库数据同步到新的 Lazer API 专用表中。
## 文件说明
1. **`migrations_old/add_missing_fields.sql`** - 创建 Lazer API 专用表结构
2. **`migrations_old/sync_legacy_data.sql`** - 数据同步脚本
3. **`sync_data.py`** - 交互式数据同步工具
4. **`quick_sync.py`** - 快速同步脚本(使用项目配置)
## 使用方法
### 方法一:快速同步(推荐)
如果您已经配置好了项目的数据库连接,可以直接使用快速同步脚本:
```bash
python quick_sync.py
```
此脚本会:
1. 自动读取项目配置中的数据库连接信息
2. 创建 Lazer API 专用表结构
3. 同步现有数据到新表
### 方法二:交互式同步
如果需要使用不同的数据库连接配置:
```bash
python sync_data.py
```
此脚本会:
1. 交互式地询问数据库连接信息
2. 检查必要表是否存在
3. 显示详细的同步过程和结果
### 方法三:手动执行 SQL
如果您熟悉 SQL 操作,可以手动执行:
```bash
# 1. 创建表结构
mysql -u username -p database_name < migrations_old/add_missing_fields.sql
# 2. 同步数据
mysql -u username -p database_name < migrations_old/sync_legacy_data.sql
```
## 同步内容
### 创建的新表
- `lazer_user_profiles` - 用户扩展资料
- `lazer_user_countries` - 用户国家信息
- `lazer_user_kudosu` - 用户 Kudosu 统计
- `lazer_user_counts` - 用户各项计数统计
- `lazer_user_statistics` - 用户游戏统计(按模式)
- `lazer_user_achievements` - 用户成就
- `lazer_oauth_tokens` - OAuth 访问令牌
- 其他相关表...
### 同步的数据
1. **用户基本信息**
-`users` 表同步基本资料
- 自动转换时间戳格式
- 设置合理的默认值
2. **游戏统计**
-`stats` 表同步各模式的游戏数据
- 计算命中精度和其他衍生统计
3. **用户成就**
-`user_achievements` 表同步成就数据(如果存在)
## 注意事项
1. **安全性**
- 脚本只会创建新表和插入数据
- 不会修改或删除现有的原始表数据
- 使用 `ON DUPLICATE KEY UPDATE` 避免重复插入
2. **兼容性**
- 兼容现有的 bancho.py 数据库结构
- 支持标准的 osu! 数据格式
3. **性能**
- 大量数据可能需要较长时间
- 建议在维护窗口期间执行
## 故障排除
### 常见错误
1. **"Unknown column" 错误**
```
ERROR 1054: Unknown column 'users.is_active' in 'field list'
```
**解决方案**: 确保先执行了 `add_missing_fields.sql` 创建表结构
2. **"Table doesn't exist" 错误**
```
ERROR 1146: Table 'database.users' doesn't exist
```
**解决方案**: 确认数据库中存在 bancho.py 的原始表
3. **连接错误**
```
ERROR 2003: Can't connect to MySQL server
```
**解决方案**: 检查数据库连接配置和权限
### 验证同步结果
同步完成后,可以执行以下查询验证结果:
```sql
-- 检查同步的用户数量
SELECT COUNT(*) FROM lazer_user_profiles;
-- 查看样本数据
SELECT
u.id, u.name,
lup.playmode, lup.is_supporter,
lus.pp, lus.play_count
FROM users u
LEFT JOIN lazer_user_profiles lup ON u.id = lup.user_id
LEFT JOIN lazer_user_statistics lus ON u.id = lus.user_id AND lus.mode = 'osu'
LIMIT 5;
```
## 支持
如果遇到问题,请:
1. 检查日志文件 `data_sync.log`
2. 确认数据库权限
3. 验证原始表数据完整性

View File

@@ -1,28 +1,48 @@
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
WORKDIR /app
ENV UV_PROJECT_ENVIRONMENT=syncvenv
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
pkg-config \
default-libmysqlclient-dev \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY uv.lock .
COPY pyproject.toml .
COPY requirements.txt .
# 安装Python依赖
RUN pip install -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["uv", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
WORKDIR /app
RUN apt-get update \
&& apt-get install -y gcc pkg-config default-libmysqlclient-dev \
&& rm -rf /var/lib/apt/lists/* \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}" \
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1 UV_PROJECT_ENVIRONMENT=/app/.venv
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
COPY pyproject.toml uv.lock ./
COPY packages/ ./packages/
RUN uv sync --frozen --no-dev
RUN uv pip install rosu-pp-py
COPY . .
# ---
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
WORKDIR /app
RUN apt-get update \
&& apt-get install -y curl netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/app/.venv/bin:${PATH}" \
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app /app
COPY docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["uv", "run", "--no-sync", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

48
Dockerfile-osurx Normal file
View File

@@ -0,0 +1,48 @@
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
WORKDIR /app
RUN apt-get update \
&& apt-get install -y gcc pkg-config default-libmysqlclient-dev git \
&& rm -rf /var/lib/apt/lists/* \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}" \
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1 UV_PROJECT_ENVIRONMENT=/app/.venv
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
COPY pyproject.toml uv.lock ./
COPY packages/ ./packages/
RUN uv sync --frozen --no-dev
RUN uv pip install git+https://github.com/ppy-sb/rosu-pp-py.git
COPY . .
# ---
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
WORKDIR /app
RUN apt-get update \
&& apt-get install -y curl netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/app/.venv/bin:${PATH}" \
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app /app
COPY docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["uv", "run", "--no-sync", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 GooGuTeam
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

268
README.md
View File

@@ -6,193 +6,159 @@
- **OAuth 2.0 认证**: 支持密码流和刷新令牌流
- **用户数据管理**: 完整的用户信息、统计数据、成就等
- **多游戏模式支持**: osu!, taiko, fruits, mania
- **多游戏模式支持**: osu! (osu!rx, osu!ap), taiko, fruits, mania
- **数据库持久化**: MySQL 存储用户数据
- **缓存支持**: Redis 缓存令牌和会话信息
- **多种存储后端**: 支持本地存储、Cloudflare R2、AWS S3
- **容器化部署**: Docker 和 Docker Compose 支持
## API 端点
### 认证端点
- `POST /oauth/token` - OAuth 令牌获取/刷新
### 用户端点
- `GET /api/v2/me/{ruleset}` - 获取当前用户信息
### 其他端点
- `GET /` - 根端点
- `GET /health` - 健康检查
## 快速开始
### 使用 Docker Compose (推荐)
1. 克隆项目
```bash
git clone <repository-url>
git clone https://github.com/GooGuTeam/osu_lazer_api.git
cd osu_lazer_api
```
2. 启动服务
2. 创建 `.env` 文件
请参考下方的服务器配置修改 .env 文件
```bash
docker-compose up -d
cp .env.example .env
```
3. 创建示例数据
3. 启动服务
```bash
docker-compose exec api python create_sample_data.py
# 标准服务器
docker-compose -f docker-compose.yml up -d
# 启用 osu!RX 和 osu!AP 模式 (偏偏要上班 pp 算法)
docker-compose -f docker-compose-osurx.yml up -d
```
4. 测试 API
```bash
# 获取访问令牌
curl -X POST http://localhost:8000/oauth/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*"
4. 通过游戏连接服务器
# 使用令牌获取用户信息
curl -X GET http://localhost:8000/api/v2/me/osu \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN"
```
### 本地开发
1. 安装依赖
```bash
pip install -r requirements.txt
```
2. 配置环境变量
```bash
# 复制服务器配置文件
cp .env .env.local
# 复制客户端配置文件(用于测试脚本)
cp .env.client .env.client.local
```
3. 启动 MySQL 和 Redis
```bash
# 使用 Docker
docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0
docker run -d --name redis -p 6379:6379 redis:7-alpine
```
4. 启动应用
```bash
uvicorn main:app --reload
```
## 项目结构
```
osu_lazer_api/
├── app/
│ ├── __init__.py
│ ├── models.py # Pydantic 数据模型
│ ├── database.py # SQLAlchemy 数据库模型
│ ├── config.py # 配置设置
│ ├── dependencies.py # 依赖注入
│ ├── auth.py # 认证和令牌管理
│ └── utils.py # 工具函数
├── main.py # FastAPI 应用主文件
├── create_sample_data.py # 示例数据创建脚本
├── requirements.txt # Python 依赖
├── .env # 环境变量配置
├── docker-compose.yml # Docker Compose 配置
├── Dockerfile # Docker 镜像配置
└── README.md # 项目说明
```
## 示例用户
创建示例数据后,您可以使用以下凭据进行测试:
- **用户名**: `Googujiang`
- **密码**: `password123`
- **用户ID**: `15651670`
使用[自定义的 osu!lazer 客户端](https://github.com/GooGuTeam/osu),或者使用 [LazerAuthlibInjection](https://github.com/MingxuanGame/LazerAuthlibInjection),修改服务器设置为服务器的 IP
## 环境变量配置
项目包含两个环境配置文件:
### 服务器配置 (`.env`)
用于配置 FastAPI 服务器的运行参数:
### 数据库设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` |
| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` |
| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` |
| `MYSQL_HOST` | MySQL 主机地址 | `localhost` |
| `MYSQL_PORT` | MySQL 端口 | `3306` |
| `MYSQL_DATABASE` | MySQL 数据库名 | `osu_api` |
| `MYSQL_USER` | MySQL 用户名 | `osu_api` |
| `MYSQL_PASSWORD` | MySQL 密码 | `password` |
| `MYSQL_ROOT_PASSWORD` | MySQL root 密码 | `password` |
| `REDIS_URL` | Redis 连接字符串 | `redis://127.0.0.1:6379/0` |
### JWT 设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `JWT_SECRET_KEY` | JWT 签名密钥 | `your_jwt_secret_here` |
| `ALGORITHM` | JWT 算法 | `HS256` |
| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` |
| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` |
| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` |
### 服务器设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `HOST` | 服务器监听地址 | `0.0.0.0` |
| `PORT` | 服务器监听端口 | `8000` |
| `DEBUG` | 调试模式 | `True` |
### 客户端配置 (`.env.client`)
用于配置客户端脚本的 API 连接参数:
| `DEBUG` | 调试模式 | `false` |
| `SERVER_URL` | 服务器 URL | `http://localhost:8000` |
| `PRIVATE_API_SECRET` | 私有 API 密钥,用于前后端 API 调用 | `your_private_api_secret_here` |
### OAuth 设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` |
| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` |
| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` |
| `OSU_WEB_CLIENT_ID` | Web OAuth 客户端 ID | `6` |
| `OSU_WEB_CLIENT_SECRET` | Web OAuth 客户端密钥 | `your_osu_web_client_secret_here`
### SignalR 服务器设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `SIGNALR_NEGOTIATE_TIMEOUT` | SignalR 协商超时时间(秒) | `30` |
| `SIGNALR_PING_INTERVAL` | SignalR ping 间隔(秒) | `15` |
### Fetcher 设置
Fetcher 用于从 osu! 官方 API 获取数据,使用 osu! 官方 API 的 OAuth 2.0 认证
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `FETCHER_CLIENT_ID` | Fetcher 客户端 ID | `""` |
| `FETCHER_CLIENT_SECRET` | Fetcher 客户端密钥 | `""` |
| `FETCHER_SCOPES` | Fetcher 权限范围 | `public` |
### 日志设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `LOG_LEVEL` | 日志级别 | `INFO` |
### 游戏设置
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `ENABLE_OSU_RX` | 启用 osu!RX 统计数据 | `false` |
| `ENABLE_OSU_AP` | 启用 osu!AP 统计数据 | `false` |
| `ENABLE_ALL_MODS_PP` | 启用所有 Mod 的 PP 计算 | `false` |
| `ENABLE_SUPPORTER_FOR_ALL_USERS` | 启用所有新注册用户的支持者状态 | `false` |
| `ENABLE_ALL_BEATMAP_LEADERBOARD` | 启用所有谱面的排行榜 | `false` |
| `SEASONAL_BACKGROUNDS` | 季节背景图 URL 列表 | `[]` |
### 存储服务设置
用于存储回放文件、头像等静态资源。
| 变量名 | 描述 | 默认值 |
|--------|------|--------|
| `STORAGE_SERVICE` | 存储服务类型:`local``r2``s3` | `local` |
| `STORAGE_SETTINGS` | 存储服务配置 (JSON 格式),配置见下 | `{"local_storage_path": "./storage"}` |
## 存储服务配置
### 本地存储 (推荐用于开发环境)
本地存储将文件保存在服务器的本地文件系统中,适合开发和小规模部署。
```bash
STORAGE_SERVICE="local"
STORAGE_SETTINGS='{"local_storage_path": "./storage"}'
```
### Cloudflare R2 存储 (推荐用于生产环境)
```bash
STORAGE_SERVICE="r2"
STORAGE_SETTINGS='{
"r2_account_id": "your_cloudflare_account_id",
"r2_access_key_id": "your_r2_access_key_id",
"r2_secret_access_key": "your_r2_secret_access_key",
"r2_bucket_name": "your_bucket_name",
"r2_public_url_base": "https://your-custom-domain.com"
}'
```
### AWS S3 存储
```bash
STORAGE_SERVICE="s3"
STORAGE_SETTINGS='{
"s3_access_key_id": "your_aws_access_key_id",
"s3_secret_access_key": "your_aws_secret_access_key",
"s3_bucket_name": "your_s3_bucket_name",
"s3_region_name": "us-east-1",
"s3_public_url_base": "https://your-custom-domain.com"
}'
```
> **注意**: 在生产环境中,请务必更改默认的密钥和密码!
## API 使用示例
### 获取访问令牌
```bash
curl -X POST http://localhost:8000/oauth/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*"
```
响应:
```json
{
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...",
"token_type": "Bearer",
"expires_in": 86400,
"refresh_token": "abc123...",
"scope": "*"
}
```
### 获取用户信息
```bash
curl -X GET http://localhost:8000/api/v2/me/osu \
-H "Authorization: Bearer YOUR_ACCESS_TOKEN"
```
### 刷新令牌
```bash
curl -X POST http://localhost:8000/oauth/token \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
```
## 开发
### 添加新用户
您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。
### 扩展功能
- 添加更多 API 端点(排行榜、谱面信息等)
- 实现实时功能WebSocket
- 添加管理面板
- 实现数据导入/导出功能
### 迁移数据库
### 更新数据库
参考[数据库迁移指南](./MIGRATE_GUIDE.md)

View File

@@ -15,6 +15,7 @@ from app.log import logger
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -125,12 +126,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
minutes=settings.access_token_expire_minutes
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
to_encode, settings.secret_key, algorithm=settings.algorithm
)
return encoded_jwt
@@ -146,7 +147,7 @@ def verify_token(token: str) -> dict | None:
"""验证访问令牌"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
token, settings.secret_key, algorithms=[settings.algorithm]
)
return payload
except JWTError:
@@ -156,6 +157,8 @@ def verify_token(token: str) -> dict | None:
async def store_token(
db: AsyncSession,
user_id: int,
client_id: int,
scopes: list[str],
access_token: str,
refresh_token: str,
expires_in: int,
@@ -164,7 +167,9 @@ async def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
statement = select(OAuthToken).where(
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
)
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
await db.delete(token)
@@ -179,7 +184,9 @@ async def store_token(
# 创建新令牌记录
token_record = OAuthToken(
user_id=user_id,
client_id=client_id,
access_token=access_token,
scope=",".join(scopes),
refresh_token=refresh_token,
expires_at=expires_at,
)
@@ -209,3 +216,18 @@ async def get_token_by_refresh_token(
OAuthToken.expires_at > datetime.utcnow(),
)
return (await db.exec(statement)).first()
async def get_user_by_authorization_code(
db: AsyncSession, redis: Redis, client_id: int, code: str
) -> tuple[User, list[str]] | None:
user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues]
scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues]
if not user_id or not scopes:
return None
await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues]
statement = select(User).where(User.id == int(user_id))
user = (await db.exec(statement)).first()
return (user, scopes.split(",")) if user else None

View File

@@ -7,7 +7,15 @@ from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod
from app.models.score import GameMode
import rosu_pp_py as rosu
try:
import rosu_pp_py as rosu
except ImportError:
raise ImportError(
"rosu-pp-py is not installed. "
"Please install it.\n"
" Official: uv add rosu-pp-py\n"
" ppy-sb: uv add git+https://github.com/ppy-sb/rosu-pp-py.git"
)
if TYPE_CHECKING:
from app.database.score import Score
@@ -51,8 +59,6 @@ def calculate_pp(
) -> float:
map = rosu.Beatmap(content=beatmap)
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
if map.is_suspicious():
return 0.0
perf = rosu.Performance(
mods=score.mods,
lazer=True,
@@ -67,7 +73,6 @@ def calculate_pp(
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
hitresult_priority=rosu.HitResultPriority.Fastest,
)
attrs = perf.calculate(map)
return attrs.pp

View File

@@ -1,51 +1,133 @@
from __future__ import annotations
import os
from enum import Enum
from typing import Annotated, Any
from dotenv import load_dotenv
load_dotenv()
from pydantic import Field, HttpUrl, ValidationInfo, field_validator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
class Settings:
class AWSS3StorageSettings(BaseSettings):
s3_access_key_id: str
s3_secret_access_key: str
s3_bucket_name: str
s3_region_name: str
s3_public_url_base: str | None = None
class CloudflareR2Settings(BaseSettings):
r2_account_id: str
r2_access_key_id: str
r2_secret_access_key: str
r2_bucket_name: str
r2_public_url_base: str | None = None
class LocalStorageSettings(BaseSettings):
local_storage_path: str = "./storage"
class StorageServiceType(str, Enum):
LOCAL = "local"
CLOUDFLARE_R2 = "r2"
AWS_S3 = "s3"
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 数据库设置
DATABASE_URL: str = os.getenv(
"DATABASE_URL", "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
)
REDIS_URL: str = os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0")
mysql_host: str = "localhost"
mysql_port: int = 3306
mysql_database: str = "osu_api"
mysql_user: str = "osu_api"
mysql_password: str = "password"
mysql_root_password: str = "password"
redis_url: str = "redis://127.0.0.1:6379/0"
@property
def database_url(self) -> str:
return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}"
# JWT 设置
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
)
secret_key: str = Field(default="your_jwt_secret_here", alias="jwt_secret_key")
algorithm: str = "HS256"
access_token_expire_minutes: int = 1440
# OAuth 设置
OSU_CLIENT_ID: str = os.getenv("OSU_CLIENT_ID", "5")
OSU_CLIENT_SECRET: str = os.getenv(
"OSU_CLIENT_SECRET", "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
)
osu_client_id: int = 5
osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
osu_web_client_id: int = 6
osu_web_client_secret: str = "your_osu_web_client_secret_here"
# 服务器设置
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8000"))
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
host: str = "0.0.0.0"
port: int = 8000
debug: bool = False
private_api_secret: str = "your_private_api_secret_here"
server_url: HttpUrl = HttpUrl("http://localhost:8000")
# SignalR 设置
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15"))
signalr_negotiate_timeout: int = 30
signalr_ping_interval: int = 15
# Fetcher 设置
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "")
FETCHER_CLIENT_SECRET: str = os.getenv("FETCHER_CLIENT_SECRET", "")
FETCHER_SCOPES: list[str] = os.getenv("FETCHER_SCOPES", "public").split(",")
FETCHER_CALLBACK_URL: str = os.getenv(
"FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback"
)
fetcher_client_id: str = ""
fetcher_client_secret: str = ""
fetcher_scopes: Annotated[list[str], NoDecode] = ["public"]
@property
def fetcher_callback_url(self) -> str:
return f"{self.server_url}fetcher/callback"
# 日志设置
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper()
log_level: str = "INFO"
# 游戏设置
enable_osu_rx: bool = False
enable_osu_ap: bool = False
enable_all_mods_pp: bool = False
enable_supporter_for_all_users: bool = False
enable_all_beatmap_leaderboard: bool = False
seasonal_backgrounds: list[str] = []
# 存储设置
storage_service: StorageServiceType = StorageServiceType.LOCAL
storage_settings: (
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
) = LocalStorageSettings()
@field_validator("fetcher_scopes", mode="before")
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
@field_validator("storage_settings", mode="after")
def validate_storage_settings(
cls,
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
info: ValidationInfo,
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
if not isinstance(v, CloudflareR2Settings):
raise ValueError(
"When storage_service is 'r2', "
"storage_settings must be CloudflareR2Settings"
)
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
if not isinstance(v, LocalStorageSettings):
raise ValueError(
"When storage_service is 'local', "
"storage_settings must be LocalStorageSettings"
)
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
if not isinstance(v, AWSS3StorageSettings):
raise ValueError(
"When storage_service is 's3', "
"storage_settings must be AWSS3StorageSettings"
)
return v
settings = Settings()

View File

@@ -1,5 +1,5 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken
from .auth import OAuthClient, OAuthToken
from .beatmap import (
Beatmap as Beatmap,
BeatmapResp as BeatmapResp,
@@ -10,16 +10,33 @@ from .beatmapset import (
BeatmapsetResp as BeatmapsetResp,
)
from .best_score import BestScore
from .counts import (
CountResp,
MonthlyPlaycounts,
ReplayWatchedCount,
)
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
User,
UserResp,
)
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .playlist_attempts import (
ItemAttemptsCount,
ItemAttemptsResp,
PlaylistAggregateScore,
)
from .playlist_best_score import PlaylistBestScore
from .playlists import Playlist, PlaylistResp
from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
from .room import APIUploadedRoom, Room, RoomResp
from .room_participated_user import RoomParticipatedUser
from .score import (
MultiplayerScores,
Score,
ScoreAround,
ScoreBase,
ScoreResp,
ScoreStatistics,
@@ -37,21 +54,39 @@ from .user_account_history import (
)
__all__ = [
"APIUploadedRoom",
"Beatmap",
"BeatmapPlaycounts",
"BeatmapPlaycountsResp",
"Beatmapset",
"BeatmapsetResp",
"BestScore",
"CountResp",
"DailyChallengeStats",
"DailyChallengeStatsResp",
"FavouriteBeatmapset",
"ItemAttemptsCount",
"ItemAttemptsResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
"MultiplayerEventResp",
"MultiplayerScores",
"OAuthClient",
"OAuthToken",
"PPBestScore",
"Playlist",
"PlaylistAggregateScore",
"PlaylistBestScore",
"PlaylistResp",
"Relationship",
"RelationshipResp",
"RelationshipType",
"ReplayWatchedCount",
"Room",
"RoomParticipatedUser",
"RoomResp",
"Score",
"ScoreAround",
"ScoreBase",
"ScoreResp",
"ScoreStatistics",

View File

@@ -1,10 +1,11 @@
from datetime import datetime
import secrets
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .lazer_user import User
@@ -17,6 +18,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)
@@ -27,3 +29,13 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
)
user: "User" = Relationship()
class OAuthClient(SQLModel, table=True):
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
client_id: int | None = Field(default=None, primary_key=True, index=True)
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)

View File

@@ -1,14 +1,14 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.config import settings
from app.models.beatmap import BeatmapRankStatus
from app.models.model import UTCBaseModel
from app.models.score import MODE_TO_INT, GameMode
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy import Column, DateTime
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -23,14 +23,12 @@ class BeatmapOwner(SQLModel):
username: str
class BeatmapBase(SQLModel, UTCBaseModel):
class BeatmapBase(SQLModel):
# Beatmap
url: str
mode: GameMode
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(
default=0.0, sa_column=Column(DECIMAL(precision=10, scale=6))
)
difficulty_rating: float = Field(default=0.0)
total_length: int
user_id: int
version: str
@@ -42,17 +40,11 @@ class BeatmapBase(SQLModel, UTCBaseModel):
# TODO: failtimes, owners
# BeatmapExtended
ar: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
cs: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
drain: float = Field(
default=0.0,
sa_column=Column(DECIMAL(precision=10, scale=2)),
) # hp
accuracy: float = Field(
default=0.0,
sa_column=Column(DECIMAL(precision=10, scale=2)),
) # od
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
ar: float = Field(default=0.0)
cs: float = Field(default=0.0)
drain: float = Field(default=0.0) # hp
accuracy: float = Field(default=0.0) # od
bpm: float = Field(default=0.0)
count_circles: int = Field(default=0)
count_sliders: int = Field(default=0)
count_spinners: int = Field(default=0)
@@ -63,7 +55,7 @@ class BeatmapBase(SQLModel, UTCBaseModel):
class Beatmap(BeatmapBase, table=True):
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
@@ -71,10 +63,6 @@ class Beatmap(BeatmapBase, table=True):
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property
def can_ranked(self) -> bool:
return self.beatmap_status > BeatmapRankStatus.PENDING
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
@@ -170,11 +158,19 @@ class BeatmapResp(BeatmapBase):
from .score import Score
beatmap_ = beatmap.model_dump()
beatmap_status = beatmap.beatmap_status
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap.beatmap_status > BeatmapRankStatus.PENDING
beatmap_["status"] = beatmap.beatmap_status.name.lower()
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
beatmap_["status"] = beatmap_status.name.lower()
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(

View File

@@ -1,58 +1,38 @@
from datetime import datetime
from typing import TYPE_CHECKING, TypedDict, cast
from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.config import settings
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.model import UTCBaseModel
from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
class BeatmapCovers(SQLModel):
cover: str
card: str
list: str
slimcover: str
cover_2_x: str | None = Field(default=None, alias="cover@2x")
card_2_x: str | None = Field(default=None, alias="card@2x")
list_2_x: str | None = Field(default=None, alias="list@2x")
slimcover_2_x: str | None = Field(default=None, alias="slimcover@2x")
@model_serializer
def _(self) -> dict[str, str | None]:
self = cast(dict[str, str | None] | BeatmapCovers, self)
if isinstance(self, dict):
return {
"cover": self["cover"],
"card": self["card"],
"list": self["list"],
"slimcover": self["slimcover"],
"cover@2x": self.get("cover@2x"),
"card@2x": self.get("card@2x"),
"list@2x": self.get("list@2x"),
"slimcover@2x": self.get("slimcover@2x"),
}
else:
return {
"cover": self.cover,
"card": self.card,
"list": self.list,
"slimcover": self.slimcover,
"cover@2x": self.cover_2_x,
"card@2x": self.card_2_x,
"list@2x": self.list_2_x,
"slimcover@2x": self.slimcover_2_x,
}
BeatmapCovers = TypedDict(
"BeatmapCovers",
{
"cover": str,
"card": str,
"list": str,
"slimcover": str,
"cover@2x": NotRequired[str | None],
"card@2x": NotRequired[str | None],
"list@2x": NotRequired[str | None],
"slimcover@2x": NotRequired[str | None],
},
)
class BeatmapHype(BaseModel):
@@ -74,12 +54,12 @@ class BeatmapNomination(TypedDict):
beatmapset_id: int
reset: bool
user_id: int
rulesets: list[GameMode] | None
rulesets: NotRequired[list[GameMode] | None]
class BeatmapDescription(SQLModel):
bbcode: str | None = None
description: str | None = None
class BeatmapDescription(TypedDict):
bbcode: NotRequired[str | None]
description: NotRequired[str | None]
class BeatmapTranslationText(BaseModel):
@@ -87,7 +67,7 @@ class BeatmapTranslationText(BaseModel):
id: int | None = None
class BeatmapsetBase(SQLModel, UTCBaseModel):
class BeatmapsetBase(SQLModel):
# Beatmapset
artist: str = Field(index=True)
artist_unicode: str = Field(index=True)
@@ -121,7 +101,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
track_id: int | None = Field(default=None) # feature artist?
# BeatmapsetExtended
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
bpm: float = Field(default=0.0)
can_be_hyped: bool = Field(default=False)
discussion_locked: bool = Field(default=False)
last_updated: datetime = Field(sa_column=Column(DateTime))
@@ -181,11 +161,24 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
"download_disabled": resp.availability.download_disabled or False,
}
)
session.add(beatmapset)
await session.commit()
if not (
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid)
if not beatmapset:
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
return beatmapset
class BeatmapsetResp(BeatmapsetBase):
id: int
@@ -193,7 +186,7 @@ class BeatmapsetResp(BeatmapsetBase):
discussion_enabled: bool = True
status: str
ranked: int
legacy_thread_url: str = ""
legacy_thread_url: str | None = ""
is_scoreable: bool
hype: BeatmapHype | None = None
availability: BeatmapAvailability
@@ -239,11 +232,21 @@ class BeatmapsetResp(BeatmapsetBase):
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
**beatmapset.model_dump(),
}
beatmap_status = beatmapset.beatmap_status
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
update["status"] = beatmap_status.name.lower()
update["ranked"] = beatmap_status.value
if session and user:
existing_favourite = (
await session.exec(

View File

@@ -29,9 +29,7 @@ class BestScore(SQLModel, table=True):
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
total_score: int = Field(
default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score"))
)
total_score: int = Field(default=0, sa_column=Column(BigInteger))
mods: list[str] = Field(
default_factory=list,
sa_column=Column(JSON),

View File

@@ -14,7 +14,13 @@ if TYPE_CHECKING:
from .lazer_user import User
class MonthlyPlaycounts(SQLModel, table=True):
class CountBase(SQLModel):
year: int = Field(index=True)
month: int = Field(index=True)
count: int = Field(default=0)
class MonthlyPlaycounts(CountBase, table=True):
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
@@ -24,20 +30,29 @@ class MonthlyPlaycounts(SQLModel, table=True):
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
year: int = Field(index=True)
month: int = Field(index=True)
playcount: int = Field(default=0)
user: "User" = Relationship(back_populates="monthly_playcounts")
class MonthlyPlaycountsResp(SQLModel):
class ReplayWatchedCount(CountBase, table=True):
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user: "User" = Relationship(back_populates="replays_watched_counts")
class CountResp(SQLModel):
start_date: date
count: int
@classmethod
def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp":
def from_db(cls, db_model: CountBase) -> "CountResp":
return cls(
start_date=date(db_model.year, db_model.month, 1),
count=db_model.playcount,
count=db_model.count,
)

View File

@@ -1,4 +1,4 @@
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.models.model import UTCBaseModel
@@ -6,8 +6,9 @@ from app.models.score import GameMode
from app.models.user import Country, Page, RankHistory
from .achievement import UserAchievement, UserAchievementResp
from .beatmap_playcounts import BeatmapPlaycounts
from .counts import CountResp, MonthlyPlaycounts, ReplayWatchedCount
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp
from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
@@ -21,6 +22,7 @@ from sqlmodel import (
Field,
Relationship,
SQLModel,
col,
func,
select,
)
@@ -74,7 +76,6 @@ class UserBase(UTCBaseModel, SQLModel):
username: str = Field(max_length=32, unique=True, index=True)
page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
# TODO: replays_watched_counts
support_level: int = 0
badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
@@ -144,6 +145,9 @@ class User(AsyncAttrs, UserBase, table=True):
back_populates="user"
)
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
back_populates="user"
)
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
@@ -164,7 +168,7 @@ class UserResp(UserBase):
is_online: bool = False
groups: list = [] # TODO
country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
favourite_beatmapset_count: int = 0 # TODO
favourite_beatmapset_count: int = 0
graveyard_beatmapset_count: int = 0 # TODO
guest_beatmapset_count: int = 0 # TODO
loved_beatmapset_count: int = 0 # TODO
@@ -176,13 +180,15 @@ class UserResp(UserBase):
follower_count: int = 0
friends: list["RelationshipResp"] | None = None
scores_best_count: int = 0
scores_first_count: int = 0
scores_first_count: int = 0 # TODO
scores_recent_count: int = 0
scores_pinned_count: int = 0
beatmap_playcounts_count: int = 0
account_history: list[UserAccountHistoryResp] = []
active_tournament_banners: list[dict] = [] # TODO
kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list)
monthly_playcounts: list[CountResp] = Field(default_factory=list)
replay_watched_counts: list[CountResp] = Field(default_factory=list)
unread_pm_count: int = 0 # TODO
rank_history: RankHistory | None = None # TODO
rank_highest: RankHighest | None = None # TODO
@@ -207,7 +213,11 @@ class UserResp(UserBase):
from app.dependencies.database import get_redis
from .best_score import BestScore
from .favourite_beatmapset import FavouriteBeatmapset
from .relationship import Relationship, RelationshipResp, RelationshipType
from .score import Score
ruleset = ruleset or obj.playmode
u = cls.model_validate(obj.model_dump())
u.id = obj.id
@@ -275,7 +285,7 @@ class UserResp(UserBase):
if "statistics" in include:
current_stattistics = None
for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode):
if i.mode == ruleset:
current_stattistics = i
break
u.statistics = (
@@ -292,16 +302,74 @@ class UserResp(UserBase):
if "monthly_playcounts" in include:
u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc)
CountResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if "replays_watched_counts" in include:
u.replay_watched_counts = [
CountResp.from_db(rwc)
for rwc in await obj.awaitable_attrs.replays_watched_counts
]
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
u.favourite_beatmapset_count = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.user_id == obj.id)
)
).one()
u.scores_pinned_count = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.user_id == obj.id,
Score.pinned_order > 0,
Score.gamemode == ruleset,
col(Score.passed).is_(True),
)
)
).one()
u.scores_best_count = (
await session.exec(
select(func.count())
.select_from(BestScore)
.where(
BestScore.user_id == obj.id,
BestScore.gamemode == ruleset,
)
.limit(200)
)
).one()
u.scores_recent_count = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.user_id == obj.id,
Score.gamemode == ruleset,
col(Score.passed).is_(True),
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
)
)
).one()
u.beatmap_playcounts_count = (
await session.exec(
select(func.count())
.select_from(BeatmapPlaycounts)
.where(
BeatmapPlaycounts.user_id == obj.id,
)
)
).one()
return u
@@ -314,6 +382,7 @@ ALL_INCLUDED = [
"statistics_rulesets",
"achievements",
"monthly_playcounts",
"replays_watched_counts",
]
@@ -324,6 +393,7 @@ SEARCH_INCLUDED = [
"statistics_rulesets",
"achievements",
"monthly_playcounts",
"replays_watched_counts",
]
BASE_INCLUDES = [

View File

@@ -0,0 +1,56 @@
from datetime import UTC, datetime
from typing import Any
from app.models.model import UTCBaseModel
from sqlmodel import (
JSON,
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
SQLModel,
)
class MultiplayerEventBase(SQLModel, UTCBaseModel):
playlist_item_id: int | None = None
user_id: int | None = Field(
default=None,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
)
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
event_type: str = Field(index=True)
class MultiplayerEvent(MultiplayerEventBase, table=True):
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)
room_id: int = Field(foreign_key="rooms.id", index=True)
updated_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
event_detail: dict[str, Any] | None = Field(
sa_column=Column(JSON),
default_factory=dict,
)
class MultiplayerEventResp(MultiplayerEventBase):
id: int
@classmethod
def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
return cls.model_validate(event)

View File

@@ -0,0 +1,152 @@
from .lazer_user import User, UserResp
from .playlist_best_score import PlaylistBestScore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
accuracy: float = 0.0
pp: float = 0
total_score: int = 0
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user: User = Relationship()
async def get_position(self, session: AsyncSession) -> int:
rownum = (
func.row_number()
.over(
partition_by=col(ItemAttemptsCountBase.room_id),
order_by=col(ItemAttemptsCountBase.total_score).desc(),
)
.label("rn")
)
subq = select(ItemAttemptsCountBase, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
result = await session.exec(stmt)
return result.one()
async def update(self, session: AsyncSession):
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == self.room_id,
PlaylistBestScore.user_id == self.user_id,
)
)
).all()
self.attempts = sum(score.attempts for score in playlist_scores)
self.total_score = sum(score.total_score for score in playlist_scores)
self.pp = sum(score.score.pp for score in playlist_scores)
self.completed = len(playlist_scores)
self.accuracy = (
sum(score.score.accuracy for score in playlist_scores) / self.completed
if self.completed > 0
else 0.0
)
await session.commit()
await session.refresh(self)
@classmethod
async def get_or_create(
cls,
room_id: int,
user_id: int,
session: AsyncSession,
) -> "ItemAttemptsCount":
item_attempts = await session.exec(
select(cls).where(
cls.room_id == room_id,
cls.user_id == user_id,
)
)
item_attempts = item_attempts.first()
if item_attempts is None:
item_attempts = cls(room_id=room_id, user_id=user_id)
session.add(item_attempts)
await session.commit()
await session.refresh(item_attempts)
await item_attempts.update(session)
return item_attempts
class ItemAttemptsResp(ItemAttemptsCountBase):
user: UserResp | None = None
position: int | None = None
@classmethod
async def from_db(
cls,
item_attempts: ItemAttemptsCount,
session: AsyncSession,
include: list[str] = [],
) -> "ItemAttemptsResp":
resp = cls.model_validate(item_attempts.model_dump())
resp.user = await UserResp.from_db(
await item_attempts.awaitable_attrs.user,
session=session,
include=["statistics", "team", "daily_challenge_user_stats"],
)
if "position" in include:
resp.position = await item_attempts.get_position(session)
# resp.accuracy *= 100
return resp
class ItemAttemptsCountForItem(BaseModel):
id: int
attempts: int
passed: bool
class PlaylistAggregateScore(BaseModel):
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
@classmethod
async def from_db(
cls,
room_id: int,
user_id: int,
session: AsyncSession,
) -> "PlaylistAggregateScore":
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.user_id == user_id,
)
)
).all()
playlist_item_attempts = []
for score in playlist_scores:
playlist_item_attempts.append(
ItemAttemptsCountForItem(
id=score.playlist_id,
attempts=score.attempts,
passed=score.score.passed,
)
)
return cls(playlist_item_attempts=playlist_item_attempts)

View File

@@ -0,0 +1,110 @@
from typing import TYPE_CHECKING
from .lazer_user import User
from redis.asyncio import Redis
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .score import Score
class PlaylistBestScore(SQLModel, table=True):
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
room_id: int = Field(foreign_key="rooms.id", index=True)
playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))
attempts: int = Field(default=0) # playlist
user: User = Relationship()
score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[PlaylistBestScore.score_id]",
"lazy": "joined",
}
)
async def process_playlist_best_score(
room_id: int,
playlist_id: int,
user_id: int,
score_id: int,
total_score: int,
session: AsyncSession,
redis: Redis,
):
previous = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.user_id == user_id,
)
)
).first()
if previous is None:
previous = PlaylistBestScore(
user_id=user_id,
score_id=score_id,
room_id=room_id,
playlist_id=playlist_id,
total_score=total_score,
)
session.add(previous)
elif not previous.score.passed or previous.total_score < total_score:
previous.score_id = score_id
previous.total_score = total_score
previous.attempts += 1
await session.commit()
if await redis.exists(f"multiplayer:{room_id}:gameplay:players"):
await redis.decr(f"multiplayer:{room_id}:gameplay:players")
async def get_position(
room_id: int,
playlist_id: int,
score_id: int,
session: AsyncSession,
) -> int:
rownum = (
func.row_number()
.over(
partition_by=(
col(PlaylistBestScore.playlist_id),
col(PlaylistBestScore.room_id),
),
order_by=col(PlaylistBestScore.total_score).desc(),
)
.label("row_number")
)
subq = (
select(PlaylistBestScore, rownum)
.where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
.subquery()
)
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0

143
app/database/playlists.py Normal file
View File

@@ -0,0 +1,143 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
from sqlmodel import (
JSON,
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
class PlaylistBase(SQLModel, UTCBaseModel):
id: int = Field(index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
ruleset_id: int = Field(ge=0, le=3)
expired: bool = Field(default=False)
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
)
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
)
required_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
)
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
freestyle: bool = Field(default=False)
class Playlist(PlaylistBase, table=True):
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
"lazy": "joined",
}
)
room: "Room" = Relationship()
@classmethod
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
cls.room_id == room_id
)
result = await session.exec(stmt)
return result.one()
@classmethod
async def from_hub(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
owner_id=playlist.owner_id,
ruleset_id=playlist.ruleset_id,
beatmap_id=playlist.beatmap_id,
required_mods=playlist.required_mods,
allowed_mods=playlist.allowed_mods,
expired=playlist.expired,
playlist_order=playlist.playlist_order,
played_at=playlist.played_at,
freestyle=playlist.freestyle,
room_id=room_id,
)
@classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
)
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
db_playlist.owner_id = playlist.owner_id
db_playlist.ruleset_id = playlist.ruleset_id
db_playlist.beatmap_id = playlist.beatmap_id
db_playlist.required_mods = playlist.required_mods
db_playlist.allowed_mods = playlist.allowed_mods
db_playlist.expired = playlist.expired
db_playlist.playlist_order = playlist.playlist_order
db_playlist.played_at = playlist.played_at
db_playlist.freestyle = playlist.freestyle
await session.commit()
@classmethod
async def add_to_db(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
):
db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
await session.refresh(db_playlist)
playlist.id = db_playlist.id
@classmethod
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == item_id, cls.room_id == room_id)
)
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(
cls, playlist: Playlist, include: list[str] = []
) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
resp = cls.model_validate(data)
return resp

View File

@@ -1,6 +1,177 @@
from sqlmodel import Field, SQLModel
from datetime import UTC, datetime
from app.database.playlist_attempts import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel
from app.models.multiplayer_hub import ServerMultiplayerRoom
from app.models.room import (
MatchType,
QueueMode,
RoomCategory,
RoomDifficultyRange,
RoomPlaylistItemStats,
RoomStatus,
)
from .lazer_user import User, UserResp
from .playlists import Playlist, PlaylistResp
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
class RoomIndex(SQLModel, table=True):
__tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
class RoomBase(SQLModel, UTCBaseModel):
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
ends_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=None,
)
participant_count: int = Field(default=0)
max_attempts: int | None = Field(default=None) # playlists
type: MatchType
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
status: RoomStatus
# TODO: channel_id
class Room(AsyncAttrs, RoomBase, table=True):
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class RoomResp(RoomBase):
id: int
has_password: bool = False
host: UserResp | None = None
playlist: list[PlaylistResp] = []
playlist_item_stats: RoomPlaylistItemStats | None = None
difficulty_range: RoomDifficultyRange | None = None
current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list)
@classmethod
async def from_db(
cls,
room: Room,
session: AsyncSession,
include: list[str] = [],
user: User | None = None,
) -> "RoomResp":
resp = cls.model_validate(room.model_dump())
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
min=0,
max=0,
)
rulesets = set()
for playlist in room.playlist:
stats.count_total += 1
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(
difficulty_range.min, playlist.beatmap.difficulty_rating
)
difficulty_range.max = max(
difficulty_range.max, playlist.beatmap.difficulty_rating
)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
resp.difficulty_range = difficulty_range
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
resp.recent_participants = []
for recent_participant in await session.exec(
select(RoomParticipatedUser)
.where(
RoomParticipatedUser.room_id == room.id,
col(RoomParticipatedUser.left_at).is_(None),
)
.limit(8)
.order_by(col(RoomParticipatedUser.joined_at).desc())
):
resp.recent_participants.append(
await UserResp.from_db(
await recent_participant.awaitable_attrs.user,
session,
include=["statistics"],
)
)
resp.host = await UserResp.from_db(
await room.awaitable_attrs.host, session, include=["statistics"]
)
if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db(
room.id, user.id, session
)
return resp
@classmethod
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
room = server_room.room
resp = cls(
id=room.room_id,
name=room.settings.name,
type=room.settings.match_type,
queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip,
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
status=server_room.status,
category=server_room.category,
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
)
return resp
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:
"""
将 APIUploadedRoom 转换为 Room 对象playlist 字段需单独处理。
"""
room_dict = self.model_dump()
room_dict.pop("playlist", None)
# host_id 已在字段中
return Room(**room_dict)
id: int | None
host_id: int | None = None
playlist: list[Playlist] = Field(default_factory=list)

View File

@@ -0,0 +1,39 @@
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .lazer_user import User
from .room import Room
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
)
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
)
joined_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False),
default=datetime.now(UTC),
)
left_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
)
room: "Room" = Relationship()
user: "User" = Relationship()

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from datetime import UTC, date, datetime
import json
import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from app.calculator import (
calculate_pp,
@@ -13,8 +13,14 @@ from app.calculator import (
calculate_weighted_pp,
clamp,
)
from app.config import settings
from app.database.team import TeamMember
from app.models.model import UTCBaseModel
from app.models.model import (
CurrentUserAttributes,
PinAttributes,
RespWithCursor,
UTCBaseModel,
)
from app.models.mods import APIMod, mods_can_get_pp
from app.models.score import (
INT_TO_MODE,
@@ -31,8 +37,8 @@ from .beatmap import Beatmap, BeatmapResp
from .beatmap_playcounts import process_beatmap_playcount
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .counts import MonthlyPlaycounts
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
from .pp_best_score import PPBestScore
from .relationship import (
Relationship as DBRelationship,
@@ -89,10 +95,11 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
default=0, sa_column=Column(BigInteger), exclude=True
)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
# optional
# TODO: current_user_attributes
position: int | None = Field(default=None) # multiplayer
# position: int | None = Field(default=None) # multiplayer
class Score(ScoreBase, table=True):
@@ -100,7 +107,6 @@ class Score(ScoreBase, table=True):
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
user_id: int = Field(
default=None,
sa_column=Column(
@@ -121,6 +127,7 @@ class Score(ScoreBase, table=True):
nslider_tail_hit: int | None = Field(default=None, exclude=True)
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
# optional
beatmap: Beatmap = Relationship()
@@ -163,6 +170,9 @@ class ScoreResp(ScoreBase):
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
position: int | None = None
scores_around: "ScoreAround | None" = None
current_user_attributes: CurrentUserAttributes | None = None
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
@@ -231,9 +241,22 @@ class ScoreResp(ScoreBase):
)
or None
)
s.current_user_attributes = CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
)
return s
class MultiplayerScores(RespWithCursor):
scores: list[ScoreResp] = Field(default_factory=list)
params: dict[str, Any] = Field(default_factory=dict)
class ScoreAround(SQLModel):
higher: MultiplayerScores | None = None
lower: MultiplayerScores | None = None
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
@@ -312,6 +335,13 @@ async def get_leaderboard(
user: User | None = None,
limit: int = 50,
) -> tuple[list[Score], Score | None]:
is_rx = "RX" in (mods or [])
is_ap = "AP" in (mods or [])
if settings.enable_osu_rx and is_rx:
mode = GameMode.OSURX
elif settings.enable_osu_ap and is_ap:
mode = GameMode.OSUAP
wheres = await _score_where(type, beatmap, mode, mods, user)
if wheres is None:
return [], None
@@ -329,6 +359,10 @@ async def get_leaderboard(
self_query = (
select(BestScore)
.where(BestScore.user_id == user.id)
.where(
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
)
.order_by(col(BestScore.total_score).desc())
.limit(1)
)
@@ -461,12 +495,13 @@ async def get_user_best_pp_in_beatmap(
async def get_user_best_pp(
session: AsyncSession,
user: int,
mode: GameMode,
limit: int = 200,
) -> Sequence[PPBestScore]:
return (
await session.exec(
select(PPBestScore)
.where(PPBestScore.user_id == user)
.where(PPBestScore.user_id == user, PPBestScore.gamemode == mode)
.order_by(col(PPBestScore.pp).desc())
.limit(limit)
)
@@ -474,7 +509,7 @@ async def get_user_best_pp(
async def process_user(
session: AsyncSession, user: User, score: Score, ranked: bool = False
session: AsyncSession, user: User, score: Score, length: int, ranked: bool = False
):
assert user.id
assert score.id
@@ -577,8 +612,8 @@ async def process_user(
)
)
statistics.play_count += 1
mouthly_playcount.playcount += 1
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
mouthly_playcount.count += 1
statistics.play_time += length
statistics.count_100 += score.n100 + score.nkatu
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
@@ -588,7 +623,7 @@ async def process_user(
)
if score.passed and ranked:
best_pp_scores = await get_user_best_pp(session, user.id)
best_pp_scores = await get_user_best_pp(session, user.id, score.gamemode)
pp_sum = 0.0
acc_sum = 0.0
for i, bp in enumerate(best_pp_scores):
@@ -616,9 +651,19 @@ async def process_score(
fetcher: "Fetcher",
session: AsyncSession,
redis: Redis,
item_id: int | None = None,
room_id: int | None = None,
) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
acronyms = [mod["acronym"] for mod in info.mods]
is_rx = "RX" in acronyms
is_ap = "AP" in acronyms
gamemode = INT_TO_MODE[info.ruleset_id]
if settings.enable_osu_rx and is_rx and gamemode == GameMode.OSU:
gamemode = GameMode.OSURX
elif settings.enable_osu_ap and is_ap and gamemode == GameMode.OSU:
gamemode = GameMode.OSUAP
score = Score(
accuracy=info.accuracy,
max_combo=info.max_combo,
@@ -630,7 +675,7 @@ async def process_score(
total_score_without_mods=info.total_score_without_mods,
beatmap_id=beatmap_id,
ended_at=datetime.now(UTC),
gamemode=INT_TO_MODE[info.ruleset_id],
gamemode=gamemode,
started_at=score_token.created_at,
user_id=user.id,
preserve=info.passed,
@@ -647,6 +692,8 @@ async def process_score(
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
playlist_item_id=item_id,
room_id=room_id,
)
if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
@@ -678,4 +725,5 @@ async def process_score(
await session.refresh(score)
await session.refresh(score_token)
await session.refresh(user)
await redis.publish("score:processed", score.id)
return score

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from contextvars import ContextVar
import json
from app.config import settings
@@ -18,23 +19,36 @@ def json_serializer(value):
# 数据库引擎
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
engine = create_async_engine(settings.database_url, json_serializer=json_serializer)
# Redis 连接
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# 数据库依赖
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
"db_session_context", default=None
)
async def get_db():
async with AsyncSession(engine) as session:
session = db_session_context.get()
if session is None:
session = AsyncSession(engine)
db_session_context.set(session)
try:
yield session
finally:
await session.close()
db_session_context.set(None)
else:
yield session
async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# Redis 依赖
def get_redis():
return redis_client
def get_redis_pubsub():
return redis_client.pubsub()

View File

@@ -12,10 +12,10 @@ async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
settings.FETCHER_CLIENT_ID,
settings.FETCHER_CLIENT_SECRET,
settings.FETCHER_SCOPES,
settings.FETCHER_CALLBACK_URL,
settings.fetcher_client_id,
settings.fetcher_client_secret,
settings.fetcher_scopes,
settings.fetcher_callback_url,
)
redis = get_redis()
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from datetime import UTC
from apscheduler.schedulers.asyncio import AsyncIOScheduler
scheduler: AsyncIOScheduler | None = None
def init_scheduler():
global scheduler
scheduler = AsyncIOScheduler(timezone=UTC)
scheduler.start()
def get_scheduler() -> AsyncIOScheduler:
global scheduler
if scheduler is None:
init_scheduler()
return scheduler # pyright: ignore[reportReturnType]
def stop_scheduler():
global scheduler
if scheduler:
scheduler.shutdown()

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from typing import cast
from app.config import (
AWSS3StorageSettings,
CloudflareR2Settings,
LocalStorageSettings,
StorageServiceType,
settings,
)
from app.storage import StorageService
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
from app.storage.local import LocalStorageService
storage: StorageService | None = None
def init_storage_service():
global storage
if settings.storage_service == StorageServiceType.LOCAL:
storage_settings = cast(LocalStorageSettings, settings.storage_settings)
storage = LocalStorageService(
storage_path=storage_settings.local_storage_path,
)
elif settings.storage_service == StorageServiceType.CLOUDFLARE_R2:
storage_settings = cast(CloudflareR2Settings, settings.storage_settings)
storage = CloudflareR2StorageService(
account_id=storage_settings.r2_account_id,
access_key_id=storage_settings.r2_access_key_id,
secret_access_key=storage_settings.r2_secret_access_key,
bucket_name=storage_settings.r2_bucket_name,
public_url_base=storage_settings.r2_public_url_base,
)
elif settings.storage_service == StorageServiceType.AWS_S3:
storage_settings = cast(AWSS3StorageSettings, settings.storage_settings)
storage = AWSS3StorageService(
access_key_id=storage_settings.s3_access_key_id,
secret_access_key=storage_settings.s3_secret_access_key,
bucket_name=storage_settings.s3_bucket_name,
public_url_base=storage_settings.s3_public_url_base,
region_name=storage_settings.s3_region_name,
)
else:
raise ValueError(f"Unsupported storage service: {settings.storage_service}")
return storage
def get_storage_service():
if storage is None:
return init_storage_service()
return storage

View File

@@ -1,34 +1,84 @@
from __future__ import annotations
from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.database import User
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import (
HTTPBearer,
OAuth2AuthorizationCodeBearer,
OAuth2PasswordBearer,
SecurityScopes,
)
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
oauth2_password = OAuth2PasswordBearer(
tokenUrl="oauth/token",
scopes={"*": "Allows access to all scopes."},
)
oauth2_code = OAuth2AuthorizationCodeBearer(
authorizationUrl="oauth/authorize",
tokenUrl="oauth/token",
scopes={
"chat.read": "Allows read chat messages on a user's behalf.",
"chat.write": "Allows sending chat messages on a user's behalf.",
"chat.write_manage": (
"Allows joining and leaving chat channels on a user's behalf."
),
"delegate": (
"Allows acting as the owner of a client; "
"only available for Client Credentials Grant."
),
"forum.write": "Allows creating and editing forum posts on a user's behalf.",
"friends.read": "Allows reading of the user's friend list.",
"identify": "Allows reading of the public profile of the user (/me).",
"public": "Allows reading of publicly available data on behalf of the user.",
},
)
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
security_scopes: SecurityScopes,
db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
) -> User:
"""获取当前认证用户"""
token = credentials.credentials
token = token_pw or token_code
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")
user = await get_current_user_by_token(token, db)
token_record = await get_token_by_access_token(db, token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token")
is_client = token_record.client_id in (
settings.osu_client_id,
settings.osu_web_client_id,
)
if security_scopes.scopes == ["*"]:
# client/web only
if not token_pw or not is_client:
raise HTTPException(status_code=401, detail="Not authenticated")
elif not is_client:
for scope in security_scopes.scopes:
if scope not in token_record.scope.split(","):
raise HTTPException(
status_code=403, detail=f"Insufficient scope: {scope}"
)
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return user
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user

View File

@@ -38,6 +38,22 @@ class BaseFetcher:
"Content-Type": "application/json",
}
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
if self.is_token_expired():
await self.refresh_access_token()
header = kwargs.pop("headers", {})
header = self.header
async with AsyncClient() as client:
response = await client.request(
method,
url,
headers=header,
**kwargs,
)
response.raise_for_status()
return response.json()
def is_token_expired(self) -> bool:
return self.token_expiry <= int(time.time())

View File

@@ -5,8 +5,6 @@ from app.log import logger
from ._base import BaseFetcher
from httpx import AsyncClient
class BeatmapFetcher(BaseFetcher):
async def get_beatmap(
@@ -21,11 +19,10 @@ class BeatmapFetcher(BaseFetcher):
logger.opt(colors=True).debug(
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
)
async with AsyncClient() as client:
response = await client.get(
return BeatmapResp.model_validate(
await self.request_api(
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
headers=self.header,
params=params,
)
response.raise_for_status()
return BeatmapResp.model_validate(response.json())
)

View File

@@ -5,18 +5,15 @@ from app.log import logger
from ._base import BaseFetcher
from httpx import AsyncClient
class BeatmapsetFetcher(BaseFetcher):
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
)
async with AsyncClient() as client:
response = await client.get(
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}",
headers=self.header,
return BeatmapsetResp.model_validate(
await self.request_api(
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}"
)
response.raise_for_status()
return BeatmapsetResp.model_validate(response.json())
)

View File

@@ -120,10 +120,10 @@ logger.add(
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
),
level=settings.LOG_LEVEL,
diagnose=settings.DEBUG,
level=settings.log_level,
diagnose=settings.debug,
)
logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True)
logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True)
uvicorn_loggers = [
"uvicorn",

View File

@@ -14,6 +14,20 @@ class BeatmapRankStatus(IntEnum):
QUALIFIED = 3
LOVED = 4
def has_leaderboard(self) -> bool:
return self in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.LOVED,
}
def has_pp(self) -> bool:
return self in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
class Genre(IntEnum):
ANY = 0

View File

@@ -3,10 +3,12 @@ from __future__ import annotations
from enum import IntEnum
from typing import ClassVar, Literal
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, Field
TOTAL_SCORE_DISTRIBUTION_BINS = 13
class _UserActivity(SignalRUnionMessage): ...
@@ -96,16 +98,14 @@ UserActivity = (
| ModdingBeatmap
| TestingBeatmap
| InDailyChallengeLobby
| PlayingDailyChallenge
)
class UserPresence(BaseModel):
activity: UserActivity | None = Field(
default=None, metadata=SignalRMeta(use_upper_case=True)
)
status: OnlineStatus | None = Field(
default=None, metadata=SignalRMeta(use_upper_case=True)
)
activity: UserActivity | None = None
status: OnlineStatus | None = None
@property
def pushable(self) -> bool:
@@ -126,3 +126,34 @@ class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身
DO_NOT_DISTURB = 1
ONLINE = 2
class DailyChallengeInfo(BaseModel):
room_id: int
class MultiplayerPlaylistItemStats(BaseModel):
playlist_item_id: int = 0
total_score_distribution: list[int] = Field(
default_factory=list,
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
)
cumulative_score: int = 0
last_processed_score_id: int = 0
class MultiplayerRoomStats(BaseModel):
room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
default_factory=dict
)
class MultiplayerRoomScoreSetEvent(BaseModel):
room_id: int
playlist_item_id: int
score_id: int
user_id: int
total_score: int
new_rank: int | None = None

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
from datetime import UTC, datetime
from app.models.score import GameMode
from pydantic import BaseModel, field_serializer
@@ -13,3 +15,41 @@ class UTCBaseModel(BaseModel):
v = v.replace(tzinfo=UTC)
return v.astimezone(UTC).isoformat()
return v
Cursor = dict[str, int]
class RespWithCursor(BaseModel):
cursor: Cursor | None = None
class PinAttributes(BaseModel):
is_pinned: bool
score_id: int
class CurrentUserAttributes(BaseModel):
can_beatmap_update_owner: bool | None = None
can_delete: bool | None = None
can_edit_metadata: bool | None = None
can_edit_tags: bool | None = None
can_hype: bool | None = None
can_hype_reason: str | None = None
can_love: bool | None = None
can_remove_from_loved: bool | None = None
is_watching: bool | None = None
new_hype_time: datetime | None = None
nomination_modes: list[GameMode] | None = None
remaining_hype: int | None = None
can_destroy: bool | None = None
can_reopen: bool | None = None
can_moderate_kudosu: bool | None = None
can_resolve: bool | None = None
vote_score: int | None = None
can_message: bool | None = None
can_message_error: str | None = None
last_read_id: int | None = None
can_new_comment: bool | None = None
can_new_comment_reason: str | None = None
pin: PinAttributes | None = None

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
from copy import deepcopy
import json
from typing import Literal, NotRequired, TypedDict
from app.config import settings as app_settings
from app.path import STATIC_DIR
class APIMod(TypedDict):
acronym: str
settings: NotRequired[dict[str, bool | float | str]]
settings: NotRequired[dict[str, bool | float | str | int]]
# https://github.com/ppy/osu-api/wiki#mods
@@ -129,10 +131,10 @@ COMMON_CONFIG: dict[str, dict] = {
}
RANKED_MODS: dict[int, dict[str, dict]] = {
0: COMMON_CONFIG,
1: COMMON_CONFIG,
2: COMMON_CONFIG,
3: COMMON_CONFIG,
0: deepcopy(COMMON_CONFIG),
1: deepcopy(COMMON_CONFIG),
2: deepcopy(COMMON_CONFIG),
3: deepcopy(COMMON_CONFIG),
}
# osu
RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False
@@ -154,8 +156,15 @@ for i in range(4, 10):
def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
if app_settings.enable_all_mods_pp:
return True
ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods:
if app_settings.enable_osu_rx and mod["acronym"] == "RX" and ruleset_id == 0:
continue
if app_settings.enable_osu_ap and mod["acronym"] == "AP" and ruleset_id == 0:
continue
mod["settings"] = mod.get("settings", {})
if (settings := ranked_mods.get(mod["acronym"])) is None:
return False

View File

@@ -0,0 +1,924 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from enum import IntEnum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypedDict,
cast,
override,
)
from app.database.beatmap import Beatmap
from app.dependencies.database import engine
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from .mods import APIMod
from .room import (
DownloadState,
MatchType,
MultiplayerRoomState,
MultiplayerUserState,
QueueMode,
RoomCategory,
RoomStatus,
)
from .signalr import (
SignalRMeta,
SignalRUnionMessage,
UserState,
)
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
PER_USER_LIMIT = 3
class MultiplayerClientState(UserState):
room_id: int = 0
class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room"
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = ""
match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState):
team_id: int
union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState
class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(BaseModel):
id: int
name: str
class TeamVersusRoomState(_MatchRoomState):
teams: list[MultiplayerTeam] = Field(
default_factory=lambda: [
MultiplayerTeam(id=0, name="Team Red"),
MultiplayerTeam(id=1, name="Team Blue"),
]
)
union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState
class PlaylistItem(BaseModel):
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int
beatmap_id: int
beatmap_checksum: str
ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool
playlist_order: int
played_at: datetime | None = None
star_rating: float
freestyle: bool
def _get_api_mods(self):
from app.models.mods import API_MODS, init_mods
if not API_MODS:
init_mods()
return API_MODS
def _validate_mod_for_ruleset(
self, mod: APIMod, ruleset_key: int, context: str = "mod"
) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if (
typed_ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[typed_ruleset_key]
):
raise InvokeException(
f"{context} {mod['acronym']} is invalid for this ruleset"
)
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not playable by users"
)
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not valid for multiplayer"
)
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
for i, mod1 in enumerate(mods):
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
if mod1_settings:
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(
f"Mods {mod1['acronym']} and "
f"{mod2['acronym']} are incompatible"
)
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
for req_mod in self.required_mods:
req_acronym = req_mod["acronym"]
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
if req_settings:
incompatible = set(req_settings.get("IncompatibleMods", []))
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(
f"Required mod {req_acronym} conflicts with "
f"allowed mods: {conflict_list}"
)
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
# Validate required mods
for mod in self.required_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
# Validate allowed mods
for mod in self.allowed_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
# Check internal compatibility of required mods
self._check_mod_compatibility(self.required_mods, ruleset_key)
# Check compatibility between required and allowed mods
self._check_required_allowed_compatibility(ruleset_key)
def validate_user_mods(
self,
user: "MultiplayerRoomUser",
proposed_mods: list[APIMod],
) -> tuple[bool, list[APIMod]]:
"""
Validates user mods against playlist item rules and returns valid mods.
Returns (is_valid, valid_mods).
"""
from typing import Literal, cast
API_MODS = self._get_api_mods()
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
valid_mods = []
all_proposed_valid = True
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
all_proposed_valid = False
continue
valid_mods.append(mod)
# Check mod compatibility within user mods
incompatible_mods = set()
final_valid_mods = []
for mod in valid_mods:
if mod["acronym"] in incompatible_mods:
all_proposed_valid = False
continue
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod)
# If not freestyle, check against allowed mods
if not self.freestyle:
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = []
for mod in final_valid_mods:
if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False
else:
filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
for mod in final_valid_mods:
mod_acronym = mod["acronym"]
is_compatible = True
for other_acronym in all_mod_acronyms:
if other_acronym == mod_acronym:
continue
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
is_compatible = False
all_proposed_valid = False
break
if is_compatible:
filtered_valid_mods.append(mod)
return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
copy.expired = False
copy.played_at = None
return copy
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
class MatchStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = (
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, download_progress=None
)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle
class MultiplayerRoom(BaseModel):
room_id: int
state: MultiplayerRoomState
settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list)
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int
@classmethod
def from_db(cls, room) -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
# 用户列表
users = [MultiplayerRoomUser(user_id=room.host_id)]
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if hasattr(room, "playlist"):
for item in room.playlist:
playlist.append(
PlaylistItem(
id=item.id,
owner_id=item.owner_id,
beatmap_id=item.beatmap_id,
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
ruleset_id=item.ruleset_id,
required_mods=item.required_mods,
allowed_mods=item.allowed_mods,
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating
if item.beatmap is not None
else 0.0,
freestyle=item.freestyle,
)
)
return cls(
room_id=room.id,
state=getattr(room, "state", MultiplayerRoomState.OPEN),
settings=MultiplayerRoomSettings(
name=room.name,
playlist_item_id=playlist[0].id if playlist else 0,
password=getattr(room, "password", ""),
match_type=room.type,
queue_mode=room.queue_mode,
auto_start_duration=timedelta(seconds=room.auto_start_duration),
auto_skip=room.auto_skip,
),
users=users,
host=host_user,
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=getattr(room, "channel_id", 0),
)
class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room
self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property
def upcoming_items(self):
return sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda i: i.playlist_order,
)
@property
def room(self):
return self.server_room.room
async def update_order(self):
from app.database import Playlist
match self.room.settings.queue_mode:
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
ordered_active_items = []
is_first_set = True
first_set_order_by_user_id = {}
active_items = [item for item in self.room.playlist if not item.expired]
active_items.sort(key=lambda x: x.id)
user_item_groups = {}
for item in active_items:
if item.owner_id not in user_item_groups:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max(
(len(items) for items in user_item_groups.values()), default=0
)
for i in range(max_items):
current_set = []
for user_id, items in user_item_groups.items():
if i < len(items):
current_set.append(items[i])
if is_first_set:
current_set.sort(
key=lambda item: (item.playlist_order, item.id)
)
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx
for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
ordered_active_items.extend(current_set)
is_first_set = False
case _:
ordered_active_items = sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with AsyncSession(engine) as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False
)
async def update_current_item(self):
upcoming_items = self.upcoming_items
next_item = (
upcoming_items[0]
if upcoming_items
else max(
self.room.playlist,
key=lambda i: i.played_at or datetime.min,
)
)
self.current_index = self.room.playlist.index(next_item)
last_id = self.room.settings.playlist_item_id
self.room.settings.playlist_item_id = next_item.id
if last_id != next_item.id:
await self.hub.setting_changed(self.server_room, True)
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
is_host = self.room.host and self.room.host.user_id == user.user_id
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if (
len([True for u in self.room.playlist if u.owner_id == user.user_id])
>= limit
):
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = beatmap.difficulty_rating
await Playlist.add_to_db(item, self.room.room_id, session)
self.room.playlist.append(item)
await self.hub.playlist_added(self.server_room, item)
await self.update_order()
await self.update_current_item()
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next(
(i for i in self.room.playlist if i.id == item.id), None
)
if existing_item is None:
raise InvokeException(
"Attempted to change an item that doesn't exist"
)
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to change an item which is not owned by the user"
)
if existing_item.expired:
raise InvokeException(
"Attempted to change an item which has already been played"
)
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = float(beatmap.difficulty_rating)
item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session)
# Update item in playlist
for idx, playlist_item in enumerate(self.room.playlist):
if playlist_item.id == item.id:
self.room.playlist[idx] = item
break
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
from app.database import Playlist
item = next(
(i for i in self.room.playlist if i.id == playlist_item_id),
None,
)
if item is None:
raise InvokeException("Item does not exist in the room")
# Check if it's the only item and current item
if item == self.current_item:
upcoming_items = [i for i in self.room.playlist if not i.expired]
if len(upcoming_items) == 1:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to remove an item which is not owned by the user"
)
if item.expired:
raise InvokeException(
"Attempted to remove an item which has already been played"
)
async with AsyncSession(engine) as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
self.room.playlist.remove(item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()
await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with AsyncSession(engine) as session:
played_at = datetime.now(UTC)
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_current_item()
async def update_queue_mode(self):
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_order()
await self.update_current_item()
@property
def current_item(self):
return self.room.playlist[self.current_index]
@dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
)
class _MatchRequest(SignalRUnionMessage): ...
class ChangeTeamRequest(_MatchRequest):
union_type: ClassVar[Literal[0]] = 0
team_id: int
class StartMatchCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[1]] = 1
duration: timedelta
class StopCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[2]] = 2
id: int
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
class MatchTypeHandler(ABC):
def __init__(self, room: "ServerMultiplayerRoom"):
self.room = room
self.hub = room.hub
@abstractmethod
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@abstractmethod
def get_details(self) -> MatchStartedEventDetail: ...
class HeadToHeadHandler(MatchTypeHandler):
@override
async def handle_join(self, user: MultiplayerRoomUser):
if user.match_state is not None:
user.match_state = None
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
return detail
class TeamVersusHandler(MatchTypeHandler):
@override
def __init__(self, room: "ServerMultiplayerRoom"):
super().__init__(room)
self.state = TeamVersusRoomState()
room.room.match_state = self.state
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
self.hub.tasks.add(task)
task.add_done_callback(self.hub.tasks.discard)
def _get_best_available_team(self) -> int:
for team in self.state.teams:
if all(
(
user.match_state is None
or not isinstance(user.match_state, TeamVersusUserState)
or user.match_state.team_id != team.id
)
for user in self.room.room.users
):
return team.id
from collections import defaultdict
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
team_counts[user.match_state.team_id] += 1
if team_counts:
min_count = min(team_counts.values())
for team_id, count in team_counts.items():
if count == min_count:
return team_id
return self.state.teams[0].id if self.state.teams else 0
@override
async def handle_join(self, user: MultiplayerRoomUser):
best_team_id = self._get_best_available_team()
user.match_state = TeamVersusUserState(team_id=best_team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
if not isinstance(request, ChangeTeamRequest):
return
if request.team_id not in [team.id for team in self.state.teams]:
raise InvokeException("Invalid team ID")
user.match_state = TeamVersusUserState(team_id=request.team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
MATCH_TYPE_HANDLERS = {
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
MatchType.TEAM_VERSUS: TeamVersusHandler,
}
@dataclass
class ServerMultiplayerRoom:
room: MultiplayerRoom
category: RoomCategory
status: RoomStatus
start_at: datetime
hub: "MultiplayerHub"
match_type_handler: MatchTypeHandler
queue: MultiplayerQueue
_next_countdown_id: int
_countdown_id_lock: asyncio.Lock
_tracked_countdown: dict[int, CountdownInfo]
def __init__(
self,
room: MultiplayerRoom,
category: RoomCategory,
start_at: datetime,
hub: "MultiplayerHub",
):
self.room = room
self.category = category
self.status = RoomStatus.IDLE
self.start_at = start_at
self.hub = hub
self.queue = MultiplayerQueue(self)
self._next_countdown_id = 0
self._countdown_id_lock = asyncio.Lock()
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
self
)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
if on_complete is not None:
await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive:
await self.stop_all_countdowns(countdown.__class__)
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = self._tracked_countdown.get(countdown.id)
if info is None:
return
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
for countdown in list(self._tracked_countdown.values()):
if isinstance(countdown.countdown, typ):
await self.stop_countdown(countdown.countdown)
class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent):
id: int
union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
class GameplayAbortReason(IntEnum):
LOAD_TOOK_TOO_LONG = 0
HOST_ABORTED = 1
class MatchStartedEventDetail(TypedDict):
room_type: Literal["playlists", "head_to_head", "team_versus"]
team: dict[int, Literal["blue", "red"]] | None

View File

@@ -1,7 +1,6 @@
# OAuth 相关模型
from __future__ import annotations
from typing import List
from pydantic import BaseModel
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型"""
form_error: dict
class UserRegistrationErrors(BaseModel):
"""用户注册错误模型"""
username: List[str] = []
user_email: List[str] = []
password: List[str] = []
username: list[str] = []
user_email: list[str] = []
password: list[str] = []
class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型"""
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None

View File

@@ -1,15 +1,8 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from app.database import User
from app.database.beatmap import Beatmap
from app.models.mods import APIMod
from .model import UTCBaseModel
from pydantic import BaseModel, Field
from pydantic import BaseModel
class RoomCategory(str, Enum):
@@ -17,6 +10,7 @@ class RoomCategory(str, Enum):
SPOTLIGHT = "spotlight"
FEATURED_ARTIST = "featured_artist"
DAILY_CHALLENGE = "daily_challenge"
REALTIME = "realtime" # INTERNAL USE ONLY, DO NOT USE IN API
class MatchType(str, Enum):
@@ -42,18 +36,40 @@ class RoomStatus(str, Enum):
PLAYING = "playing"
class PlaylistItem(UTCBaseModel):
id: int | None
owner_id: int
ruleset_id: int
expired: bool
playlist_order: int | None
played_at: datetime | None
allowed_mods: list[APIMod] = Field(default_factory=list)
required_mods: list[APIMod] = Field(default_factory=list)
beatmap_id: int
beatmap: Beatmap | None
freestyle: bool
class MultiplayerRoomState(str, Enum):
OPEN = "open"
WAITING_FOR_LOAD = "waiting_for_load"
PLAYING = "playing"
CLOSED = "closed"
class MultiplayerUserState(str, Enum):
IDLE = "idle"
READY = "ready"
WAITING_FOR_LOAD = "waiting_for_load"
LOADED = "loaded"
READY_FOR_GAMEPLAY = "ready_for_gameplay"
PLAYING = "playing"
FINISHED_PLAY = "finished_play"
RESULTS = "results"
SPECTATING = "spectating"
@property
def is_playing(self) -> bool:
return self in {
self.WAITING_FOR_LOAD,
self.PLAYING,
self.READY_FOR_GAMEPLAY,
self.LOADED,
}
class DownloadState(str, Enum):
UNKNOWN = "unknown"
NOT_DOWNLOADED = "not_downloaded"
DOWNLOADING = "downloading"
IMPORTING = "importing"
LOCALLY_AVAILABLE = "locally_available"
class RoomPlaylistItemStats(BaseModel):
@@ -67,39 +83,7 @@ class RoomDifficultyRange(BaseModel):
max: float
class ItemAttemptsCount(BaseModel):
id: int
attempts: int
passed: bool
class PlaylistAggregateScore(BaseModel):
playlist_item_attempts: list[ItemAttemptsCount]
class Room(UTCBaseModel):
id: int | None
name: str = ""
password: str | None
has_password: bool = False
host: User | None
category: RoomCategory = RoomCategory.NORMAL
duration: int | None
starts_at: datetime | None
ends_at: datetime | None
participant_count: int = 0
recent_participants: list[User] = Field(default_factory=list)
max_attempts: int | None
playlist: list[PlaylistItem] = Field(default_factory=list)
playlist_item_stats: RoomPlaylistItemStats | None
difficulty_range: RoomDifficultyRange | None
type: MatchType = MatchType.PLAYLISTS
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_skip: bool = False
auto_start_duration: int = 0
current_user_score: PlaylistAggregateScore | None
current_playlist_item: PlaylistItem | None
channel_id: int = 0
status: RoomStatus = RoomStatus.IDLE
# availability 字段在当前序列化中未包含,但可能在某些场景下需要
availability: RoomAvailability | None
class PlaylistStatus(BaseModel):
count_active: int
count_total: int
ruleset_ids: list[int]

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
from enum import Enum
from typing import Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods
from pydantic import BaseModel, Field, ValidationInfo, field_validator
import rosu_pp_py as rosu
if TYPE_CHECKING:
import rosu_pp_py as rosu
class GameMode(str, Enum):
@@ -14,13 +16,19 @@ class GameMode(str, Enum):
TAIKO = "taiko"
FRUITS = "fruits"
MANIA = "mania"
OSURX = "osurx"
OSUAP = "osuap"
def to_rosu(self) -> "rosu.GameMode":
import rosu_pp_py as rosu
def to_rosu(self) -> rosu.GameMode:
return {
GameMode.OSU: rosu.GameMode.Osu,
GameMode.TAIKO: rosu.GameMode.Taiko,
GameMode.FRUITS: rosu.GameMode.Catch,
GameMode.MANIA: rosu.GameMode.Mania,
GameMode.OSURX: rosu.GameMode.Osu,
GameMode.OSUAP: rosu.GameMode.Osu,
}[self]
@@ -29,8 +37,11 @@ MODE_TO_INT = {
GameMode.TAIKO: 1,
GameMode.FRUITS: 2,
GameMode.MANIA: 3,
GameMode.OSURX: 0,
GameMode.OSUAP: 0,
}
INT_TO_MODE = {v: k for k, v in MODE_TO_INT.items()}
INT_TO_MODE[0] = GameMode.OSU
class Rank(str, Enum):

View File

@@ -1,12 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Any, ClassVar
from typing import ClassVar
from pydantic import (
BaseModel,
BeforeValidator,
Field,
)
@@ -15,23 +13,7 @@ from pydantic import (
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_upper_case: bool = False # use upper CamelCase for field names
def _by_index(v: Any, class_: type[Enum]):
enum_list = list(class_)
if not isinstance(v, int):
return v
if 0 <= v < len(enum_list):
return enum_list[v]
raise ValueError(
f"Value {v} is out of range for enum "
f"{class_.__name__} with {len(enum_list)} items"
)
def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
return BeforeValidator(lambda v: _by_index(v, enum_class))
use_abbr: bool = True
class SignalRUnionMessage(BaseModel):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import datetime
from enum import IntEnum
from typing import Any
from typing import Annotated, Any
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
@@ -89,9 +89,9 @@ class LegacyReplayFrame(BaseModel):
mouse_y: float | None = None
button_state: int
header: FrameHeader | None = Field(
default=None, metadata=[SignalRMeta(member_ignore=True)]
)
header: Annotated[
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
]
class FrameDataBundle(BaseModel):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import NotRequired, TypedDict
from .model import UTCBaseModel
@@ -83,9 +84,9 @@ class RankHistory(BaseModel):
data: list[int]
class Page(BaseModel):
html: str = ""
raw: str = ""
class Page(TypedDict):
html: NotRequired[str]
raw: NotRequired[str]
class BeatmapsetType(str, Enum):

View File

@@ -3,6 +3,3 @@ from __future__ import annotations
from pathlib import Path
STATIC_DIR = Path(__file__).parent.parent / "static"
REPLAY_DIR = Path(__file__).parent.parent / "replays"
REPLAY_DIR.mkdir(exist_ok=True)

View File

@@ -2,16 +2,17 @@ from __future__ import annotations
from app.signalr import signalr_router as signalr_router
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmap,
beatmapset,
me,
relationship,
score,
user,
)
from .api_router import router as api_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router
from .private import private_router as private_router
from .v2.router import router as api_v2_router
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
__all__ = [
"api_v2_router",
"auth_router",
"fetcher_router",
"file_router",
"private_router",
"signalr_router",
]

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
from typing import Literal
from app.auth import (
authenticate_user,
@@ -9,12 +10,14 @@ from app.auth import (
generate_refresh_token,
get_password_hash,
get_token_by_refresh_token,
get_user_by_authorization_code,
store_token,
)
from app.config import settings
from app.database import DailyChallengeStats, User
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
@@ -26,6 +29,7 @@ from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -159,14 +163,22 @@ async def register_user(
country_code="CN", # 默认国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
is_supporter=settings.enable_supporter_for_all_users,
support_level=int(settings.enable_supporter_for_all_users),
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
assert new_user.id is not None, "New user ID should not be None"
for i in GameMode:
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
if settings.enable_osu_rx:
statistics_rx = UserStatistics(mode=GameMode.OSURX, user_id=new_user.id)
db.add(statistics_rx)
if settings.enable_osu_ap:
statistics_ap = UserStatistics(mode=GameMode.OSUAP, user_id=new_user.id)
db.add(statistics_ap)
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
db.add(daily_challenge_user_stats)
await db.commit()
@@ -187,21 +199,36 @@ async def register_user(
@router.post("/oauth/token", response_model=TokenResponse)
async def oauth_token(
grant_type: str = Form(...),
client_id: str = Form(...),
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(...),
client_id: int = Form(...),
client_secret: str = Form(...),
code: str | None = Form(None),
scope: str = Form("*"),
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
if (
client_id != settings.OSU_CLIENT_ID
or client_secret != settings.OSU_CLIENT_SECRET
):
scopes = scope.split(" ")
client = (
await db.exec(
select(OAuthClient).where(
OAuthClient.client_id == client_id,
OAuthClient.client_secret == client_secret,
)
)
).first()
is_game_client = (client_id, client_secret) in [
(settings.osu_client_id, settings.osu_client_secret),
(settings.osu_web_client_id, settings.osu_web_client_secret),
]
if client is None and not is_game_client:
return create_oauth_error_response(
error="invalid_client",
description=(
@@ -214,7 +241,6 @@ async def oauth_token(
)
if grant_type == "password":
# 密码授权流程
if not username or not password:
return create_oauth_error_response(
error="invalid_request",
@@ -225,6 +251,16 @@ async def oauth_token(
),
hint="Username and password required",
)
if scopes != ["*"]:
return create_oauth_error_response(
error="invalid_scope",
description=(
"The requested scope is invalid, unknown, "
"or malformed. The client may not request "
"more than one scope at a time."
),
hint="Only '*' scope is allowed for password grant type",
)
# 验证用户
user = await authenticate_user(db, username, password)
@@ -242,7 +278,7 @@ async def oauth_token(
)
# 生成令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
)
@@ -253,15 +289,17 @@ async def oauth_token(
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=scope,
)
@@ -295,7 +333,7 @@ async def oauth_token(
)
# 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
)
@@ -305,19 +343,83 @@ async def oauth_token(
await store_token(
db,
token_record.user_id,
client_id,
scopes,
access_token,
new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token,
scope=scope,
)
elif grant_type == "authorization_code":
if client is None:
return create_oauth_error_response(
error="invalid_client",
description=(
"Client authentication failed (e.g., unknown client, "
"no client authentication included, "
"or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401,
)
if not code:
return create_oauth_error_response(
error="invalid_request",
description=(
"The request is missing a required parameter, "
"includes an invalid parameter value, "
"includes a parameter more than once, or is otherwise malformed."
),
hint="Authorization code required",
)
code_result = await get_user_by_authorization_code(db, redis, client_id, code)
if not code_result:
return create_oauth_error_response(
error="invalid_grant",
description=(
"The provided authorization grant (e.g., authorization code, "
"resource owner credentials) or refresh token is invalid, "
"expired, revoked, does not match the redirection URI used in "
"the authorization request, or was issued to another client."
),
hint="Invalid authorization code",
)
user, scopes = code_result
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
)
else:
return create_oauth_error_response(
error="unsupported_grant_type",

View File

@@ -5,7 +5,7 @@ from app.fetcher import Fetcher
from fastapi import APIRouter, Depends
fetcher_router = APIRouter()
fetcher_router = APIRouter(prefix="/fetcher", tags=["fetcher"])
@fetcher_router.get("/callback")

26
app/router/file.py Normal file
View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from app.dependencies.storage import get_storage_service
from app.storage import LocalStorageService, StorageService
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
file_router = APIRouter(prefix="/file")
@file_router.get("/{path:path}")
async def get_file(path: str, storage: StorageService = Depends(get_storage_service)):
if not isinstance(storage, LocalStorageService):
raise HTTPException(404, "Not Found")
if not await storage.is_exists(path):
raise HTTPException(404, "Not Found")
try:
return FileResponse(
path=storage._get_file_path(path),
media_type="application/octet-stream",
filename=path.split("/")[-1],
)
except FileNotFoundError:
raise HTTPException(404, "Not Found")

View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from . import avatar # noqa: F401
from .router import router as private_router
__all__ = [
"private_router",
]

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import base64
import hashlib
from io import BytesIO
from app.database.lazer_user import User
from app.dependencies.database import get_db
from app.dependencies.storage import get_storage_service
from app.storage.base import StorageService
from .router import router
from fastapi import Body, Depends, HTTPException
from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post("/avatar/upload", tags=["avatar"])
async def upload_avatar(
file: str = Body(...),
user_id: int = Body(...),
storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db),
):
content = base64.b64decode(file)
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# check file
if len(content) > 5 * 1024 * 1024: # 5MB limit
raise HTTPException(status_code=400, detail="File size exceeds 5MB limit")
elif len(content) == 0:
raise HTTPException(status_code=400, detail="File cannot be empty")
with Image.open(BytesIO(content)) as img:
if img.format not in ["PNG", "JPEG", "GIF"]:
raise HTTPException(status_code=400, detail="Invalid image format")
if img.size[0] > 256 or img.size[1] > 256:
raise HTTPException(
status_code=400, detail="Image size exceeds 256x256 pixels"
)
filehash = hashlib.sha256(content).hexdigest()
storage_path = f"avatars/{user_id}_{filehash}.png"
if not await storage.is_exists(storage_path):
await storage.write_file(storage_path, content)
url = await storage.get_file_url(storage_path)
user.avatar_url = url
await session.commit()
return {
"url": url,
"filehash": filehash,
}

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import hashlib
import hmac
import time
from app.config import settings
from fastapi import APIRouter, Depends, Header, HTTPException, Request
async def verify_signature(
request: Request,
ts: int = Header(..., alias="X-Timestamp"),
nonce: str = Header(..., alias="X-Nonce"),
signature: str = Header(..., alias="X-Signature"),
):
path = request.url.path
data = await request.body()
body = data.decode("utf-8")
py_ts = ts // 1000
if abs(time.time() - py_ts) > 30:
raise HTTPException(status_code=403, detail="Invalid timestamp")
payload = f"{path}|{body}|{ts}|{nonce}"
expected_sig = hmac.new(
settings.private_api_secret.encode(), payload.encode(), hashlib.sha256
).hexdigest()
if not hmac.compare_digest(expected_sig, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
router = APIRouter(
prefix="/api/private",
dependencies=[Depends(verify_signature)],
include_in_schema=False,
)

View File

@@ -1,33 +0,0 @@
from __future__ import annotations
from app.database.room import RoomIndex
from app.dependencies.database import get_db, get_redis
from app.models.room import Room
from .api_router import router
from fastapi import Depends, Query
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/rooms", tags=["rooms"], response_model=list[Room])
async def get_all_rooms(
mode: str = Query(
None
), # TODO: lazer源码显示房间不会是除了open以外的其他状态先放在这里
status: str = Query(None),
category: str = Query(None),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
roomsList: list[Room] = []
for room_index in all_room_ids:
dumped_room = await redis.get(str(room_index.id))
if dumped_room:
actual_room = Room.model_validate_json(str(dumped_room))
if actual_room.status == status and actual_room.category == category:
roomsList.append(actual_room)
return roomsList

View File

@@ -1,227 +0,0 @@
from __future__ import annotations
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
from app.database.score import get_leaderboard, process_score, process_user
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.models.beatmap import BeatmapRankStatus
from app.models.score import (
INT_TO_MODE,
GameMode,
LeaderboardType,
Rank,
SoloScoreSubmissionInfo,
)
from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class BeatmapScores(BaseModel):
scores: list[ScoreResp]
userScore: ScoreResp | None = None
@router.get(
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
)
async def get_beatmap_scores(
beatmap: int,
mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mods: list[str] = Query(default_factory=set, alias="mods[]"),
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
all_scores, user_score = await get_leaderboard(
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
)
return BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
class BeatmapUserScore(BaseModel):
position: int
score: ScoreResp
@router.get(
"/beatmaps/{beatmap}/scores/users/{user}",
tags=["beatmap"],
response_model=BeatmapUserScore,
)
async def get_user_beatmap_score(
beatmap: int,
user: int,
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
user_score = (
await db.exec(
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.order_by(col(Score.total_score).desc())
)
).first()
if not user_score:
raise HTTPException(
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
)
else:
return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0,
score=await ScoreResp.from_db(db, user_score),
)
@router.get(
"/beatmaps/{beatmap}/scores/users/{user}/all",
tags=["beatmap"],
response_model=list[ScoreResp],
)
async def get_user_all_beatmap_scores(
beatmap: int,
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
all_user_scores = (
await db.exec(
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.order_by(col(Score.classic_total_score).desc())
)
).all()
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
)
async def create_solo_score(
beatmap: int,
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap,
ruleset_id=INT_TO_MODE[ruleset_id],
)
db.add(score_token)
await db.commit()
await db.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put(
"/beatmaps/{beatmap}/solo/scores/{token}",
tags=["beatmap"],
response_model=ScoreResp,
)
async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
if not info.passed:
info.rank = Rank.F
async with db:
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
)
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
)
await db.refresh(current_user)
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score)

17
app/router/v2/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmap,
beatmapset,
me,
misc,
relationship,
room,
score,
user,
)
from .router import router as api_v2_router
__all__ = [
"api_v2_router",
]

View File

@@ -17,9 +17,9 @@ from app.models.score import (
GameMode,
)
from .api_router import router
from .router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis.asyncio import Redis
@@ -33,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -56,7 +56,7 @@ async def lookup_beatmap(
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -74,9 +74,10 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: User = Depends(get_current_user),
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
if not b_ids:
# select 50 beatmaps by last_updated
@@ -86,9 +87,29 @@ async def batch_get_beatmaps(
)
).all()
else:
beatmaps = (
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
beatmaps = list(
(
await db.exec(
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
)
).all()
)
not_found_beatmaps = [
bid for bid in b_ids if bid not in [bm.id for bm in beatmaps]
]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[
Beatmap.get_or_fetch(db, fetcher, bid=bid)
for bid in not_found_beatmaps
],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
)
for beatmap in beatmaps:
await db.refresh(beatmap)
return BatchGetResp(
beatmaps=[
@@ -105,7 +126,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),

View File

@@ -2,47 +2,56 @@ from __future__ import annotations
from typing import Literal
from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from .api_router import router
from .router import router
from fastapi import Depends, Form, HTTPException, Query
from fastapi import Depends, Form, HTTPException, Query, Security
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from httpx import HTTPError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
async def lookup_beatmapset(
beatmap_id: int = Query(),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
return resp
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
await Beatmapset.from_resp(db, resp)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = await BeatmapsetResp.from_db(
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return resp
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
async def download_beatmapset(
beatmapset: int,
no_video: bool = Query(True, alias="noVideo"),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
):
if current_user.country_code == "CN":
return RedirectResponse(
@@ -59,7 +68,7 @@ async def download_beatmapset(
async def favourite_beatmapset(
beatmapset: int,
action: Literal["favourite", "unfavourite"] = Form(),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
existing_favourite = (

View File

@@ -6,9 +6,9 @@ from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.models.score import GameMode
from .api_router import router
from .router import router
from fastapi import Depends
from fastapi import Depends, Security
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -16,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
ruleset: GameMode | None = None,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
):
return await UserResp.from_db(

25
app/router/v2/misc.py Normal file
View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.config import settings
from .router import router
from pydantic import BaseModel
class Background(BaseModel):
url: str
class BackgroundsResp(BaseModel):
ends_at: datetime = datetime(year=9999, month=12, day=31, tzinfo=UTC)
backgrounds: list[Background]
@router.get("/seasonal-backgrounds", response_model=BackgroundsResp)
async def get_seasonal_backgrounds():
return BackgroundsResp(
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
)

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
from app.database import User as DBUser
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
from .api_router import router
from .router import router
from fastapi import Depends, HTTPException, Query, Request
from fastapi import Depends, HTTPException, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -17,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
async def get_relationship(
request: Request,
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["friends.read"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
@@ -43,7 +42,7 @@ class AddFriendResp(BaseModel):
async def add_relationship(
request: Request,
target: int = Query(),
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
@@ -106,7 +105,7 @@ async def add_relationship(
async def delete_relationship(
request: Request,
target: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (

361
app/router/v2/room.py Normal file
View File

@@ -0,0 +1,361 @@
from __future__ import annotations
from datetime import UTC, datetime
from typing import Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
from app.database.lazer_user import User, UserResp
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
from app.database.playlists import Playlist, PlaylistResp
from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.dependencies.user import get_current_user
from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.signalr.hub import MultiplayerHubs
from .router import router
from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
async def get_all_rooms(
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open"
),
category: RoomCategory = Query(RoomCategory.NORMAL),
status: RoomStatus | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
resp_list: list[RoomResp] = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
now = datetime.now(UTC)
if status is not None:
where_clauses.append(col(Room.status) == status)
if mode == "open":
where_clauses.append(
(col(Room.ends_at).is_(None))
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
)
if category == RoomCategory.REALTIME:
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated":
where_clauses.append(
exists().where(
col(RoomParticipatedUser.room_id) == Room.id,
col(RoomParticipatedUser.user_id) == current_user.id,
)
)
if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended":
where_clauses.append(
(col(Room.ends_at).is_not(None))
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
)
db_rooms = (
(
await db.exec(
select(Room).where(
*where_clauses,
)
)
)
.unique()
.all()
)
for room in db_rooms:
resp = await RoomResp.from_db(room, db)
if category == RoomCategory.REALTIME:
mp_room = MultiplayerHubs.rooms.get(room.id)
resp.has_password = (
bool(mp_room.room.settings.password.strip())
if mp_room is not None
else False
)
resp.category = RoomCategory.NORMAL
resp_list.append(resp)
return resp_list
class APICreatedRoom(RoomResp):
error: str = ""
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession
):
participated_user = (
await session.exec(
select(RoomParticipatedUser).where(
RoomParticipatedUser.room_id == room_id,
RoomParticipatedUser.user_id == user_id,
)
)
).first()
if participated_user is None:
participated_user = RoomParticipatedUser(
room_id=room_id,
user_id=user_id,
joined_at=datetime.now(UTC),
)
session.add(participated_user)
else:
participated_user.left_at = None
participated_user.joined_at = datetime.now(UTC)
db_room.participant_count += 1
@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
async def create_room(
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db)
# await db.commit()
# await db.refresh(db_room)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = ""
return created_room
@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
async def get_room(
room: int,
category: str = Query(default=""),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
redis: Redis = Depends(get_redis),
):
# 直接从db获取信息毕竟都一样
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(
db_room, include=["current_user_score"], session=db, user=current_user
)
return resp
@router.delete("/rooms/{room}", tags=["room"])
async def delete_room(
room: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
else:
db_room.ends_at = datetime.now(UTC)
await db.commit()
return None
@router.put("/rooms/{room}/users/{user}", tags=["room"])
async def add_user_to_room(
room: int,
user: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is not None:
await _participate_room(room, user, db_room, db)
await db.commit()
await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db)
return resp
else:
raise HTTPException(404, "room not found0")
@router.delete("/rooms/{room}/users/{user}", tags=["room"])
async def remove_user_from_room(
room: int,
user: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is not None:
participated_user = (
await db.exec(
select(RoomParticipatedUser).where(
RoomParticipatedUser.room_id == room,
RoomParticipatedUser.user_id == user,
)
)
).first()
if participated_user is not None:
participated_user.left_at = datetime.now(UTC)
db_room.participant_count -= 1
await db.commit()
return None
else:
raise HTTPException(404, "Room not found")
class APILeaderboard(BaseModel):
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
user_score: ItemAttemptsResp | None = None
@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
async def get_room_leaderboard(
room: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
aggs = await db.exec(
select(ItemAttemptsCount)
.where(ItemAttemptsCount.room_id == room)
.order_by(col(ItemAttemptsCount.total_score).desc())
)
aggs_resp = []
user_agg = None
for i, agg in enumerate(aggs):
resp = await ItemAttemptsResp.from_db(agg, db)
resp.position = i + 1
# resp.accuracy *= 100
aggs_resp.append(resp)
if agg.user_id == current_user.id:
user_agg = resp
return APILeaderboard(
leaderboard=aggs_resp,
user_score=user_agg,
)
class RoomEvents(BaseModel):
beatmaps: list[BeatmapResp] = Field(default_factory=list)
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
current_playlist_item_id: int = 0
events: list[MultiplayerEventResp] = Field(default_factory=list)
first_event_id: int = 0
last_event_id: int = 0
playlist_items: list[PlaylistResp] = Field(default_factory=list)
room: RoomResp
user: list[UserResp] = Field(default_factory=list)
@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
async def get_room_events(
room_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000),
after: int | None = Query(None, ge=0),
before: int | None = Query(None, ge=0),
):
events = (
await db.exec(
select(MultiplayerEvent)
.where(
MultiplayerEvent.room_id == room_id,
col(MultiplayerEvent.id) > after if after is not None else True,
col(MultiplayerEvent.id) < before if before is not None else True,
)
.order_by(col(MultiplayerEvent.id).desc())
.limit(limit)
)
).all()
user_ids = set()
playlist_items = {}
beatmap_ids = set()
event_resps = []
first_event_id = 0
last_event_id = 0
current_playlist_item_id = 0
for event in events:
event_resps.append(MultiplayerEventResp.from_db(event))
if event.user_id:
user_ids.add(event.user_id)
if event.playlist_item_id is not None and (
playitem := (
await db.exec(
select(Playlist).where(
Playlist.id == event.playlist_item_id,
Playlist.room_id == room_id,
)
)
).first()
):
current_playlist_item_id = playitem.id
playlist_items[event.playlist_item_id] = playitem
beatmap_ids.add(playitem.beatmap_id)
scores = await db.exec(
select(Score).where(
Score.playlist_item_id == event.playlist_item_id,
Score.room_id == room_id,
)
)
for score in scores:
user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
if room := MultiplayerHubs.rooms.get(room_id):
current_playlist_item_id = room.queue.current_item.id
room_resp = await RoomResp.from_hub(room)
else:
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room, db)
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
return RoomEvents(
beatmaps=beatmap_resps,
beatmapsets=beatmapset_resps,
current_playlist_item_id=current_playlist_item_id,
events=event_resps,
first_event_id=first_event_id,
last_event_id=last_event_id,
playlist_items=playlist_items_resps,
room=room_resp,
user=user_resps,
)

View File

@@ -2,4 +2,4 @@ from __future__ import annotations
from fastapi import APIRouter
router = APIRouter()
router = APIRouter(prefix="/api/v2")

770
app/router/v2/score.py Normal file
View File

@@ -0,0 +1,770 @@
from __future__ import annotations
from datetime import UTC, date, datetime
import time
from app.calculator import clamp
from app.config import settings
from app.database import (
Beatmap,
Playlist,
Room,
Score,
ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
)
from app.database.counts import ReplayWatchedCount
from app.database.playlist_attempts import ItemAttemptsCount
from app.database.playlist_best_score import (
PlaylistBestScore,
get_position,
process_playlist_best_score,
)
from app.database.relationship import Relationship, RelationshipType
from app.database.score import (
MultiplayerScores,
ScoreAround,
get_leaderboard,
process_score,
process_user,
)
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.room import RoomCategory
from app.models.score import (
INT_TO_MODE,
GameMode,
LeaderboardType,
Rank,
SoloScoreSubmissionInfo,
)
from app.storage.base import StorageService
from app.storage.local import LocalStorageService
from .router import router
from fastapi import Body, Depends, Form, HTTPException, Query, Security
from fastapi.responses import FileResponse, RedirectResponse
from httpx import HTTPError
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 10
async def submit_score(
info: SoloScoreSubmissionInfo,
beatmap: int,
token: int,
current_user: User,
db: AsyncSession,
redis: Redis,
fetcher: Fetcher,
item_id: int | None = None,
room_id: int | None = None,
):
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
try:
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = (
db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_leaderboard
)
beatmap_length = db_beatmap.total_length
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
item_id,
room_id,
)
await db.refresh(current_user)
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, beatmap_length, ranked)
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score)
class BeatmapScores(BaseModel):
scores: list[ScoreResp]
userScore: ScoreResp | None = None
@router.get(
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
)
async def get_beatmap_scores(
beatmap: int,
mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mods: list[str] = Query(default_factory=set, alias="mods[]"),
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
all_scores, user_score = await get_leaderboard(
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
)
return BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
class BeatmapUserScore(BaseModel):
position: int
score: ScoreResp
@router.get(
"/beatmaps/{beatmap}/scores/users/{user}",
tags=["beatmap"],
response_model=BeatmapUserScore,
)
async def get_user_beatmap_score(
beatmap: int,
user: int,
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
user_score = (
await db.exec(
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.order_by(col(Score.total_score).desc())
)
).first()
if not user_score:
raise HTTPException(
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
)
else:
resp = await ScoreResp.from_db(db, user_score)
return BeatmapUserScore(
position=resp.rank_global or 0,
score=resp,
)
@router.get(
"/beatmaps/{beatmap}/scores/users/{user}/all",
tags=["beatmap"],
response_model=list[ScoreResp],
)
async def get_user_all_beatmap_scores(
beatmap: int,
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
all_user_scores = (
await db.exec(
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.order_by(col(Score.classic_total_score).desc())
)
).all()
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
)
async def create_solo_score(
beatmap: int,
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap,
ruleset_id=INT_TO_MODE[ruleset_id],
)
db.add(score_token)
await db.commit()
await db.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put(
"/beatmaps/{beatmap}/solo/scores/{token}",
tags=["beatmap"],
response_model=ScoreResp,
)
async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
@router.post(
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
)
async def create_playlist_score(
room_id: int,
playlist_id: int,
beatmap_id: int = Form(),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
version_hash: str = Form(""),
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
raise HTTPException(status_code=400, detail="Room has ended")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist not found")
# validate
if not item.freestyle:
if item.ruleset_id != ruleset_id:
raise HTTPException(
status_code=400, detail="Ruleset mismatch in playlist item"
)
if item.beatmap_id != beatmap_id:
raise HTTPException(
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
agg = await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id,
ItemAttemptsCount.user_id == current_user.id,
)
)
agg = agg.first()
if agg and room.max_attempts and agg.attempts >= room.max_attempts:
raise HTTPException(
status_code=422,
detail="You have reached the maximum attempts for this room",
)
if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
# 这里应该不用验证mod了吧。。。
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap_id,
ruleset_id=INT_TO_MODE[ruleset_id],
playlist_item_id=playlist_id,
)
session.add(score_token)
await session.commit()
await session.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
async def submit_playlist_score(
room_id: int,
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
user_id = current_user.id
score_resp = await submit_score(
info,
item.beatmap_id,
token,
current_user,
session,
redis,
fetcher,
item.id,
room_id,
)
await process_playlist_best_score(
room_id,
playlist_id,
user_id,
score_resp.id,
score_resp.total_score,
session,
redis,
)
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
return score_resp
class IndexedScoreResp(MultiplayerScores):
total: int
user_score: ScoreResp | None = None
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
)
async def index_playlist_scores(
room_id: int,
playlist_id: int,
limit: int = 50,
cursor: int = Query(2000000, alias="cursor[total_score]"),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
limit = clamp(limit, 1, 50)
scores = (
await session.exec(
select(PlaylistBestScore)
.where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.total_score < cursor,
)
.order_by(col(PlaylistBestScore.total_score).desc())
.limit(limit + 1)
)
).all()
has_more = len(scores) > limit
if has_more:
scores = scores[:-1]
user_score = None
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
for score in score_resp:
score.position = await get_position(room_id, playlist_id, score.id, session)
if score.user_id == current_user.id:
user_score = score
if room.category == RoomCategory.DAILY_CHALLENGE:
score_resp = [s for s in score_resp if s.passed]
if user_score and not user_score.passed:
user_score = None
resp = IndexedScoreResp(
scores=score_resp,
user_score=user_score,
total=len(scores),
params={
"limit": limit,
},
)
if has_more:
resp.cursor = {
"total_score": scores[-1].total_score,
}
return resp
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
response_model=ScoreResp,
)
async def show_playlist_score(
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
start_time = time.time()
score_record = None
completed = room.category != RoomCategory.REALTIME
while time.time() - start_time < READ_SCORE_TIMEOUT:
if score_record is None:
score_record = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.score_id == score_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
)
).first()
if completed_players := await redis.get(
f"multiplayer:{room_id}:gameplay:players"
):
completed = completed_players == "0"
if score_record and completed:
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_id, session)
if completed:
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
if score.total_score > resp.total_score:
higher_scores.append(await ScoreResp.from_db(session, score.score))
elif score.total_score < resp.total_score:
lower_scores.append(await ScoreResp.from_db(session, score.score))
resp.scores_around = ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
return resp
@router.get(
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
response_model=ScoreResp,
)
async def get_user_playlist_score(
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
):
score_record = None
start_time = time.time()
while time.time() - start_time < READ_SCORE_TIMEOUT:
score_record = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.user_id == user_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
)
).first()
if score_record:
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(
room_id, playlist_id, score_record.score_id, session
)
return resp
@router.put("/score-pins/{score}", status_code=204)
async def pin_score(
score: int,
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
select(Score).where(
Score.id == score,
Score.user_id == current_user.id,
col(Score.passed).is_(True),
)
)
).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
if score_record.pinned_order > 0:
return
next_order = (
(
await db.exec(
select(func.max(Score.pinned_order)).where(
Score.user_id == current_user.id,
Score.gamemode == score_record.gamemode,
)
)
).first()
or 0
) + 1
score_record.pinned_order = next_order
await db.commit()
@router.delete("/score-pins/{score}", status_code=204)
async def unpin_score(
score: int,
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
select(Score).where(Score.id == score, Score.user_id == current_user.id)
)
).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
if score_record.pinned_order == 0:
return
changed_score = (
await db.exec(
select(Score).where(
Score.user_id == current_user.id,
Score.pinned_order > score_record.pinned_order,
Score.gamemode == score_record.gamemode,
)
)
).all()
for s in changed_score:
s.pinned_order -= 1
await db.commit()
@router.post("/score-pins/{score}/reorder", status_code=204)
async def reorder_score_pin(
score: int,
after_score_id: int | None = Body(default=None),
before_score_id: int | None = Body(default=None),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
select(Score).where(Score.id == score, Score.user_id == current_user.id)
)
).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
if score_record.pinned_order == 0:
raise HTTPException(status_code=400, detail="Score is not pinned")
if (after_score_id is None) == (before_score_id is None):
raise HTTPException(
status_code=400,
detail="Either after_score_id or before_score_id "
"must be provided (but not both)",
)
all_pinned_scores = (
await db.exec(
select(Score)
.where(
Score.user_id == current_user.id,
Score.pinned_order > 0,
Score.gamemode == score_record.gamemode,
)
.order_by(col(Score.pinned_order))
)
).all()
target_order = None
reference_score_id = after_score_id or before_score_id
reference_score = next(
(s for s in all_pinned_scores if s.id == reference_score_id), None
)
if not reference_score:
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
if after_score_id:
target_order = reference_score.pinned_order + 1
else:
target_order = reference_score.pinned_order
current_order = score_record.pinned_order
if current_order == target_order:
return
updates = []
if current_order < target_order:
for s in all_pinned_scores:
if current_order < s.pinned_order <= target_order and s.id != score:
updates.append((s.id, s.pinned_order - 1))
if after_score_id:
final_target = (
target_order - 1 if target_order > current_order else target_order
)
else:
final_target = target_order
else:
for s in all_pinned_scores:
if target_order <= s.pinned_order < current_order and s.id != score:
updates.append((s.id, s.pinned_order + 1))
final_target = target_order
for score_id, new_order in updates:
await db.exec(select(Score).where(Score.id == score_id))
score_to_update = (
await db.exec(select(Score).where(Score.id == score_id))
).first()
if score_to_update:
score_to_update.pinned_order = new_order
score_record.pinned_order = final_target
await db.commit()
@router.get("/scores/{score_id}/download")
async def download_score_replay(
score_id: int,
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
storage_service: StorageService = Depends(get_storage_service),
):
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
filepath = f"replays/{score.id}_{score.beatmap_id}_{score.user_id}_lazer_replay.osr"
if not await storage_service.is_exists(filepath):
raise HTTPException(status_code=404, detail="Replay file not found")
is_friend = (
score.user_id == current_user.id
or (
await db.exec(
select(exists()).where(
Relationship.user_id == current_user.id,
Relationship.target_id == score.user_id,
Relationship.type == RelationshipType.FOLLOW,
)
)
).first()
)
if not is_friend:
replay_watched_count = (
await db.exec(
select(ReplayWatchedCount).where(
ReplayWatchedCount.user_id == score.user_id,
ReplayWatchedCount.year == date.today().year,
ReplayWatchedCount.month == date.today().month,
)
)
).first()
if replay_watched_count is None:
replay_watched_count = ReplayWatchedCount(
user_id=score.user_id, year=date.today().year, month=date.today().month
)
db.add(replay_watched_count)
replay_watched_count.count += 1
await db.commit()
if isinstance(storage_service, LocalStorageService):
return FileResponse(
path=await storage_service.get_file_url(filepath),
filename=filepath,
media_type="application/x-osu-replay",
)
else:
return RedirectResponse(
await storage_service.get_file_url(filepath),
301,
)

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Literal
from app.database import (
BeatmapPlaycounts,
BeatmapPlaycountsResp,
@@ -8,16 +11,18 @@ from app.database import (
UserResp,
)
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from .api_router import router
from .router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel import exists, false, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
@@ -31,6 +36,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse)
async def get_users(
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(default=False), # TODO: future use
session: AsyncSession = Depends(get_db),
):
@@ -59,6 +65,7 @@ async def get_user_info(
user: str,
ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
await session.exec(
@@ -86,7 +93,7 @@ async def get_user_info(
async def get_user_beatmapsets(
user_id: int,
type: BeatmapsetType,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
@@ -130,3 +137,59 @@ async def get_user_beatmapsets(
raise HTTPException(400, detail="Invalid beatmapset type")
return resp
@router.get("/users/{user}/scores/{type}", response_model=list[ScoreResp])
async def get_user_scores(
user: int,
type: Literal["best", "recent", "firsts", "pinned"],
legacy_only: bool = Query(False),
include_fails: bool = Query(False),
mode: GameMode | None = None,
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_user = await session.get(User, user)
if not db_user:
raise HTTPException(404, detail="User not found")
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (
col(Score.gamemode) == gamemode
)
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
where_clause &= Score.pinned_order > 0
order_by = col(Score.pinned_order).asc()
elif type == "best":
where_clause &= exists().where(col(PPBestScore.score_id) == Score.id)
order_by = col(Score.pp).desc()
elif type == "recent":
where_clause &= Score.ended_at > datetime.now(UTC) - timedelta(hours=24)
order_by = col(Score.ended_at).desc()
elif type == "firsts":
# TODO
where_clause &= false()
scores = (
await session.exec(
select(Score)
.where(where_clause)
.order_by(order_by)
.limit(limit)
.offset(offset)
)
).all()
if not scores:
return []
return [
await ScoreResp.from_db(
session,
score,
)
for score in scores
]

10
app/service/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from .daily_challenge import create_daily_challenge_room
from .room import create_playlist_room, create_playlist_room_from_api
__all__ = [
"create_daily_challenge_room",
"create_playlist_room",
"create_playlist_room_from_api",
]

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import json
from app.database.playlists import Playlist
from app.database.room import Room
from app.dependencies.database import engine, get_redis
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.models.metadata_hub import DailyChallengeInfo
from app.models.mods import APIMod
from app.models.room import RoomCategory
from .room import create_playlist_room
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_daily_challenge_room(
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
) -> Room:
async with AsyncSession(engine) as session:
today = datetime.now(UTC).date()
return await create_playlist_room(
session=session,
name=str(today),
host_id=3,
playlist=[
Playlist(
id=0,
room_id=0,
owner_id=3,
ruleset_id=ruleset_id,
beatmap_id=beatmap,
required_mods=required_mods,
)
],
category=RoomCategory.DAILY_CHALLENGE,
duration=duration,
)
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge")
async def daily_challenge_job():
from app.signalr.hub import MetadataHubs
now = datetime.now(UTC)
redis = get_redis()
key = f"daily_challenge:{now.date()}"
if not await redis.exists(key):
return
async with AsyncSession(engine) as session:
room = (
await session.exec(
select(Room).where(
Room.category == RoomCategory.DAILY_CHALLENGE,
col(Room.ends_at) > datetime.now(UTC),
)
)
).first()
if room:
return
try:
beatmap = await redis.hget(key, "beatmap") # pyright: ignore[reportGeneralTypeIssues]
ruleset_id = await redis.hget(key, "ruleset_id") # pyright: ignore[reportGeneralTypeIssues]
required_mods = await redis.hget(key, "required_mods") # pyright: ignore[reportGeneralTypeIssues]
if beatmap is None or ruleset_id is None:
logger.warning(
f"[DailyChallenge] Missing required data for daily challenge {now}."
" Will try again in 5 minutes."
)
get_scheduler().add_job(
daily_challenge_job,
"date",
run_date=datetime.now(UTC) + timedelta(minutes=5),
)
return
beatmap_int = int(beatmap)
ruleset_id_int = int(ruleset_id)
mods_list = []
if required_mods:
mods_list = json.loads(required_mods)
next_day = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
room = await create_daily_challenge_room(
beatmap=beatmap_int,
ruleset_id=ruleset_id_int,
required_mods=mods_list,
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
)
await MetadataHubs.broadcast_call(
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
)
logger.success(
"[DailyChallenge] Added today's daily challenge: "
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
)
return
except (ValueError, json.JSONDecodeError) as e:
logger.warning(
f"[DailyChallenge] Error processing daily challenge data: {e}"
" Will try again in 5 minutes."
)
except Exception as e:
logger.exception(
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
" Will try again in 5 minutes."
)
get_scheduler().add_job(
daily_challenge_job,
"date",
run_date=datetime.now(UTC) + timedelta(minutes=5),
)

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from app.config import settings
from app.database.lazer_user import User
from app.database.statistics import UserStatistics
from app.dependencies.database import engine
from app.models.score import GameMode
from sqlalchemy import exists
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_rx_statistics():
async with AsyncSession(engine) as session:
users = (await session.exec(select(User.id))).all()
for i in users:
if settings.enable_osu_rx:
is_exist = (
await session.exec(
select(exists()).where(
UserStatistics.user_id == i,
UserStatistics.mode == GameMode.OSURX,
)
)
).first()
if not is_exist:
statistics_rx = UserStatistics(mode=GameMode.OSURX, user_id=i)
session.add(statistics_rx)
if settings.enable_osu_ap:
is_exist = (
await session.exec(
select(exists()).where(
UserStatistics.user_id == i,
UserStatistics.mode == GameMode.OSUAP,
)
)
).first()
if not is_exist:
statistics_ap = UserStatistics(mode=GameMode.OSUAP, user_id=i)
session.add(statistics_ap)
await session.commit()

78
app/service/room.py Normal file
View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from app.database.beatmap import Beatmap
from app.database.playlists import Playlist
from app.database.room import APIUploadedRoom, Room
from app.dependencies.fetcher import get_fetcher
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
from sqlalchemy import exists
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_playlist_room_from_api(
session: AsyncSession, room: APIUploadedRoom, host_id: int
) -> Room:
db_room = room.to_room()
db_room.host_id = host_id
db_room.starts_at = datetime.now(UTC)
db_room.ends_at = db_room.starts_at + timedelta(
minutes=db_room.duration if db_room.duration is not None else 0
)
session.add(db_room)
await session.commit()
await session.refresh(db_room)
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
await session.refresh(db_room)
return db_room
async def create_playlist_room(
session: AsyncSession,
name: str,
host_id: int,
category: RoomCategory = RoomCategory.NORMAL,
duration: int = 30,
max_attempts: int | None = None,
playlist: list[Playlist] = [],
) -> Room:
db_room = Room(
name=name,
category=category,
duration=duration,
starts_at=datetime.now(UTC),
ends_at=datetime.now(UTC) + timedelta(minutes=duration),
participant_count=0,
max_attempts=max_attempts,
type=MatchType.PLAYLISTS,
queue_mode=QueueMode.HOST_ONLY,
auto_skip=False,
auto_start_duration=0,
status=RoomStatus.IDLE,
host_id=host_id,
)
session.add(db_room)
await session.commit()
await session.refresh(db_room)
await add_playlists_to_room(session, db_room.id, playlist, host_id)
await session.refresh(db_room)
return db_room
async def add_playlists_to_room(
session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int
):
for item in playlist:
if not (
await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))
).first():
fetcher = await get_fetcher()
await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id)
item.id = await Playlist.get_next_id_for_room(room_id, session)
item.room_id = room_id
item.owner_id = owner_id
session.add(item)
await session.commit()

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from typing import Any
from app.dependencies.database import get_redis_pubsub
class RedisSubscriber:
def __init__(self):
self.pubsub = get_redis_pubsub()
self.handlers: dict[str, list[Callable[[str, str], Awaitable[Any]]]] = {}
self.task: asyncio.Task | None = None
async def subscribe(self, channel: str):
await self.pubsub.subscribe(channel)
if channel not in self.handlers:
self.handlers[channel] = []
async def unsubscribe(self, channel: str):
if channel in self.handlers:
del self.handlers[channel]
await self.pubsub.unsubscribe(channel)
async def listen(self):
while True:
message = await self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=None
)
if message is not None and message["type"] == "message":
method = self.handlers.get(message["channel"])
if method:
await asyncio.gather(
*[
handler(message["channel"], message["data"])
for handler in method
]
)
def start(self):
if self.task is None or self.task.done():
self.task = asyncio.create_task(self.listen())
def stop(self):
if self.task is not None and not self.task.done():
self.task.cancel()
self.task = None

View File

@@ -0,0 +1,87 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from app.database import PlaylistBestScore, Score
from app.database.playlist_best_score import get_position
from app.dependencies.database import engine
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
from .base import RedisSubscriber
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.signalr.hub import MetadataHub
CHANNEL = "score:processed"
class ScoreSubscriber(RedisSubscriber):
def __init__(self):
super().__init__()
self.room_subscriber: dict[int, list[int]] = {}
self.metadata_hub: "MetadataHub | None " = None
self.subscribed = False
self.handlers[CHANNEL] = [self._handler]
async def subscribe_room_score(self, room_id: int, user_id: int):
if room_id not in self.room_subscriber:
await self.subscribe(CHANNEL)
self.start()
self.room_subscriber.setdefault(room_id, []).append(user_id)
async def unsubscribe_room_score(self, room_id: int, user_id: int):
if room_id in self.room_subscriber:
self.room_subscriber[room_id].remove(user_id)
if not self.room_subscriber[room_id]:
del self.room_subscriber[room_id]
async def _notify_room_score_processed(self, score_id: int):
if not self.metadata_hub:
return
async with AsyncSession(engine) as session:
score = await session.get(Score, score_id)
if (
not score
or not score.passed
or score.room_id is None
or score.playlist_item_id is None
):
return
if not self.room_subscriber.get(score.room_id, []):
return
new_rank = None
user_best = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.user_id == score.user_id,
PlaylistBestScore.room_id == score.room_id,
)
)
).first()
if user_best and user_best.score_id == score_id:
new_rank = await get_position(
user_best.room_id,
user_best.playlist_id,
user_best.score_id,
session,
)
event = MultiplayerRoomScoreSetEvent(
room_id=score.room_id,
playlist_item_id=score.playlist_item_id,
score_id=score_id,
user_id=score.user_id,
total_score=score.total_score,
new_rank=new_rank,
)
await self.metadata_hub.notify_room_score_processed(event)
async def _handler(self, channel: str, data: str):
score_id = int(data)
if self.metadata_hub:
await self._notify_room_score_processed(score_id)

View File

@@ -6,9 +6,9 @@ import time
from typing import Any
from app.config import settings
from app.exception import InvokeException
from app.log import logger
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
@@ -74,7 +74,7 @@ class Client:
while True:
try:
await self.send_packet(PingPacket())
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
await asyncio.sleep(settings.signalr_ping_interval)
except WebSocketDisconnect:
break
except Exception as e:
@@ -99,6 +99,16 @@ class Hub[TState: UserState]:
return client
return default
def get_before_clients(self, id: str, current_token: str) -> list[Client]:
clients = []
for client in self.clients.values():
if client.connection_id != id:
continue
if client.connection_token == current_token:
continue
clients.append(client)
return clients
@abstractmethod
def create_state(self, client: Client) -> TState:
raise NotImplementedError
@@ -117,6 +127,11 @@ class Hub[TState: UserState]:
if group_id in self.groups:
self.groups[group_id].discard(client)
async def kick_client(self, client: Client) -> None:
await self.call_noblock(client, "DisconnectRequested")
await client.send_packet(ClosePacket(allow_reconnect=False))
await client.connection.close(code=1000, reason="Disconnected by server")
async def add_client(
self,
connection_id: str,
@@ -131,7 +146,7 @@ class Hub[TState: UserState]:
if connection_token in self.waited_clients:
if (
self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
< time.time() - settings.signalr_negotiate_timeout
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]

View File

@@ -1,18 +1,34 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Coroutine
from datetime import UTC, datetime
import math
from typing import override
from app.database import Relationship, RelationshipType
from app.database.lazer_user import User
from app.calculator import clamp
from app.database import Relationship, RelationshipType, User
from app.database.playlist_best_score import PlaylistBestScore
from app.database.playlists import Playlist
from app.database.room import Room
from app.dependencies.database import engine, get_redis
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from app.models.metadata_hub import (
TOTAL_SCORE_DISTRIBUTION_BINS,
DailyChallengeInfo,
MetadataClientState,
MultiplayerPlaylistItemStats,
MultiplayerRoomScoreSetEvent,
MultiplayerRoomStats,
OnlineStatus,
UserActivity,
)
from app.models.room import RoomCategory
from app.service.subscribers.score_processed import ScoreSubscriber
from .hub import Client, Hub
from sqlmodel import select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
@@ -21,11 +37,33 @@ ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
class MetadataHub(Hub[MetadataClientState]):
def __init__(self) -> None:
super().__init__()
self.subscriber = ScoreSubscriber()
self.subscriber.metadata_hub = self
self._daily_challenge_stats: MultiplayerRoomStats | None = None
self._today = datetime.now(UTC).date()
self._lock = asyncio.Lock()
def get_daily_challenge_stats(
self, daily_challenge_room: int
) -> MultiplayerRoomStats:
if (
self._daily_challenge_stats is None
or self._today != datetime.now(UTC).date()
):
self._daily_challenge_stats = MultiplayerRoomStats(
room_id=daily_challenge_room,
playlist_item_stats={},
)
return self._daily_challenge_stats
@staticmethod
def online_presence_watchers_group() -> str:
return ONLINE_PRESENCE_WATCHERS_GROUP
@staticmethod
def room_watcher_group(room_id: int) -> str:
return f"metadata:multiplayer-room-watchers:{room_id}"
def broadcast_tasks(
self, user_id: int, store: MetadataClientState | None
) -> set[Coroutine]:
@@ -102,10 +140,29 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state if friend_state.pushable else None,
friend_state.for_push
if friend_state.pushable
else None,
)
)
await asyncio.gather(*tasks)
daily_challenge_room = (
await session.exec(
select(Room).where(
col(Room.ends_at) > datetime.now(UTC),
Room.category == RoomCategory.DAILY_CHALLENGE,
)
)
).first()
if daily_challenge_room:
await self.call_noblock(
client,
"DailyChallengeUpdated",
DailyChallengeInfo(
room_id=daily_challenge_room.id,
),
)
redis = get_redis()
await redis.set(f"metadata:online:{user_id}", "")
@@ -161,3 +218,76 @@ class MetadataHub(Hub[MetadataClientState]):
async def EndWatchingUserPresence(self, client: Client) -> None:
self.remove_from_group(client, self.online_presence_watchers_group())
async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent):
await self.broadcast_group_call(
self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event
)
async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int):
self.add_to_group(client, self.room_watcher_group(room_id))
await self.subscriber.subscribe_room_score(room_id, client.user_id)
stats = self.get_daily_challenge_stats(room_id)
await self.update_daily_challenge_stats(stats)
return list(stats.playlist_item_stats.values())
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
async with AsyncSession(engine) as session:
playlist_ids = (
await session.exec(
select(Playlist.id).where(
Playlist.room_id == stats.room_id,
)
)
).all()
for playlist_id in playlist_ids:
item = stats.playlist_item_stats.get(playlist_id, None)
if item is None:
item = MultiplayerPlaylistItemStats(
playlist_item_id=playlist_id,
total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS,
cumulative_score=0,
last_processed_score_id=0,
)
stats.playlist_item_stats[playlist_id] = item
last_processed_score_id = item.last_processed_score_id
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == stats.room_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.score_id > last_processed_score_id,
)
)
).all()
if len(scores) == 0:
continue
async with self._lock:
if item.last_processed_score_id == last_processed_score_id:
totals = defaultdict(int)
for score in scores:
bin_index = int(
clamp(
math.floor(score.total_score / 100000),
0,
TOTAL_SCORE_DISTRIBUTION_BINS - 1,
)
)
totals[bin_index] += 1
item.cumulative_score += sum(
score.total_score for score in scores
)
for j in range(TOTAL_SCORE_DISTRIBUTION_BINS):
item.total_score_distribution[j] += totals.get(j, 0)
if scores:
item.last_processed_score_id = max(
score.score_id for score in scores
)
async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int):
self.remove_from_group(client, self.room_watcher_group(room_id))
await self.subscriber.unsubscribe_room_score(room_id, client.user_id)

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,13 @@ import struct
import time
from typing import override
from app.config import settings
from app.database import Beatmap, User
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
@@ -24,7 +26,6 @@ from app.models.spectator_hub import (
StoreClientState,
StoreScore,
)
from app.path import REPLAY_DIR
from app.utils import unix_timestamp_to_windows
from .hub import Client, Hub
@@ -63,7 +64,7 @@ def encode_string(s: str) -> bytes:
return ret
def save_replay(
async def save_replay(
ruleset_id: int,
md5: str,
username: str,
@@ -135,8 +136,14 @@ def save_replay(
data.extend(struct.pack("<i", len(compressed)))
data.extend(compressed)
replay_path = REPLAY_DIR / f"lazer-{score.type}-{username}-{score.id}.osr"
replay_path.write_bytes(data)
storage_service = get_storage_service()
replay_path = (
f"replays/{score.id}_{score.beatmap_id}_{score.user_id}_lazer_replay.osr"
)
await storage_service.write_file(
replay_path,
bytes(data),
)
class SpectatorHub(Hub[StoreClientState]):
@@ -179,15 +186,13 @@ class SpectatorHub(Hub[StoreClientState]):
return
if state.beatmap_id is None or state.ruleset_id is None:
return
fetcher = await get_fetcher()
async with AsyncSession(engine) as session:
async with session.begin():
beatmap = (
await session.exec(
select(Beatmap).where(Beatmap.id == state.beatmap_id)
)
).first()
if not beatmap:
return
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=state.beatmap_id
)
user = (
await session.exec(select(User).where(User.id == user_id))
).first()
@@ -237,16 +242,17 @@ class SpectatorHub(Hub[StoreClientState]):
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
score = store.score
assert store.beatmap_status is not None
assert store.state is not None
assert store.score is not None
if not score or not store.score_token:
if (
score is None
or store.score_token is None
or store.beatmap_status is None
or store.state is None
):
return
if (
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
) and any(
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
):
settings.enable_all_beatmap_leaderboard
and store.beatmap_status.has_leaderboard()
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
await self._process_score(store, client)
store.state = None
store.beatmap_status = None
@@ -296,7 +302,7 @@ class SpectatorHub(Hub[StoreClientState]):
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
await save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,

View File

@@ -15,7 +15,7 @@ from typing import (
)
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
import msgpack_lazer_api as m
from pydantic import BaseModel
@@ -97,6 +97,8 @@ class MsgpackProtocol:
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds() * 10_000_000)
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
@@ -126,15 +128,19 @@ class MsgpackProtocol:
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
for i, f in enumerate(typ.model_fields.items()):
field, info = f
if info.exclude:
i = 0
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
continue
d[field] = MsgpackProtocol.validate_object(v[i], anno)
else:
d[field] = MsgpackProtocol.validate_object(v[i], anno)
i += 1
return d
return v
@@ -209,7 +215,9 @@ class MsgpackProtocol:
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif isinstance(v, list):
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v / 10_000_000))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
@@ -234,7 +242,9 @@ class MsgpackProtocol:
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
@@ -292,36 +302,55 @@ class MsgpackProtocol:
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any):
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v)
return cls.serialize_model(v, in_union)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k): cls.serialize_to_json(value)
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
for k, value in v.items()
}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, Enum):
elif isinstance(v, datetime.timedelta):
# d.hh:mm:ss
total_seconds = int(v.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{hours:02}:{minutes:02}:{seconds:02}"
elif isinstance(v, Enum) and dict_key:
return v.value
elif isinstance(v, Enum):
list_ = list(typ)
return list_.index(v)
return v
@classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
d = {}
is_union = issubclass(v.__class__, SignalRUnionMessage)
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
cls.serialize_to_json(getattr(v, field))
name = (
snake_to_camel(
field,
metadata.use_abbr if metadata else True,
)
if not is_union
else snake_to_pascal(
field,
metadata.use_abbr if metadata else True,
)
)
if issubclass(v.__class__, SignalRUnionMessage):
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
if is_union and not in_union:
return {
"$dtype": v.__class__.__name__,
"$value": d,
@@ -339,7 +368,12 @@ class JSONProtocol:
)
if metadata and metadata.json_ignore:
continue
value = v.get(snake_to_camel(field, not from_union))
name = (
snake_to_camel(field, metadata.use_abbr if metadata else True)
if not from_union
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
)
value = v.get(name)
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
@@ -397,7 +431,18 @@ class JSONProtocol:
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif isinstance(v, list):
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
# d.hh:mm:ss
parts = v.split(":")
if len(parts) == 3:
return datetime.timedelta(
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
)
elif len(parts) == 2:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0]))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)

View File

@@ -6,26 +6,26 @@ import time
from typing import Literal
import uuid
from app.database import User
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
router = APIRouter(prefix="/signalr", tags=["SignalR"])
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
async def negotiate(
hub: Literal["spectator", "multiplayer", "metadata"],
negotiate_version: int = Query(1, alias="negotiateVersion"),
user: User = Depends(get_current_user),
user: DBUser = Depends(get_current_user),
):
connectionId = str(user.id)
connectionToken = f"{connectionId}:{uuid.uuid4()}"
@@ -55,9 +55,15 @@ async def connect(
if id not in hub_:
await websocket.close(code=1008)
return
if (user := await get_current_user_by_token(token, db)) is None or str(
user.id
) != user_id:
try:
if (
user := await get_current_user(
SecurityScopes(scopes=["*"]), db, token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
except HTTPException:
await websocket.close(code=1008)
return
await websocket.accept()
@@ -92,6 +98,11 @@ async def connect(
if error or not client:
await websocket.close(code=1008)
return
connected_clients = hub_.get_before_clients(user_id, id)
for connected_client in connected_clients:
await hub_.kick_client(connected_client)
await hub_.clean_state(client, False)
task = asyncio.create_task(hub_.on_connect(client))
hub_.tasks.add(task)

13
app/storage/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from .aws_s3 import AWSS3StorageService
from .base import StorageService
from .cloudflare_r2 import CloudflareR2StorageService
from .local import LocalStorageService
__all__ = [
"AWSS3StorageService",
"CloudflareR2StorageService",
"LocalStorageService",
"StorageService",
]

103
app/storage/aws_s3.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from .base import StorageService
import aioboto3
from botocore.exceptions import ClientError
class AWSS3StorageService(StorageService):
def __init__(
self,
access_key_id: str,
secret_access_key: str,
bucket_name: str,
region_name: str,
public_url_base: str | None = None,
):
self.bucket_name = bucket_name
self.public_url_base = public_url_base
self.session = aioboto3.Session()
self.access_key_id = access_key_id
self.secret_access_key = secret_access_key
self.region_name = region_name
@property
def endpoint_url(self) -> str | None:
return None
def _get_client(self):
return self.session.client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.access_key_id,
aws_secret_access_key=self.secret_access_key,
region_name=self.region_name,
)
async def write_file(
self,
file_path: str,
content: bytes,
content_type: str = "application/octet-stream",
cache_control: str = "public, max-age=31536000",
) -> None:
async with self._get_client() as client:
await client.put_object(
Bucket=self.bucket_name,
Key=file_path,
Body=content,
ContentType=content_type,
CacheControl=cache_control,
)
async def read_file(self, file_path: str) -> bytes:
async with self._get_client() as client:
try:
response = await client.get_object(
Bucket=self.bucket_name,
Key=file_path,
)
async with response["Body"] as stream:
return await stream.read()
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "404":
raise FileNotFoundError(f"File not found: {file_path}")
raise RuntimeError(f"Failed to read file from R2: {e}")
async def delete_file(self, file_path: str) -> None:
async with self._get_client() as client:
try:
await client.delete_object(
Bucket=self.bucket_name,
Key=file_path,
)
except ClientError as e:
raise RuntimeError(f"Failed to delete file from R2: {e}")
async def is_exists(self, file_path: str) -> bool:
async with self._get_client() as client:
try:
await client.head_object(
Bucket=self.bucket_name,
Key=file_path,
)
return True
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "404":
return False
raise RuntimeError(f"Failed to check file existence in R2: {e}")
async def get_file_url(self, file_path: str) -> str:
if self.public_url_base:
return f"{self.public_url_base.rstrip('/')}/{file_path.lstrip('/')}"
async with self._get_client() as client:
try:
url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": self.bucket_name, "Key": file_path},
)
return url
except ClientError as e:
raise RuntimeError(f"Failed to generate file URL: {e}")

34
app/storage/base.py Normal file
View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import abc
class StorageService(abc.ABC):
@abc.abstractmethod
async def write_file(
self,
file_path: str,
content: bytes,
content_type: str = "application/octet-stream",
cache_control: str = "public, max-age=31536000",
) -> None:
raise NotImplementedError
@abc.abstractmethod
async def read_file(self, file_path: str) -> bytes:
raise NotImplementedError
@abc.abstractmethod
async def delete_file(self, file_path: str) -> None:
raise NotImplementedError
@abc.abstractmethod
async def is_exists(self, file_path: str) -> bool:
raise NotImplementedError
@abc.abstractmethod
async def get_file_url(self, file_path: str) -> str:
raise NotImplementedError
async def close(self) -> None:
pass

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from .aws_s3 import AWSS3StorageService
class CloudflareR2StorageService(AWSS3StorageService):
def __init__(
self,
account_id: str,
access_key_id: str,
secret_access_key: str,
bucket_name: str,
public_url_base: str | None = None,
):
super().__init__(
access_key_id=access_key_id,
secret_access_key=secret_access_key,
bucket_name=bucket_name,
public_url_base=public_url_base,
region_name="auto",
)
self.account_id = account_id
@property
def endpoint_url(self) -> str:
return f"https://{self.account_id}.r2.cloudflarestorage.com"

80
app/storage/local.py Normal file
View File

@@ -0,0 +1,80 @@
from __future__ import annotations
from pathlib import Path
from app.config import settings
from .base import StorageService
import aiofiles
class LocalStorageService(StorageService):
def __init__(
self,
storage_path: str,
):
self.storage_path = Path(storage_path).resolve()
self.storage_path.mkdir(parents=True, exist_ok=True)
def _get_file_path(self, file_path: str) -> Path:
clean_path = file_path.lstrip("/")
full_path = self.storage_path / clean_path
try:
full_path.resolve().relative_to(self.storage_path)
except ValueError:
raise ValueError(f"Invalid file path: {file_path}")
return full_path
async def write_file(
self,
file_path: str,
content: bytes,
content_type: str = "application/octet-stream",
cache_control: str = "public, max-age=31536000",
) -> None:
full_path = self._get_file_path(file_path)
full_path.parent.mkdir(parents=True, exist_ok=True)
try:
async with aiofiles.open(full_path, "wb") as f:
await f.write(content)
except OSError as e:
raise RuntimeError(f"Failed to write file: {e}")
async def read_file(self, file_path: str) -> bytes:
full_path = self._get_file_path(file_path)
if not full_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
try:
async with aiofiles.open(full_path, "rb") as f:
return await f.read()
except OSError as e:
raise RuntimeError(f"Failed to read file: {e}")
async def delete_file(self, file_path: str) -> None:
full_path = self._get_file_path(file_path)
if not full_path.exists():
return
try:
full_path.unlink()
parent = full_path.parent
while parent != self.storage_path and not any(parent.iterdir()):
parent.rmdir()
parent = parent.parent
except OSError as e:
raise RuntimeError(f"Failed to delete file: {e}")
async def is_exists(self, file_path: str) -> bool:
full_path = self._get_file_path(file_path)
return full_path.exists() and full_path.is_file()
async def get_file_url(self, file_path: str) -> str:
return f"{settings.server_url}file/{file_path.lstrip('/')}"

View File

@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True) -> str:
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to camelCase."""
if not name:
return name
@@ -47,12 +47,46 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str:
result = []
for part in parts:
if part.lower() in abbreviations:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
if result or not lower_case:
if result:
result.append(part.capitalize())
else:
result.append(part.lower())
return "".join(result)
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to PascalCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
result.append(part.capitalize())
return "".join(result)

79
docker-compose-osurx.yml Normal file
View File

@@ -0,0 +1,79 @@
version: '3.8'
services:
app:
# or use
# image: mingxuangame/osu-lazer-api-osurx:latest
build:
context: .
dockerfile: Dockerfile-osurx
container_name: osu_api_server_osurx
ports:
- "8000:8000"
environment:
- MYSQL_HOST=mysql
- MYSQL_PORT=3306
- REDIS_URL=redis://redis:6379/0
- ENABLE_OSU_RX=true
- ENABLE_OSU_AP=true
- ENABLE_ALL_MODS_PP=true
- ENABLE_SUPPORTER_FOR_ALL_USERS=true
- ENABLE_ALL_BEATMAP_LEADERBOARD=true
env_file:
- .env
depends_on:
mysql:
condition: service_healthy
redis:
condition: service_healthy
volumes:
- ./replays:/app/replays
- ./static:/app/static
restart: unless-stopped
networks:
- osu-network
mysql:
image: mysql:8.0
container_name: osu_api_mysql_osurx
environment:
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
- MYSQL_DATABASE=${MYSQL_DATABASE}
- MYSQL_USER=${MYSQL_USER}
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
volumes:
- mysql_data:/var/lib/mysql
- ./mysql-init:/docker-entrypoint-initdb.d
healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
timeout: 20s
retries: 10
interval: 10s
start_period: 40s
restart: unless-stopped
networks:
- osu-network
redis:
image: redis:7-alpine
container_name: osu_api_redis_osurx
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
timeout: 5s
retries: 5
interval: 10s
start_period: 10s
restart: unless-stopped
networks:
- osu-network
command: redis-server --appendonly yes
volumes:
mysql_data:
redis_data:
networks:
osu-network:
driver: bridge

View File

@@ -1,50 +1,74 @@
version: '3.8'
services:
mysql:
image: mysql:8.0
container_name: osu_api_mysql
environment:
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: osu_api
MYSQL_USER: osu_user
MYSQL_PASSWORD: osu_password
ports:
- "3306:3306"
volumes:
- mysql_data:/var/lib/mysql
- ./mysql-init:/docker-entrypoint-initdb.d
restart: unless-stopped
redis:
image: redis:7-alpine
container_name: osu_api_redis
ports:
- "6379:6379"
volumes:
- redis_data:/data
restart: unless-stopped
command: redis-server --appendonly yes
api:
build: .
container_name: osu_api_server
ports:
- "8000:8000"
environment:
DATABASE_URL: mysql+aiomysql://osu_user:osu_password@mysql:3306/osu_api
REDIS_URL: redis://redis:6379/0
SECRET_KEY: your-production-secret-key-here
OSU_CLIENT_ID: "5"
OSU_CLIENT_SECRET: "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
depends_on:
- mysql
- redis
restart: unless-stopped
volumes:
- ./:/app
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
volumes:
mysql_data:
redis_data:
version: '3.8'
services:
app:
# or use
# image: mingxuangame/osu-lazer-api:latest
build:
context: .
dockerfile: Dockerfile
container_name: osu_api_server
ports:
- "8000:8000"
environment:
- MYSQL_HOST=mysql
- MYSQL_PORT=3306
- REDIS_URL=redis://redis:6379/0
env_file:
- .env
depends_on:
mysql:
condition: service_healthy
redis:
condition: service_healthy
volumes:
- ./replays:/app/replays
- ./static:/app/static
restart: unless-stopped
networks:
- osu-network
mysql:
image: mysql:8.0
container_name: osu_api_mysql
environment:
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
- MYSQL_DATABASE=${MYSQL_DATABASE}
- MYSQL_USER=${MYSQL_USER}
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
volumes:
- mysql_data:/var/lib/mysql
- ./mysql-init:/docker-entrypoint-initdb.d
healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
timeout: 20s
retries: 10
interval: 10s
start_period: 40s
restart: unless-stopped
networks:
- osu-network
redis:
image: redis:7-alpine
container_name: osu_api_redis
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
timeout: 5s
retries: 5
interval: 10s
start_period: 10s
restart: unless-stopped
networks:
- osu-network
command: redis-server --appendonly yes
volumes:
mysql_data:
redis_data:
networks:
osu-network:
driver: bridge

13
docker-entrypoint.sh Normal file
View File

@@ -0,0 +1,13 @@
#!/bin/bash
set -e
echo "Waiting for database connection..."
while ! nc -z $MYSQL_HOST $MYSQL_PORT; do
sleep 1
done
echo "Database connected"
echo "Running alembic..."
uv run --no-sync alembic upgrade head
exec "$@"

158
main.py
View File

@@ -4,29 +4,55 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.database import engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.router import api_router, auth_router, fetcher_router, signalr_router
from app.dependencies.scheduler import init_scheduler, stop_scheduler
from app.log import logger
from app.router import (
api_v2_router,
auth_router,
fetcher_router,
file_router,
private_router,
signalr_router,
)
from app.service.daily_challenge import daily_challenge_job
from app.service.osu_rx_statistics import create_rx_statistics
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
await create_rx_statistics()
await get_fetcher() # 初始化 fetcher
init_scheduler()
await daily_challenge_job()
# on shutdown
yield
stop_scheduler()
await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
app.include_router(api_router, prefix="/api/v2")
app.include_router(signalr_router, prefix="/signalr")
app.include_router(fetcher_router, prefix="/fetcher")
app.include_router(api_v2_router)
app.include_router(signalr_router)
app.include_router(fetcher_router)
app.include_router(file_router)
app.include_router(auth_router)
app.include_router(private_router)
# CORS 配置
app.add_middleware(
CORSMiddleware,
allow_origins=[str(settings.server_url)],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
@@ -41,114 +67,30 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
# @app.get("/api/v2/friends")
# async def get_friends():
# return JSONResponse(
# content=[
# {
# "id": 123456,
# "username": "BestFriend",
# "is_online": True,
# "is_supporter": False,
# "country": {"code": "US", "name": "United States"},
# }
# ]
# )
# @app.get("/api/v2/notifications")
# async def get_notifications():
# return JSONResponse(content={"notifications": [], "unread_count": 0})
# @app.post("/api/v2/chat/ack")
# async def chat_ack():
# return JSONResponse(content={"status": "ok"})
# @app.get("/api/v2/users/{user_id}/{mode}")
# async def get_user_mode(user_id: int, mode: str):
# return JSONResponse(
# content={
# "id": user_id,
# "username": "测试测试测",
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 114514,
# "global_rank": 666,
# "country_rank": 1,
# "hit_accuracy": 100,
# },
# "country": {"code": "JP", "name": "Japan"},
# }
# )
# @app.get("/api/v2/me")
# async def get_me():
# return JSONResponse(
# content={
# "id": 15651670,
# "username": "Googujiang",
# "is_online": True,
# "country": {"code": "JP", "name": "Japan"},
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 2826.26,
# "global_rank": 298026,
# "country_rank": 11220,
# "hit_accuracy": 95.7168,
# },
# }
# )
# @app.post("/signalr/metadata/negotiate")
# async def metadata_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "abc123",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/spectator/negotiate")
# async def spectator_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "spec456",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/multiplayer/negotiate")
# async def multiplayer_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "multi789",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
if settings.secret_key == "your_jwt_secret_here":
logger.warning(
"jwt_secret_key is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 32"
)
if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
logger.warning(
"osu_web_client_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 40"
)
if settings.private_api_secret == "your_private_api_secret_here":
logger.warning(
"private_api_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 32"
)
if __name__ == "__main__":
from app.log import logger # noqa: F401
import uvicorn
uvicorn.run(
"main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG,
host=settings.host,
port=settings.port,
reload=settings.debug,
log_config=None, # 禁用uvicorn默认日志配置
access_log=True, # 启用访问日志
)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import asyncio
from logging.config import fileConfig
import os
from app.config import settings
from app.database import * # noqa: F403
from alembic import context
@@ -45,7 +45,8 @@ def run_migrations_offline() -> None:
script output.
"""
url = os.environ.get("DATABASE_URL", config.get_main_option("sqlalchemy.url"))
url = settings.database_url
print(url)
context.configure(
url=url,
target_metadata=target_metadata,
@@ -73,8 +74,7 @@ async def run_async_migrations() -> None:
"""
sa_config = config.get_section(config.config_ini_section, {})
if db_url := os.environ.get("DATABASE_URL"):
sa_config["sqlalchemy.url"] = db_url
sa_config["sqlalchemy.url"] = settings.database_url
connectable = async_engine_from_config(
sa_config,
prefix="sqlalchemy.",

View File

@@ -0,0 +1,116 @@
"""gamemode: add osurx & osupp
Revision ID: 19cdc9ce4dcb
Revises: fdb3822a30ba
Create Date: 2025-08-10 06:10:08.093591
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "19cdc9ce4dcb"
down_revision: str | Sequence[str] | None = "fdb3822a30ba"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"lazer_users",
"playmode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"beatmaps",
"mode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"lazer_user_statistics",
"mode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"score_tokens",
"ruleset_id",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"scores",
"gamemode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"best_scores",
"gamemode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
op.alter_column(
"total_score_best_scores",
"gamemode",
type_=sa.Enum(
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
),
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"total_score_best_scores",
"gamemode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"best_scores",
"gamemode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"scores",
"gamemode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"score_tokens",
"ruleset_id",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"lazer_user_statistics",
"mode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"beatmaps",
"mode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
op.alter_column(
"lazer_users",
"playmode",
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
)
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More