Merge branch 'main' of https://github.com/GooGuTeam/osu_lazer_api
This commit is contained in:
@@ -4,6 +4,13 @@
|
||||
"service": "devcontainer",
|
||||
"shutdownAction": "stopCompose",
|
||||
"workspaceFolder": "/workspaces/osu_lazer_api",
|
||||
"containerEnv": {
|
||||
"MYSQL_DATABASE": "osu_api",
|
||||
"MYSQL_USER": "osu_user",
|
||||
"MYSQL_PASSWORD": "osu_password",
|
||||
"MYSQL_HOST": "mysql",
|
||||
"MYSQL_PORT": "3306"
|
||||
},
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
@@ -66,6 +73,6 @@
|
||||
3306,
|
||||
6379
|
||||
],
|
||||
"postCreateCommand": "uv sync --dev && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check",
|
||||
"postCreateCommand": "uv sync --dev && uv pip install rosu-pp-py && uv run alembic upgrade head && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check",
|
||||
"remoteUser": "vscode"
|
||||
}
|
||||
}
|
||||
5
.dockerignore
Normal file
5
.dockerignore
Normal file
@@ -0,0 +1,5 @@
|
||||
.venv/
|
||||
.ruff_cache/
|
||||
.vscode/
|
||||
storage/
|
||||
replays/
|
||||
@@ -1,4 +0,0 @@
|
||||
# osu! API 客户端配置
|
||||
OSU_CLIENT_ID=5
|
||||
OSU_CLIENT_SECRET=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk
|
||||
OSU_API_URL=http://localhost:8000
|
||||
78
.env.example
Normal file
78
.env.example
Normal file
@@ -0,0 +1,78 @@
|
||||
# 数据库设置
|
||||
MYSQL_HOST="localhost"
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_DATABASE="osu_api"
|
||||
MYSQL_USER="osu_api"
|
||||
MYSQL_PASSWORD="password"
|
||||
MYSQL_ROOT_PASSWORD="password"
|
||||
# Redis URL
|
||||
REDIS_URL="redis://127.0.0.1:6379/0"
|
||||
|
||||
# JWT 密钥,使用 openssl rand -hex 32 生成
|
||||
JWT_SECRET_KEY="your_jwt_secret_here"
|
||||
# JWT 算法
|
||||
ALGORITHM="HS256"
|
||||
# JWT 过期时间
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
||||
|
||||
# 服务器地址
|
||||
HOST="0.0.0.0"
|
||||
PORT=8000
|
||||
# 服务器 URL
|
||||
SERVER_URL="http://localhost:8000"
|
||||
# 调试模式,生产环境请设置为 false
|
||||
DEBUG=false
|
||||
# 私有 API 密钥,用于前后端 API 调用,使用 openssl rand -hex 32 生成
|
||||
PRIVATE_API_SECRET="your_private_api_secret_here"
|
||||
|
||||
# osu! 登录设置
|
||||
OSU_CLIENT_ID=5 # lazer client ID
|
||||
OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" # lazer client secret
|
||||
OSU_WEB_CLIENT_ID=6 # 网页端 client ID
|
||||
OSU_WEB_CLIENT_SECRET="your_osu_web_client_secret_here" # 网页端 client secret,使用 openssl rand -hex 40 生成
|
||||
|
||||
# SignalR 服务器设置
|
||||
SIGNALR_NEGOTIATE_TIMEOUT=30
|
||||
SIGNALR_PING_INTERVAL=15
|
||||
|
||||
# Fetcher 设置
|
||||
FETCHER_CLIENT_ID=""
|
||||
FETCHER_CLIENT_SECRET=""
|
||||
FETCHER_SCOPES=public
|
||||
|
||||
# 日志设置
|
||||
LOG_LEVEL="INFO"
|
||||
|
||||
# 游戏设置
|
||||
ENABLE_OSU_RX=false # 启用 osu!RX 统计数据
|
||||
ENABLE_OSU_AP=false # 启用 osu!AP 统计数据
|
||||
ENABLE_ALL_MODS_PP=false # 启用所有 Mod 的 PP 计算
|
||||
ENABLE_SUPPORTER_FOR_ALL_USERS=false # 启用所有新注册用户的支持者状态
|
||||
ENABLE_ALL_BEATMAP_LEADERBOARD=false # 启用所有谱面的排行榜(没有排行榜的谱面会以 APPROVED 状态返回)
|
||||
SEASONAL_BACKGROUNDS='[]' # 季节背景图 URL 列表
|
||||
|
||||
# 存储服务设置
|
||||
# 支持的存储类型:local(本地存储)、r2(Cloudflare R2)、s3(AWS S3)
|
||||
STORAGE_SERVICE="local"
|
||||
|
||||
# 存储服务配置 (JSON 格式)
|
||||
# 本地存储配置(当 STORAGE_SERVICE=local 时)
|
||||
STORAGE_SETTINGS='{"local_storage_path": "./storage"}'
|
||||
|
||||
# Cloudflare R2 存储配置(当 STORAGE_SERVICE=r2 时)
|
||||
# STORAGE_SETTINGS='{
|
||||
# "r2_account_id": "your_cloudflare_r2_account_id",
|
||||
# "r2_access_key_id": "your_r2_access_key_id",
|
||||
# "r2_secret_access_key": "your_r2_secret_access_key",
|
||||
# "r2_bucket_name": "your_r2_bucket_name",
|
||||
# "r2_public_url_base": "https://your-custom-domain.com"
|
||||
# }'
|
||||
|
||||
# AWS S3 存储配置(当 STORAGE_SERVICE=s3 时)
|
||||
# STORAGE_SETTINGS='{
|
||||
# "s3_access_key_id": "your_aws_access_key_id",
|
||||
# "s3_secret_access_key": "your_aws_secret_access_key",
|
||||
# "s3_bucket_name": "your_s3_bucket_name",
|
||||
# "s3_region_name": "us-east-1",
|
||||
# "s3_public_url_base": "https://your-custom-domain.com"
|
||||
# }'
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -37,6 +37,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
test-cert/
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
@@ -184,9 +185,9 @@ cython_debug/
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
@@ -211,5 +212,6 @@ bancho.py-master/*
|
||||
.vscode/settings.json
|
||||
|
||||
# runtime file
|
||||
storage/
|
||||
replays/
|
||||
osu-master/*
|
||||
osu-master/*
|
||||
|
||||
3
.idea/.gitignore
generated
vendored
3
.idea/.gitignore
generated
vendored
@@ -1,3 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
17
.idea/inspectionProfiles/Project_Default.xml
generated
17
.idea/inspectionProfiles/Project_Default.xml
generated
@@ -1,17 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ourVersions">
|
||||
<value>
|
||||
<list size="4">
|
||||
<item index="0" class="java.lang.String" itemvalue="3.7" />
|
||||
<item index="1" class="java.lang.String" itemvalue="3.11" />
|
||||
<item index="2" class="java.lang.String" itemvalue="3.12" />
|
||||
<item index="3" class="java.lang.String" itemvalue="3.13" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
10
.idea/misc.xml
generated
10
.idea/misc.xml
generated
@@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="osu_lazer_api" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="uv (osu_lazer_api)" project-jdk-type="Python SDK" />
|
||||
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/osu_lazer_api.iml" filepath="$PROJECT_DIR$/.idea/osu_lazer_api.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
14
.idea/osu_lazer_api.iml
generated
14
.idea/osu_lazer_api.iml
generated
@@ -1,14 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="uv (osu_lazer_api)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -1,140 +0,0 @@
|
||||
# Lazer API 数据同步指南
|
||||
|
||||
本指南将帮助您将现有的 bancho.py 数据库数据同步到新的 Lazer API 专用表中。
|
||||
|
||||
## 文件说明
|
||||
|
||||
1. **`migrations_old/add_missing_fields.sql`** - 创建 Lazer API 专用表结构
|
||||
2. **`migrations_old/sync_legacy_data.sql`** - 数据同步脚本
|
||||
3. **`sync_data.py`** - 交互式数据同步工具
|
||||
4. **`quick_sync.py`** - 快速同步脚本(使用项目配置)
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 方法一:快速同步(推荐)
|
||||
|
||||
如果您已经配置好了项目的数据库连接,可以直接使用快速同步脚本:
|
||||
|
||||
```bash
|
||||
python quick_sync.py
|
||||
```
|
||||
|
||||
此脚本会:
|
||||
1. 自动读取项目配置中的数据库连接信息
|
||||
2. 创建 Lazer API 专用表结构
|
||||
3. 同步现有数据到新表
|
||||
|
||||
### 方法二:交互式同步
|
||||
|
||||
如果需要使用不同的数据库连接配置:
|
||||
|
||||
```bash
|
||||
python sync_data.py
|
||||
```
|
||||
|
||||
此脚本会:
|
||||
1. 交互式地询问数据库连接信息
|
||||
2. 检查必要表是否存在
|
||||
3. 显示详细的同步过程和结果
|
||||
|
||||
### 方法三:手动执行 SQL
|
||||
|
||||
如果您熟悉 SQL 操作,可以手动执行:
|
||||
|
||||
```bash
|
||||
# 1. 创建表结构
|
||||
mysql -u username -p database_name < migrations_old/add_missing_fields.sql
|
||||
|
||||
# 2. 同步数据
|
||||
mysql -u username -p database_name < migrations_old/sync_legacy_data.sql
|
||||
```
|
||||
|
||||
## 同步内容
|
||||
|
||||
### 创建的新表
|
||||
|
||||
- `lazer_user_profiles` - 用户扩展资料
|
||||
- `lazer_user_countries` - 用户国家信息
|
||||
- `lazer_user_kudosu` - 用户 Kudosu 统计
|
||||
- `lazer_user_counts` - 用户各项计数统计
|
||||
- `lazer_user_statistics` - 用户游戏统计(按模式)
|
||||
- `lazer_user_achievements` - 用户成就
|
||||
- `lazer_oauth_tokens` - OAuth 访问令牌
|
||||
- 其他相关表...
|
||||
|
||||
### 同步的数据
|
||||
|
||||
1. **用户基本信息**
|
||||
- 从 `users` 表同步基本资料
|
||||
- 自动转换时间戳格式
|
||||
- 设置合理的默认值
|
||||
|
||||
2. **游戏统计**
|
||||
- 从 `stats` 表同步各模式的游戏数据
|
||||
- 计算命中精度和其他衍生统计
|
||||
|
||||
3. **用户成就**
|
||||
- 从 `user_achievements` 表同步成就数据(如果存在)
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **安全性**
|
||||
- 脚本只会创建新表和插入数据
|
||||
- 不会修改或删除现有的原始表数据
|
||||
- 使用 `ON DUPLICATE KEY UPDATE` 避免重复插入
|
||||
|
||||
2. **兼容性**
|
||||
- 兼容现有的 bancho.py 数据库结构
|
||||
- 支持标准的 osu! 数据格式
|
||||
|
||||
3. **性能**
|
||||
- 大量数据可能需要较长时间
|
||||
- 建议在维护窗口期间执行
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见错误
|
||||
|
||||
1. **"Unknown column" 错误**
|
||||
```
|
||||
ERROR 1054: Unknown column 'users.is_active' in 'field list'
|
||||
```
|
||||
**解决方案**: 确保先执行了 `add_missing_fields.sql` 创建表结构
|
||||
|
||||
2. **"Table doesn't exist" 错误**
|
||||
```
|
||||
ERROR 1146: Table 'database.users' doesn't exist
|
||||
```
|
||||
**解决方案**: 确认数据库中存在 bancho.py 的原始表
|
||||
|
||||
3. **连接错误**
|
||||
```
|
||||
ERROR 2003: Can't connect to MySQL server
|
||||
```
|
||||
**解决方案**: 检查数据库连接配置和权限
|
||||
|
||||
### 验证同步结果
|
||||
|
||||
同步完成后,可以执行以下查询验证结果:
|
||||
|
||||
```sql
|
||||
-- 检查同步的用户数量
|
||||
SELECT COUNT(*) FROM lazer_user_profiles;
|
||||
|
||||
-- 查看样本数据
|
||||
SELECT
|
||||
u.id, u.name,
|
||||
lup.playmode, lup.is_supporter,
|
||||
lus.pp, lus.play_count
|
||||
FROM users u
|
||||
LEFT JOIN lazer_user_profiles lup ON u.id = lup.user_id
|
||||
LEFT JOIN lazer_user_statistics lus ON u.id = lus.user_id AND lus.mode = 'osu'
|
||||
LIMIT 5;
|
||||
```
|
||||
|
||||
## 支持
|
||||
|
||||
如果遇到问题,请:
|
||||
1. 检查日志文件 `data_sync.log`
|
||||
2. 确认数据库权限
|
||||
3. 验证原始表数据完整性
|
||||
76
Dockerfile
76
Dockerfile
@@ -1,28 +1,48 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
WORKDIR /app
|
||||
ENV UV_PROJECT_ENVIRONMENT=syncvenv
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY uv.lock .
|
||||
COPY pyproject.toml .
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 启动命令
|
||||
CMD ["uv", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y gcc pkg-config default-libmysqlclient-dev \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
ENV PATH="/root/.cargo/bin:${PATH}" \
|
||||
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1 UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY packages/ ./packages/
|
||||
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv pip install rosu-pp-py
|
||||
|
||||
COPY . .
|
||||
|
||||
# ---
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y curl netcat-openbsd \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV PATH="/app/.venv/bin:${PATH}" \
|
||||
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
COPY --from=builder /app /app
|
||||
|
||||
COPY docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["uv", "run", "--no-sync", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
48
Dockerfile-osurx
Normal file
48
Dockerfile-osurx
Normal file
@@ -0,0 +1,48 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y gcc pkg-config default-libmysqlclient-dev git \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
|
||||
ENV PATH="/root/.cargo/bin:${PATH}" \
|
||||
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1 UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV UV_PROJECT_ENVIRONMENT=/app/.venv
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY packages/ ./packages/
|
||||
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv pip install git+https://github.com/ppy-sb/rosu-pp-py.git
|
||||
|
||||
COPY . .
|
||||
|
||||
# ---
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y curl netcat-openbsd \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV PATH="/app/.venv/bin:${PATH}" \
|
||||
PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
COPY --from=builder /app /app
|
||||
|
||||
COPY docker-entrypoint.sh /app/docker-entrypoint.sh
|
||||
RUN chmod +x /app/docker-entrypoint.sh
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/docker-entrypoint.sh"]
|
||||
CMD ["uv", "run", "--no-sync", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 GooGuTeam
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
268
README.md
268
README.md
@@ -6,193 +6,159 @@
|
||||
|
||||
- **OAuth 2.0 认证**: 支持密码流和刷新令牌流
|
||||
- **用户数据管理**: 完整的用户信息、统计数据、成就等
|
||||
- **多游戏模式支持**: osu!, taiko, fruits, mania
|
||||
- **多游戏模式支持**: osu! (osu!rx, osu!ap), taiko, fruits, mania
|
||||
- **数据库持久化**: MySQL 存储用户数据
|
||||
- **缓存支持**: Redis 缓存令牌和会话信息
|
||||
- **多种存储后端**: 支持本地存储、Cloudflare R2、AWS S3
|
||||
- **容器化部署**: Docker 和 Docker Compose 支持
|
||||
|
||||
## API 端点
|
||||
|
||||
### 认证端点
|
||||
- `POST /oauth/token` - OAuth 令牌获取/刷新
|
||||
|
||||
### 用户端点
|
||||
- `GET /api/v2/me/{ruleset}` - 获取当前用户信息
|
||||
|
||||
### 其他端点
|
||||
- `GET /` - 根端点
|
||||
- `GET /health` - 健康检查
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 使用 Docker Compose (推荐)
|
||||
|
||||
1. 克隆项目
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
git clone https://github.com/GooGuTeam/osu_lazer_api.git
|
||||
cd osu_lazer_api
|
||||
```
|
||||
|
||||
2. 启动服务
|
||||
2. 创建 `.env` 文件
|
||||
|
||||
请参考下方的服务器配置修改 .env 文件
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
3. 创建示例数据
|
||||
3. 启动服务
|
||||
```bash
|
||||
docker-compose exec api python create_sample_data.py
|
||||
# 标准服务器
|
||||
docker-compose -f docker-compose.yml up -d
|
||||
# 启用 osu!RX 和 osu!AP 模式 (偏偏要上班 pp 算法)
|
||||
docker-compose -f docker-compose-osurx.yml up -d
|
||||
```
|
||||
|
||||
4. 测试 API
|
||||
```bash
|
||||
# 获取访问令牌
|
||||
curl -X POST http://localhost:8000/oauth/token \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*"
|
||||
4. 通过游戏连接服务器
|
||||
|
||||
# 使用令牌获取用户信息
|
||||
curl -X GET http://localhost:8000/api/v2/me/osu \
|
||||
-H "Authorization: Bearer YOUR_ACCESS_TOKEN"
|
||||
```
|
||||
|
||||
### 本地开发
|
||||
|
||||
1. 安装依赖
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 配置环境变量
|
||||
```bash
|
||||
# 复制服务器配置文件
|
||||
cp .env .env.local
|
||||
|
||||
# 复制客户端配置文件(用于测试脚本)
|
||||
cp .env.client .env.client.local
|
||||
```
|
||||
|
||||
3. 启动 MySQL 和 Redis
|
||||
```bash
|
||||
# 使用 Docker
|
||||
docker run -d --name mysql -e MYSQL_ROOT_PASSWORD=password -e MYSQL_DATABASE=osu_api -p 3306:3306 mysql:8.0
|
||||
docker run -d --name redis -p 6379:6379 redis:7-alpine
|
||||
```
|
||||
|
||||
|
||||
4. 启动应用
|
||||
```bash
|
||||
uvicorn main:app --reload
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
osu_lazer_api/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── models.py # Pydantic 数据模型
|
||||
│ ├── database.py # SQLAlchemy 数据库模型
|
||||
│ ├── config.py # 配置设置
|
||||
│ ├── dependencies.py # 依赖注入
|
||||
│ ├── auth.py # 认证和令牌管理
|
||||
│ └── utils.py # 工具函数
|
||||
├── main.py # FastAPI 应用主文件
|
||||
├── create_sample_data.py # 示例数据创建脚本
|
||||
├── requirements.txt # Python 依赖
|
||||
├── .env # 环境变量配置
|
||||
├── docker-compose.yml # Docker Compose 配置
|
||||
├── Dockerfile # Docker 镜像配置
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 示例用户
|
||||
|
||||
创建示例数据后,您可以使用以下凭据进行测试:
|
||||
|
||||
- **用户名**: `Googujiang`
|
||||
- **密码**: `password123`
|
||||
- **用户ID**: `15651670`
|
||||
使用[自定义的 osu!lazer 客户端](https://github.com/GooGuTeam/osu),或者使用 [LazerAuthlibInjection](https://github.com/MingxuanGame/LazerAuthlibInjection),修改服务器设置为服务器的 IP
|
||||
|
||||
## 环境变量配置
|
||||
|
||||
项目包含两个环境配置文件:
|
||||
|
||||
### 服务器配置 (`.env`)
|
||||
用于配置 FastAPI 服务器的运行参数:
|
||||
|
||||
### 数据库设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `DATABASE_URL` | MySQL 数据库连接字符串 | `mysql+pymysql://root:password@localhost:3306/osu_api` |
|
||||
| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` |
|
||||
| `SECRET_KEY` | JWT 签名密钥 | `your-secret-key-here` |
|
||||
| `MYSQL_HOST` | MySQL 主机地址 | `localhost` |
|
||||
| `MYSQL_PORT` | MySQL 端口 | `3306` |
|
||||
| `MYSQL_DATABASE` | MySQL 数据库名 | `osu_api` |
|
||||
| `MYSQL_USER` | MySQL 用户名 | `osu_api` |
|
||||
| `MYSQL_PASSWORD` | MySQL 密码 | `password` |
|
||||
| `MYSQL_ROOT_PASSWORD` | MySQL root 密码 | `password` |
|
||||
| `REDIS_URL` | Redis 连接字符串 | `redis://127.0.0.1:6379/0` |
|
||||
|
||||
### JWT 设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `JWT_SECRET_KEY` | JWT 签名密钥 | `your_jwt_secret_here` |
|
||||
| `ALGORITHM` | JWT 算法 | `HS256` |
|
||||
| `ACCESS_TOKEN_EXPIRE_MINUTES` | 访问令牌过期时间(分钟) | `1440` |
|
||||
| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` |
|
||||
| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` |
|
||||
|
||||
### 服务器设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `HOST` | 服务器监听地址 | `0.0.0.0` |
|
||||
| `PORT` | 服务器监听端口 | `8000` |
|
||||
| `DEBUG` | 调试模式 | `True` |
|
||||
|
||||
### 客户端配置 (`.env.client`)
|
||||
用于配置客户端脚本的 API 连接参数:
|
||||
| `DEBUG` | 调试模式 | `false` |
|
||||
| `SERVER_URL` | 服务器 URL | `http://localhost:8000` |
|
||||
| `PRIVATE_API_SECRET` | 私有 API 密钥,用于前后端 API 调用 | `your_private_api_secret_here` |
|
||||
|
||||
### OAuth 设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` |
|
||||
| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` |
|
||||
| `OSU_API_URL` | API 服务器地址 | `http://localhost:8000` |
|
||||
| `OSU_WEB_CLIENT_ID` | Web OAuth 客户端 ID | `6` |
|
||||
| `OSU_WEB_CLIENT_SECRET` | Web OAuth 客户端密钥 | `your_osu_web_client_secret_here`
|
||||
|
||||
### SignalR 服务器设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `SIGNALR_NEGOTIATE_TIMEOUT` | SignalR 协商超时时间(秒) | `30` |
|
||||
| `SIGNALR_PING_INTERVAL` | SignalR ping 间隔(秒) | `15` |
|
||||
|
||||
### Fetcher 设置
|
||||
|
||||
Fetcher 用于从 osu! 官方 API 获取数据,使用 osu! 官方 API 的 OAuth 2.0 认证
|
||||
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `FETCHER_CLIENT_ID` | Fetcher 客户端 ID | `""` |
|
||||
| `FETCHER_CLIENT_SECRET` | Fetcher 客户端密钥 | `""` |
|
||||
| `FETCHER_SCOPES` | Fetcher 权限范围 | `public` |
|
||||
|
||||
### 日志设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `LOG_LEVEL` | 日志级别 | `INFO` |
|
||||
|
||||
### 游戏设置
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `ENABLE_OSU_RX` | 启用 osu!RX 统计数据 | `false` |
|
||||
| `ENABLE_OSU_AP` | 启用 osu!AP 统计数据 | `false` |
|
||||
| `ENABLE_ALL_MODS_PP` | 启用所有 Mod 的 PP 计算 | `false` |
|
||||
| `ENABLE_SUPPORTER_FOR_ALL_USERS` | 启用所有新注册用户的支持者状态 | `false` |
|
||||
| `ENABLE_ALL_BEATMAP_LEADERBOARD` | 启用所有谱面的排行榜 | `false` |
|
||||
| `SEASONAL_BACKGROUNDS` | 季节背景图 URL 列表 | `[]` |
|
||||
|
||||
### 存储服务设置
|
||||
|
||||
用于存储回放文件、头像等静态资源。
|
||||
|
||||
| 变量名 | 描述 | 默认值 |
|
||||
|--------|------|--------|
|
||||
| `STORAGE_SERVICE` | 存储服务类型:`local`、`r2`、`s3` | `local` |
|
||||
| `STORAGE_SETTINGS` | 存储服务配置 (JSON 格式),配置见下 | `{"local_storage_path": "./storage"}` |
|
||||
|
||||
## 存储服务配置
|
||||
|
||||
### 本地存储 (推荐用于开发环境)
|
||||
|
||||
本地存储将文件保存在服务器的本地文件系统中,适合开发和小规模部署。
|
||||
|
||||
```bash
|
||||
STORAGE_SERVICE="local"
|
||||
STORAGE_SETTINGS='{"local_storage_path": "./storage"}'
|
||||
```
|
||||
|
||||
### Cloudflare R2 存储 (推荐用于生产环境)
|
||||
|
||||
```bash
|
||||
STORAGE_SERVICE="r2"
|
||||
STORAGE_SETTINGS='{
|
||||
"r2_account_id": "your_cloudflare_account_id",
|
||||
"r2_access_key_id": "your_r2_access_key_id",
|
||||
"r2_secret_access_key": "your_r2_secret_access_key",
|
||||
"r2_bucket_name": "your_bucket_name",
|
||||
"r2_public_url_base": "https://your-custom-domain.com"
|
||||
}'
|
||||
```
|
||||
|
||||
### AWS S3 存储
|
||||
|
||||
```bash
|
||||
STORAGE_SERVICE="s3"
|
||||
STORAGE_SETTINGS='{
|
||||
"s3_access_key_id": "your_aws_access_key_id",
|
||||
"s3_secret_access_key": "your_aws_secret_access_key",
|
||||
"s3_bucket_name": "your_s3_bucket_name",
|
||||
"s3_region_name": "us-east-1",
|
||||
"s3_public_url_base": "https://your-custom-domain.com"
|
||||
}'
|
||||
```
|
||||
|
||||
> **注意**: 在生产环境中,请务必更改默认的密钥和密码!
|
||||
|
||||
## API 使用示例
|
||||
|
||||
### 获取访问令牌
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/oauth/token \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "grant_type=password&username=Googujiang&password=password123&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk&scope=*"
|
||||
```
|
||||
|
||||
响应:
|
||||
```json
|
||||
{
|
||||
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"refresh_token": "abc123...",
|
||||
"scope": "*"
|
||||
}
|
||||
```
|
||||
|
||||
### 获取用户信息
|
||||
|
||||
```bash
|
||||
curl -X GET http://localhost:8000/api/v2/me/osu \
|
||||
-H "Authorization: Bearer YOUR_ACCESS_TOKEN"
|
||||
```
|
||||
|
||||
### 刷新令牌
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/oauth/token \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "grant_type=refresh_token&refresh_token=YOUR_REFRESH_TOKEN&client_id=5&client_secret=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
```
|
||||
|
||||
## 开发
|
||||
|
||||
### 添加新用户
|
||||
|
||||
您可以通过修改 `create_sample_data.py` 文件来添加更多示例用户,或者扩展 API 来支持用户注册功能。
|
||||
|
||||
### 扩展功能
|
||||
|
||||
- 添加更多 API 端点(排行榜、谱面信息等)
|
||||
- 实现实时功能(WebSocket)
|
||||
- 添加管理面板
|
||||
- 实现数据导入/导出功能
|
||||
|
||||
### 迁移数据库
|
||||
### 更新数据库
|
||||
|
||||
参考[数据库迁移指南](./MIGRATE_GUIDE.md)
|
||||
|
||||
|
||||
30
app/auth.py
30
app/auth.py
@@ -15,6 +15,7 @@ from app.log import logger
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -125,12 +126,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
minutes=settings.access_token_expire_minutes
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
to_encode, settings.secret_key, algorithm=settings.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
@@ -146,7 +147,7 @@ def verify_token(token: str) -> dict | None:
|
||||
"""验证访问令牌"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
token, settings.secret_key, algorithms=[settings.algorithm]
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
@@ -156,6 +157,8 @@ def verify_token(token: str) -> dict | None:
|
||||
async def store_token(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
client_id: int,
|
||||
scopes: list[str],
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
@@ -164,7 +167,9 @@ async def store_token(
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
|
||||
)
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
await db.delete(token)
|
||||
@@ -179,7 +184,9 @@ async def store_token(
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
access_token=access_token,
|
||||
scope=",".join(scopes),
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
@@ -209,3 +216,18 @@ async def get_token_by_refresh_token(
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
async def get_user_by_authorization_code(
|
||||
db: AsyncSession, redis: Redis, client_id: int, code: str
|
||||
) -> tuple[User, list[str]] | None:
|
||||
user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues]
|
||||
scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues]
|
||||
if not user_id or not scopes:
|
||||
return None
|
||||
|
||||
await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
statement = select(User).where(User.id == int(user_id))
|
||||
user = (await db.exec(statement)).first()
|
||||
return (user, scopes.split(",")) if user else None
|
||||
|
||||
@@ -7,7 +7,15 @@ from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
try:
|
||||
import rosu_pp_py as rosu
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"rosu-pp-py is not installed. "
|
||||
"Please install it.\n"
|
||||
" Official: uv add rosu-pp-py\n"
|
||||
" ppy-sb: uv add git+https://github.com/ppy-sb/rosu-pp-py.git"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.score import Score
|
||||
@@ -51,8 +59,6 @@ def calculate_pp(
|
||||
) -> float:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
|
||||
if map.is_suspicious():
|
||||
return 0.0
|
||||
perf = rosu.Performance(
|
||||
mods=score.mods,
|
||||
lazer=True,
|
||||
@@ -67,7 +73,6 @@ def calculate_pp(
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
hitresult_priority=rosu.HitResultPriority.Fastest,
|
||||
)
|
||||
attrs = perf.calculate(map)
|
||||
return attrs.pp
|
||||
|
||||
142
app/config.py
142
app/config.py
@@ -1,51 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
from pydantic import Field, HttpUrl, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings:
|
||||
class AWSS3StorageSettings(BaseSettings):
|
||||
s3_access_key_id: str
|
||||
s3_secret_access_key: str
|
||||
s3_bucket_name: str
|
||||
s3_region_name: str
|
||||
s3_public_url_base: str | None = None
|
||||
|
||||
|
||||
class CloudflareR2Settings(BaseSettings):
|
||||
r2_account_id: str
|
||||
r2_access_key_id: str
|
||||
r2_secret_access_key: str
|
||||
r2_bucket_name: str
|
||||
r2_public_url_base: str | None = None
|
||||
|
||||
|
||||
class LocalStorageSettings(BaseSettings):
|
||||
local_storage_path: str = "./storage"
|
||||
|
||||
|
||||
class StorageServiceType(str, Enum):
|
||||
LOCAL = "local"
|
||||
CLOUDFLARE_R2 = "r2"
|
||||
AWS_S3 = "s3"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# 数据库设置
|
||||
DATABASE_URL: str = os.getenv(
|
||||
"DATABASE_URL", "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
|
||||
)
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0")
|
||||
mysql_host: str = "localhost"
|
||||
mysql_port: int = 3306
|
||||
mysql_database: str = "osu_api"
|
||||
mysql_user: str = "osu_api"
|
||||
mysql_password: str = "password"
|
||||
mysql_root_password: str = "password"
|
||||
redis_url: str = "redis://127.0.0.1:6379/0"
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}"
|
||||
|
||||
# JWT 设置
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
|
||||
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
|
||||
)
|
||||
secret_key: str = Field(default="your_jwt_secret_here", alias="jwt_secret_key")
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 1440
|
||||
|
||||
# OAuth 设置
|
||||
OSU_CLIENT_ID: str = os.getenv("OSU_CLIENT_ID", "5")
|
||||
OSU_CLIENT_SECRET: str = os.getenv(
|
||||
"OSU_CLIENT_SECRET", "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
)
|
||||
osu_client_id: int = 5
|
||||
osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
osu_web_client_id: int = 6
|
||||
osu_web_client_secret: str = "your_osu_web_client_secret_here"
|
||||
|
||||
# 服务器设置
|
||||
HOST: str = os.getenv("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("PORT", "8000"))
|
||||
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
private_api_secret: str = "your_private_api_secret_here"
|
||||
server_url: HttpUrl = HttpUrl("http://localhost:8000")
|
||||
|
||||
# SignalR 设置
|
||||
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
|
||||
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15"))
|
||||
signalr_negotiate_timeout: int = 30
|
||||
signalr_ping_interval: int = 15
|
||||
|
||||
# Fetcher 设置
|
||||
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "")
|
||||
FETCHER_CLIENT_SECRET: str = os.getenv("FETCHER_CLIENT_SECRET", "")
|
||||
FETCHER_SCOPES: list[str] = os.getenv("FETCHER_SCOPES", "public").split(",")
|
||||
FETCHER_CALLBACK_URL: str = os.getenv(
|
||||
"FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback"
|
||||
)
|
||||
fetcher_client_id: str = ""
|
||||
fetcher_client_secret: str = ""
|
||||
fetcher_scopes: Annotated[list[str], NoDecode] = ["public"]
|
||||
|
||||
@property
|
||||
def fetcher_callback_url(self) -> str:
|
||||
return f"{self.server_url}fetcher/callback"
|
||||
|
||||
# 日志设置
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
log_level: str = "INFO"
|
||||
|
||||
# 游戏设置
|
||||
enable_osu_rx: bool = False
|
||||
enable_osu_ap: bool = False
|
||||
enable_all_mods_pp: bool = False
|
||||
enable_supporter_for_all_users: bool = False
|
||||
enable_all_beatmap_leaderboard: bool = False
|
||||
seasonal_backgrounds: list[str] = []
|
||||
|
||||
# 存储设置
|
||||
storage_service: StorageServiceType = StorageServiceType.LOCAL
|
||||
storage_settings: (
|
||||
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
|
||||
) = LocalStorageSettings()
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
@field_validator("storage_settings", mode="after")
|
||||
def validate_storage_settings(
|
||||
cls,
|
||||
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
||||
info: ValidationInfo,
|
||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
||||
if not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'r2', "
|
||||
"storage_settings must be CloudflareR2Settings"
|
||||
)
|
||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
||||
if not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'local', "
|
||||
"storage_settings must be LocalStorageSettings"
|
||||
)
|
||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
||||
if not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 's3', "
|
||||
"storage_settings must be AWSS3StorageSettings"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .achievement import UserAchievement, UserAchievementResp
|
||||
from .auth import OAuthToken
|
||||
from .auth import OAuthClient, OAuthToken
|
||||
from .beatmap import (
|
||||
Beatmap as Beatmap,
|
||||
BeatmapResp as BeatmapResp,
|
||||
@@ -10,16 +10,33 @@ from .beatmapset import (
|
||||
BeatmapsetResp as BeatmapsetResp,
|
||||
)
|
||||
from .best_score import BestScore
|
||||
from .counts import (
|
||||
CountResp,
|
||||
MonthlyPlaycounts,
|
||||
ReplayWatchedCount,
|
||||
)
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .lazer_user import (
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .playlist_attempts import (
|
||||
ItemAttemptsCount,
|
||||
ItemAttemptsResp,
|
||||
PlaylistAggregateScore,
|
||||
)
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .room import APIUploadedRoom, Room, RoomResp
|
||||
from .room_participated_user import RoomParticipatedUser
|
||||
from .score import (
|
||||
MultiplayerScores,
|
||||
Score,
|
||||
ScoreAround,
|
||||
ScoreBase,
|
||||
ScoreResp,
|
||||
ScoreStatistics,
|
||||
@@ -37,21 +54,39 @@ from .user_account_history import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIUploadedRoom",
|
||||
"Beatmap",
|
||||
"BeatmapPlaycounts",
|
||||
"BeatmapPlaycountsResp",
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BestScore",
|
||||
"CountResp",
|
||||
"DailyChallengeStats",
|
||||
"DailyChallengeStatsResp",
|
||||
"FavouriteBeatmapset",
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"MonthlyPlaycounts",
|
||||
"MultiplayerEvent",
|
||||
"MultiplayerEventResp",
|
||||
"MultiplayerScores",
|
||||
"OAuthClient",
|
||||
"OAuthToken",
|
||||
"PPBestScore",
|
||||
"Playlist",
|
||||
"PlaylistAggregateScore",
|
||||
"PlaylistBestScore",
|
||||
"PlaylistResp",
|
||||
"Relationship",
|
||||
"RelationshipResp",
|
||||
"RelationshipType",
|
||||
"ReplayWatchedCount",
|
||||
"Room",
|
||||
"RoomParticipatedUser",
|
||||
"RoomResp",
|
||||
"Score",
|
||||
"ScoreAround",
|
||||
"ScoreBase",
|
||||
"ScoreResp",
|
||||
"ScoreStatistics",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
@@ -17,6 +18,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
client_id: int = Field(index=True)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
token_type: str = Field(default="Bearer", max_length=20)
|
||||
@@ -27,3 +29,13 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
)
|
||||
|
||||
user: "User" = Relationship()
|
||||
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
|
||||
client_id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
owner_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import MODE_TO_INT, GameMode
|
||||
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -23,14 +23,12 @@ class BeatmapOwner(SQLModel):
|
||||
username: str
|
||||
|
||||
|
||||
class BeatmapBase(SQLModel, UTCBaseModel):
|
||||
class BeatmapBase(SQLModel):
|
||||
# Beatmap
|
||||
url: str
|
||||
mode: GameMode
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
difficulty_rating: float = Field(
|
||||
default=0.0, sa_column=Column(DECIMAL(precision=10, scale=6))
|
||||
)
|
||||
difficulty_rating: float = Field(default=0.0)
|
||||
total_length: int
|
||||
user_id: int
|
||||
version: str
|
||||
@@ -42,17 +40,11 @@ class BeatmapBase(SQLModel, UTCBaseModel):
|
||||
# TODO: failtimes, owners
|
||||
|
||||
# BeatmapExtended
|
||||
ar: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
|
||||
cs: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
|
||||
drain: float = Field(
|
||||
default=0.0,
|
||||
sa_column=Column(DECIMAL(precision=10, scale=2)),
|
||||
) # hp
|
||||
accuracy: float = Field(
|
||||
default=0.0,
|
||||
sa_column=Column(DECIMAL(precision=10, scale=2)),
|
||||
) # od
|
||||
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(precision=10, scale=2)))
|
||||
ar: float = Field(default=0.0)
|
||||
cs: float = Field(default=0.0)
|
||||
drain: float = Field(default=0.0) # hp
|
||||
accuracy: float = Field(default=0.0) # od
|
||||
bpm: float = Field(default=0.0)
|
||||
count_circles: int = Field(default=0)
|
||||
count_sliders: int = Field(default=0)
|
||||
count_spinners: int = Field(default=0)
|
||||
@@ -63,7 +55,7 @@ class BeatmapBase(SQLModel, UTCBaseModel):
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus
|
||||
# optional
|
||||
@@ -71,10 +63,6 @@ class Beatmap(BeatmapBase, table=True):
|
||||
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
|
||||
@property
|
||||
def can_ranked(self) -> bool:
|
||||
return self.beatmap_status > BeatmapRankStatus.PENDING
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
@@ -170,11 +158,19 @@ class BeatmapResp(BeatmapBase):
|
||||
from .score import Score
|
||||
|
||||
beatmap_ = beatmap.model_dump()
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
beatmap_["convert"] = True
|
||||
beatmap_["is_scoreable"] = beatmap.beatmap_status > BeatmapRankStatus.PENDING
|
||||
beatmap_["status"] = beatmap.beatmap_status.name.lower()
|
||||
beatmap_["ranked"] = beatmap.beatmap_status.value
|
||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
else:
|
||||
beatmap_["status"] = beatmap_status.name.lower()
|
||||
beatmap_["ranked"] = beatmap_status.value
|
||||
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
|
||||
|
||||
@@ -1,58 +1,38 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import BASE_INCLUDES, User, UserResp
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import Field, Relationship, SQLModel, col, func, select
|
||||
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
|
||||
class BeatmapCovers(SQLModel):
|
||||
cover: str
|
||||
card: str
|
||||
list: str
|
||||
slimcover: str
|
||||
cover_2_x: str | None = Field(default=None, alias="cover@2x")
|
||||
card_2_x: str | None = Field(default=None, alias="card@2x")
|
||||
list_2_x: str | None = Field(default=None, alias="list@2x")
|
||||
slimcover_2_x: str | None = Field(default=None, alias="slimcover@2x")
|
||||
|
||||
@model_serializer
|
||||
def _(self) -> dict[str, str | None]:
|
||||
self = cast(dict[str, str | None] | BeatmapCovers, self)
|
||||
if isinstance(self, dict):
|
||||
return {
|
||||
"cover": self["cover"],
|
||||
"card": self["card"],
|
||||
"list": self["list"],
|
||||
"slimcover": self["slimcover"],
|
||||
"cover@2x": self.get("cover@2x"),
|
||||
"card@2x": self.get("card@2x"),
|
||||
"list@2x": self.get("list@2x"),
|
||||
"slimcover@2x": self.get("slimcover@2x"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"cover": self.cover,
|
||||
"card": self.card,
|
||||
"list": self.list,
|
||||
"slimcover": self.slimcover,
|
||||
"cover@2x": self.cover_2_x,
|
||||
"card@2x": self.card_2_x,
|
||||
"list@2x": self.list_2_x,
|
||||
"slimcover@2x": self.slimcover_2_x,
|
||||
}
|
||||
BeatmapCovers = TypedDict(
|
||||
"BeatmapCovers",
|
||||
{
|
||||
"cover": str,
|
||||
"card": str,
|
||||
"list": str,
|
||||
"slimcover": str,
|
||||
"cover@2x": NotRequired[str | None],
|
||||
"card@2x": NotRequired[str | None],
|
||||
"list@2x": NotRequired[str | None],
|
||||
"slimcover@2x": NotRequired[str | None],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BeatmapHype(BaseModel):
|
||||
@@ -74,12 +54,12 @@ class BeatmapNomination(TypedDict):
|
||||
beatmapset_id: int
|
||||
reset: bool
|
||||
user_id: int
|
||||
rulesets: list[GameMode] | None
|
||||
rulesets: NotRequired[list[GameMode] | None]
|
||||
|
||||
|
||||
class BeatmapDescription(SQLModel):
|
||||
bbcode: str | None = None
|
||||
description: str | None = None
|
||||
class BeatmapDescription(TypedDict):
|
||||
bbcode: NotRequired[str | None]
|
||||
description: NotRequired[str | None]
|
||||
|
||||
|
||||
class BeatmapTranslationText(BaseModel):
|
||||
@@ -87,7 +67,7 @@ class BeatmapTranslationText(BaseModel):
|
||||
id: int | None = None
|
||||
|
||||
|
||||
class BeatmapsetBase(SQLModel, UTCBaseModel):
|
||||
class BeatmapsetBase(SQLModel):
|
||||
# Beatmapset
|
||||
artist: str = Field(index=True)
|
||||
artist_unicode: str = Field(index=True)
|
||||
@@ -121,7 +101,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel):
|
||||
track_id: int | None = Field(default=None) # feature artist?
|
||||
|
||||
# BeatmapsetExtended
|
||||
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
|
||||
bpm: float = Field(default=0.0)
|
||||
can_be_hyped: bool = Field(default=False)
|
||||
discussion_locked: bool = Field(default=False)
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime))
|
||||
@@ -181,11 +161,24 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
"download_disabled": resp.availability.download_disabled or False,
|
||||
}
|
||||
)
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
|
||||
).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
|
||||
) -> "Beatmapset":
|
||||
beatmapset = await session.get(Beatmapset, sid)
|
||||
if not beatmapset:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
beatmapset = await cls.from_resp(session, resp)
|
||||
return beatmapset
|
||||
|
||||
|
||||
class BeatmapsetResp(BeatmapsetBase):
|
||||
id: int
|
||||
@@ -193,7 +186,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
discussion_enabled: bool = True
|
||||
status: str
|
||||
ranked: int
|
||||
legacy_thread_url: str = ""
|
||||
legacy_thread_url: str | None = ""
|
||||
is_scoreable: bool
|
||||
hype: BeatmapHype | None = None
|
||||
availability: BeatmapAvailability
|
||||
@@ -239,11 +232,21 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
required=beatmapset.nominations_required,
|
||||
current=beatmapset.nominations_current,
|
||||
),
|
||||
"status": beatmapset.beatmap_status.name.lower(),
|
||||
"ranked": beatmapset.beatmap_status.value,
|
||||
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
|
||||
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
|
||||
**beatmapset.model_dump(),
|
||||
}
|
||||
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
update["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
else:
|
||||
update["status"] = beatmap_status.name.lower()
|
||||
update["ranked"] = beatmap_status.value
|
||||
|
||||
if session and user:
|
||||
existing_favourite = (
|
||||
await session.exec(
|
||||
|
||||
@@ -29,9 +29,7 @@ class BestScore(SQLModel, table=True):
|
||||
)
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
total_score: int = Field(
|
||||
default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score"))
|
||||
)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
mods: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
|
||||
@@ -14,7 +14,13 @@ if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class MonthlyPlaycounts(SQLModel, table=True):
|
||||
class CountBase(SQLModel):
|
||||
year: int = Field(index=True)
|
||||
month: int = Field(index=True)
|
||||
count: int = Field(default=0)
|
||||
|
||||
|
||||
class MonthlyPlaycounts(CountBase, table=True):
|
||||
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
@@ -24,20 +30,29 @@ class MonthlyPlaycounts(SQLModel, table=True):
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
year: int = Field(index=True)
|
||||
month: int = Field(index=True)
|
||||
playcount: int = Field(default=0)
|
||||
|
||||
user: "User" = Relationship(back_populates="monthly_playcounts")
|
||||
|
||||
|
||||
class MonthlyPlaycountsResp(SQLModel):
|
||||
class ReplayWatchedCount(CountBase, table=True):
|
||||
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user: "User" = Relationship(back_populates="replays_watched_counts")
|
||||
|
||||
|
||||
class CountResp(SQLModel):
|
||||
start_date: date
|
||||
count: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp":
|
||||
def from_db(cls, db_model: CountBase) -> "CountResp":
|
||||
return cls(
|
||||
start_date=date(db_model.year, db_model.month, 1),
|
||||
count=db_model.playcount,
|
||||
count=db_model.count,
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
@@ -6,8 +6,9 @@ from app.models.score import GameMode
|
||||
from app.models.user import Country, Page, RankHistory
|
||||
|
||||
from .achievement import UserAchievement, UserAchievementResp
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .counts import CountResp, MonthlyPlaycounts, ReplayWatchedCount
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp
|
||||
from .statistics import UserStatistics, UserStatisticsResp
|
||||
from .team import Team, TeamMember
|
||||
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
||||
@@ -21,6 +22,7 @@ from sqlmodel import (
|
||||
Field,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
@@ -74,7 +76,6 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
username: str = Field(max_length=32, unique=True, index=True)
|
||||
page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
|
||||
previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
# TODO: replays_watched_counts
|
||||
support_level: int = 0
|
||||
badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
@@ -144,6 +145,9 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
back_populates="user"
|
||||
)
|
||||
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
|
||||
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
@@ -164,7 +168,7 @@ class UserResp(UserBase):
|
||||
is_online: bool = False
|
||||
groups: list = [] # TODO
|
||||
country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
|
||||
favourite_beatmapset_count: int = 0 # TODO
|
||||
favourite_beatmapset_count: int = 0
|
||||
graveyard_beatmapset_count: int = 0 # TODO
|
||||
guest_beatmapset_count: int = 0 # TODO
|
||||
loved_beatmapset_count: int = 0 # TODO
|
||||
@@ -176,13 +180,15 @@ class UserResp(UserBase):
|
||||
follower_count: int = 0
|
||||
friends: list["RelationshipResp"] | None = None
|
||||
scores_best_count: int = 0
|
||||
scores_first_count: int = 0
|
||||
scores_first_count: int = 0 # TODO
|
||||
scores_recent_count: int = 0
|
||||
scores_pinned_count: int = 0
|
||||
beatmap_playcounts_count: int = 0
|
||||
account_history: list[UserAccountHistoryResp] = []
|
||||
active_tournament_banners: list[dict] = [] # TODO
|
||||
kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
|
||||
monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list)
|
||||
monthly_playcounts: list[CountResp] = Field(default_factory=list)
|
||||
replay_watched_counts: list[CountResp] = Field(default_factory=list)
|
||||
unread_pm_count: int = 0 # TODO
|
||||
rank_history: RankHistory | None = None # TODO
|
||||
rank_highest: RankHighest | None = None # TODO
|
||||
@@ -207,7 +213,11 @@ class UserResp(UserBase):
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
from .best_score import BestScore
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .score import Score
|
||||
|
||||
ruleset = ruleset or obj.playmode
|
||||
|
||||
u = cls.model_validate(obj.model_dump())
|
||||
u.id = obj.id
|
||||
@@ -275,7 +285,7 @@ class UserResp(UserBase):
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
for i in await obj.awaitable_attrs.statistics:
|
||||
if i.mode == (ruleset or obj.playmode):
|
||||
if i.mode == ruleset:
|
||||
current_stattistics = i
|
||||
break
|
||||
u.statistics = (
|
||||
@@ -292,16 +302,74 @@ class UserResp(UserBase):
|
||||
|
||||
if "monthly_playcounts" in include:
|
||||
u.monthly_playcounts = [
|
||||
MonthlyPlaycountsResp.from_db(pc)
|
||||
CountResp.from_db(pc)
|
||||
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||
]
|
||||
|
||||
if "replays_watched_counts" in include:
|
||||
u.replay_watched_counts = [
|
||||
CountResp.from_db(rwc)
|
||||
for rwc in await obj.awaitable_attrs.replays_watched_counts
|
||||
]
|
||||
|
||||
if "achievements" in include:
|
||||
u.user_achievements = [
|
||||
UserAchievementResp.from_db(ua)
|
||||
for ua in await obj.awaitable_attrs.achievement
|
||||
]
|
||||
|
||||
u.favourite_beatmapset_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.user_id == obj.id)
|
||||
)
|
||||
).one()
|
||||
u.scores_pinned_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.user_id == obj.id,
|
||||
Score.pinned_order > 0,
|
||||
Score.gamemode == ruleset,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
u.scores_best_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(BestScore)
|
||||
.where(
|
||||
BestScore.user_id == obj.id,
|
||||
BestScore.gamemode == ruleset,
|
||||
)
|
||||
.limit(200)
|
||||
)
|
||||
).one()
|
||||
u.scores_recent_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.user_id == obj.id,
|
||||
Score.gamemode == ruleset,
|
||||
col(Score.passed).is_(True),
|
||||
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
u.beatmap_playcounts_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(BeatmapPlaycounts)
|
||||
.where(
|
||||
BeatmapPlaycounts.user_id == obj.id,
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
return u
|
||||
|
||||
|
||||
@@ -314,6 +382,7 @@ ALL_INCLUDED = [
|
||||
"statistics_rulesets",
|
||||
"achievements",
|
||||
"monthly_playcounts",
|
||||
"replays_watched_counts",
|
||||
]
|
||||
|
||||
|
||||
@@ -324,6 +393,7 @@ SEARCH_INCLUDED = [
|
||||
"statistics_rulesets",
|
||||
"achievements",
|
||||
"monthly_playcounts",
|
||||
"replays_watched_counts",
|
||||
]
|
||||
|
||||
BASE_INCLUDES = [
|
||||
|
||||
56
app/database/multiplayer_event.py
Normal file
56
app/database/multiplayer_event.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventBase(SQLModel, UTCBaseModel):
|
||||
playlist_item_id: int | None = None
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_type: str = Field(index=True)
|
||||
|
||||
|
||||
class MultiplayerEvent(MultiplayerEventBase, table=True):
|
||||
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_detail: dict[str, Any] | None = Field(
|
||||
sa_column=Column(JSON),
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventResp(MultiplayerEventBase):
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
|
||||
return cls.model_validate(event)
|
||||
152
app/database/playlist_attempts.py
Normal file
152
app/database/playlist_attempts.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
total_score: int = 0
|
||||
|
||||
|
||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
user: User = Relationship()
|
||||
|
||||
async def get_position(self, session: AsyncSession) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(ItemAttemptsCountBase.room_id),
|
||||
order_by=col(ItemAttemptsCountBase.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(ItemAttemptsCountBase, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
async def update(self, session: AsyncSession):
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == self.room_id,
|
||||
PlaylistBestScore.user_id == self.user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
self.attempts = sum(score.attempts for score in playlist_scores)
|
||||
self.total_score = sum(score.total_score for score in playlist_scores)
|
||||
self.pp = sum(score.score.pp for score in playlist_scores)
|
||||
self.completed = len(playlist_scores)
|
||||
self.accuracy = (
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed
|
||||
if self.completed > 0
|
||||
else 0.0
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "ItemAttemptsCount":
|
||||
item_attempts = await session.exec(
|
||||
select(cls).where(
|
||||
cls.room_id == room_id,
|
||||
cls.user_id == user_id,
|
||||
)
|
||||
)
|
||||
item_attempts = item_attempts.first()
|
||||
if item_attempts is None:
|
||||
item_attempts = cls(room_id=room_id, user_id=user_id)
|
||||
session.add(item_attempts)
|
||||
await session.commit()
|
||||
await session.refresh(item_attempts)
|
||||
await item_attempts.update(session)
|
||||
return item_attempts
|
||||
|
||||
|
||||
class ItemAttemptsResp(ItemAttemptsCountBase):
|
||||
user: UserResp | None = None
|
||||
position: int | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
item_attempts: ItemAttemptsCount,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
) -> "ItemAttemptsResp":
|
||||
resp = cls.model_validate(item_attempts.model_dump())
|
||||
resp.user = await UserResp.from_db(
|
||||
await item_attempts.awaitable_attrs.user,
|
||||
session=session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
)
|
||||
if "position" in include:
|
||||
resp.position = await item_attempts.get_position(session)
|
||||
# resp.accuracy *= 100
|
||||
return resp
|
||||
|
||||
|
||||
class ItemAttemptsCountForItem(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "PlaylistAggregateScore":
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
playlist_item_attempts = []
|
||||
for score in playlist_scores:
|
||||
playlist_item_attempts.append(
|
||||
ItemAttemptsCountForItem(
|
||||
id=score.playlist_id,
|
||||
attempts=score.attempts,
|
||||
passed=score.score.passed,
|
||||
)
|
||||
)
|
||||
return cls(playlist_item_attempts=playlist_item_attempts)
|
||||
110
app/database/playlist_best_score.py
Normal file
110
app/database/playlist_best_score.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .lazer_user import User
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .score import Score
|
||||
|
||||
|
||||
class PlaylistBestScore(SQLModel, table=True):
|
||||
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
attempts: int = Field(default=0) # playlist
|
||||
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[PlaylistBestScore.score_id]",
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def process_playlist_best_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
score_id: int,
|
||||
total_score: int,
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
):
|
||||
previous = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if previous is None:
|
||||
previous = PlaylistBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score_id,
|
||||
room_id=room_id,
|
||||
playlist_id=playlist_id,
|
||||
total_score=total_score,
|
||||
)
|
||||
session.add(previous)
|
||||
elif not previous.score.passed or previous.total_score < total_score:
|
||||
previous.score_id = score_id
|
||||
previous.total_score = total_score
|
||||
previous.attempts += 1
|
||||
await session.commit()
|
||||
if await redis.exists(f"multiplayer:{room_id}:gameplay:players"):
|
||||
await redis.decr(f"multiplayer:{room_id}:gameplay:players")
|
||||
|
||||
|
||||
async def get_position(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
session: AsyncSession,
|
||||
) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(
|
||||
col(PlaylistBestScore.playlist_id),
|
||||
col(PlaylistBestScore.room_id),
|
||||
),
|
||||
order_by=col(PlaylistBestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = (
|
||||
select(PlaylistBestScore, rownum)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
143
app/database/playlists.py
Normal file
143
app/database/playlists.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .room import Room
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
ruleset_id: int = Field(ge=0, le=3)
|
||||
expired: bool = Field(default=False)
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
)
|
||||
allowed_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
required_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
freestyle: bool = Field(default=False)
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
beatmap: Beatmap = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
|
||||
@classmethod
|
||||
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
|
||||
cls.room_id == room_id
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
owner_id=playlist.owner_id,
|
||||
ruleset_id=playlist.ruleset_id,
|
||||
beatmap_id=playlist.beatmap_id,
|
||||
required_mods=playlist.required_mods,
|
||||
allowed_mods=playlist.allowed_mods,
|
||||
expired=playlist.expired,
|
||||
playlist_order=playlist.playlist_order,
|
||||
played_at=playlist.played_at,
|
||||
freestyle=playlist.freestyle,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
db_playlist.owner_id = playlist.owner_id
|
||||
db_playlist.ruleset_id = playlist.ruleset_id
|
||||
db_playlist.beatmap_id = playlist.beatmap_id
|
||||
db_playlist.required_mods = playlist.required_mods
|
||||
db_playlist.allowed_mods = playlist.allowed_mods
|
||||
db_playlist.expired = playlist.expired
|
||||
db_playlist.playlist_order = playlist.playlist_order
|
||||
db_playlist.played_at = playlist.played_at
|
||||
db_playlist.freestyle = playlist.freestyle
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
playlist.id = db_playlist.id
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == item_id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
await session.delete(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, playlist: Playlist, include: list[str] = []
|
||||
) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
@@ -1,6 +1,177 @@
|
||||
from sqlmodel import Field, SQLModel
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.database.playlist_attempts import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomDifficultyRange,
|
||||
RoomPlaylistItemStats,
|
||||
RoomStatus,
|
||||
)
|
||||
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class RoomIndex(SQLModel, table=True):
|
||||
__tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
|
||||
class RoomBase(SQLModel, UTCBaseModel):
|
||||
name: str = Field(index=True)
|
||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||
duration: int | None = Field(default=None) # minutes
|
||||
starts_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
ends_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
participant_count: int = Field(default=0)
|
||||
max_attempts: int | None = Field(default=None) # playlists
|
||||
type: MatchType
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
# TODO: channel_id
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "selectin",
|
||||
"cascade": "all, delete-orphan",
|
||||
"overlaps": "room",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
host: UserResp | None = None
|
||||
playlist: list[PlaylistResp] = []
|
||||
playlist_item_stats: RoomPlaylistItemStats | None = None
|
||||
difficulty_range: RoomDifficultyRange | None = None
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
current_user_score: PlaylistAggregateScore | None = None
|
||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room: Room,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
user: User | None = None,
|
||||
) -> "RoomResp":
|
||||
resp = cls.model_validate(room.model_dump())
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
min=0,
|
||||
max=0,
|
||||
)
|
||||
rulesets = set()
|
||||
for playlist in room.playlist:
|
||||
stats.count_total += 1
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(
|
||||
difficulty_range.min, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.max = max(
|
||||
difficulty_range.max, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
resp.difficulty_range = difficulty_range
|
||||
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
|
||||
resp.recent_participants = []
|
||||
for recent_participant in await session.exec(
|
||||
select(RoomParticipatedUser)
|
||||
.where(
|
||||
RoomParticipatedUser.room_id == room.id,
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.limit(8)
|
||||
.order_by(col(RoomParticipatedUser.joined_at).desc())
|
||||
):
|
||||
resp.recent_participants.append(
|
||||
await UserResp.from_db(
|
||||
await recent_participant.awaitable_attrs.user,
|
||||
session,
|
||||
include=["statistics"],
|
||||
)
|
||||
)
|
||||
resp.host = await UserResp.from_db(
|
||||
await room.awaitable_attrs.host, session, include=["statistics"]
|
||||
)
|
||||
if "current_user_score" in include and user:
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(
|
||||
room.id, user.id, session
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
name=room.settings.name,
|
||||
type=room.settings.match_type,
|
||||
queue_mode=room.settings.queue_mode,
|
||||
auto_skip=room.settings.auto_skip,
|
||||
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||
status=server_room.status,
|
||||
category=server_room.category,
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
"""
|
||||
将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
|
||||
"""
|
||||
room_dict = self.model_dump()
|
||||
room_dict.pop("playlist", None)
|
||||
# host_id 已在字段中
|
||||
return Room(**room_dict)
|
||||
|
||||
id: int | None
|
||||
host_id: int | None = None
|
||||
playlist: list[Playlist] = Field(default_factory=list)
|
||||
|
||||
39
app/database/room_participated_user.py
Normal file
39
app/database/room_participated_user.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .room import Room
|
||||
|
||||
|
||||
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
)
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
|
||||
)
|
||||
joined_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
left_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
|
||||
)
|
||||
|
||||
room: "Room" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from datetime import UTC, date, datetime
|
||||
import json
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.calculator import (
|
||||
calculate_pp,
|
||||
@@ -13,8 +13,14 @@ from app.calculator import (
|
||||
calculate_weighted_pp,
|
||||
clamp,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database.team import TeamMember
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.model import (
|
||||
CurrentUserAttributes,
|
||||
PinAttributes,
|
||||
RespWithCursor,
|
||||
UTCBaseModel,
|
||||
)
|
||||
from app.models.mods import APIMod, mods_can_get_pp
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
@@ -31,8 +37,8 @@ from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmap_playcounts import process_beatmap_playcount
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .best_score import BestScore
|
||||
from .counts import MonthlyPlaycounts
|
||||
from .lazer_user import User, UserResp
|
||||
from .monthly_playcounts import MonthlyPlaycounts
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import (
|
||||
Relationship as DBRelationship,
|
||||
@@ -89,10 +95,11 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
|
||||
# optional
|
||||
# TODO: current_user_attributes
|
||||
position: int | None = Field(default=None) # multiplayer
|
||||
# position: int | None = Field(default=None) # multiplayer
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
@@ -100,7 +107,6 @@ class Score(ScoreBase, table=True):
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
@@ -121,6 +127,7 @@ class Score(ScoreBase, table=True):
|
||||
nslider_tail_hit: int | None = Field(default=None, exclude=True)
|
||||
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pinned_order: int = Field(default=0, exclude=True)
|
||||
|
||||
# optional
|
||||
beatmap: Beatmap = Relationship()
|
||||
@@ -163,6 +170,9 @@ class ScoreResp(ScoreBase):
|
||||
maximum_statistics: ScoreStatistics | None = None
|
||||
rank_global: int | None = None
|
||||
rank_country: int | None = None
|
||||
position: int | None = None
|
||||
scores_around: "ScoreAround | None" = None
|
||||
current_user_attributes: CurrentUserAttributes | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
@@ -231,9 +241,22 @@ class ScoreResp(ScoreBase):
|
||||
)
|
||||
or None
|
||||
)
|
||||
s.current_user_attributes = CurrentUserAttributes(
|
||||
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
class MultiplayerScores(RespWithCursor):
|
||||
scores: list[ScoreResp] = Field(default_factory=list)
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ScoreAround(SQLModel):
|
||||
higher: MultiplayerScores | None = None
|
||||
lower: MultiplayerScores | None = None
|
||||
|
||||
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
@@ -312,6 +335,13 @@ async def get_leaderboard(
|
||||
user: User | None = None,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[Score], Score | None]:
|
||||
is_rx = "RX" in (mods or [])
|
||||
is_ap = "AP" in (mods or [])
|
||||
if settings.enable_osu_rx and is_rx:
|
||||
mode = GameMode.OSURX
|
||||
elif settings.enable_osu_ap and is_ap:
|
||||
mode = GameMode.OSUAP
|
||||
|
||||
wheres = await _score_where(type, beatmap, mode, mods, user)
|
||||
if wheres is None:
|
||||
return [], None
|
||||
@@ -329,6 +359,10 @@ async def get_leaderboard(
|
||||
self_query = (
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user.id)
|
||||
.where(
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.limit(1)
|
||||
)
|
||||
@@ -461,12 +495,13 @@ async def get_user_best_pp_in_beatmap(
|
||||
async def get_user_best_pp(
|
||||
session: AsyncSession,
|
||||
user: int,
|
||||
mode: GameMode,
|
||||
limit: int = 200,
|
||||
) -> Sequence[PPBestScore]:
|
||||
return (
|
||||
await session.exec(
|
||||
select(PPBestScore)
|
||||
.where(PPBestScore.user_id == user)
|
||||
.where(PPBestScore.user_id == user, PPBestScore.gamemode == mode)
|
||||
.order_by(col(PPBestScore.pp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
@@ -474,7 +509,7 @@ async def get_user_best_pp(
|
||||
|
||||
|
||||
async def process_user(
|
||||
session: AsyncSession, user: User, score: Score, ranked: bool = False
|
||||
session: AsyncSession, user: User, score: Score, length: int, ranked: bool = False
|
||||
):
|
||||
assert user.id
|
||||
assert score.id
|
||||
@@ -577,8 +612,8 @@ async def process_user(
|
||||
)
|
||||
)
|
||||
statistics.play_count += 1
|
||||
mouthly_playcount.playcount += 1
|
||||
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
|
||||
mouthly_playcount.count += 1
|
||||
statistics.play_time += length
|
||||
statistics.count_100 += score.n100 + score.nkatu
|
||||
statistics.count_300 += score.n300 + score.ngeki
|
||||
statistics.count_50 += score.n50
|
||||
@@ -588,7 +623,7 @@ async def process_user(
|
||||
)
|
||||
|
||||
if score.passed and ranked:
|
||||
best_pp_scores = await get_user_best_pp(session, user.id)
|
||||
best_pp_scores = await get_user_best_pp(session, user.id, score.gamemode)
|
||||
pp_sum = 0.0
|
||||
acc_sum = 0.0
|
||||
for i, bp in enumerate(best_pp_scores):
|
||||
@@ -616,9 +651,19 @@ async def process_score(
|
||||
fetcher: "Fetcher",
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
) -> Score:
|
||||
assert user.id
|
||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
acronyms = [mod["acronym"] for mod in info.mods]
|
||||
is_rx = "RX" in acronyms
|
||||
is_ap = "AP" in acronyms
|
||||
gamemode = INT_TO_MODE[info.ruleset_id]
|
||||
if settings.enable_osu_rx and is_rx and gamemode == GameMode.OSU:
|
||||
gamemode = GameMode.OSURX
|
||||
elif settings.enable_osu_ap and is_ap and gamemode == GameMode.OSU:
|
||||
gamemode = GameMode.OSUAP
|
||||
score = Score(
|
||||
accuracy=info.accuracy,
|
||||
max_combo=info.max_combo,
|
||||
@@ -630,7 +675,7 @@ async def process_score(
|
||||
total_score_without_mods=info.total_score_without_mods,
|
||||
beatmap_id=beatmap_id,
|
||||
ended_at=datetime.now(UTC),
|
||||
gamemode=INT_TO_MODE[info.ruleset_id],
|
||||
gamemode=gamemode,
|
||||
started_at=score_token.created_at,
|
||||
user_id=user.id,
|
||||
preserve=info.passed,
|
||||
@@ -647,6 +692,8 @@ async def process_score(
|
||||
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
||||
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
||||
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
||||
playlist_item_id=item_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
if can_get_pp:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
@@ -678,4 +725,5 @@ async def process_score(
|
||||
await session.refresh(score)
|
||||
await session.refresh(score_token)
|
||||
await session.refresh(user)
|
||||
await redis.publish("score:processed", score.id)
|
||||
return score
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
import json
|
||||
|
||||
from app.config import settings
|
||||
@@ -18,23 +19,36 @@ def json_serializer(value):
|
||||
|
||||
|
||||
# 数据库引擎
|
||||
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
|
||||
engine = create_async_engine(settings.database_url, json_serializer=json_serializer)
|
||||
|
||||
# Redis 连接
|
||||
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
|
||||
"db_session_context", default=None
|
||||
)
|
||||
|
||||
|
||||
async def get_db():
|
||||
async with AsyncSession(engine) as session:
|
||||
session = db_session_context.get()
|
||||
if session is None:
|
||||
session = AsyncSession(engine)
|
||||
db_session_context.set(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
db_session_context.set(None)
|
||||
else:
|
||||
yield session
|
||||
|
||||
|
||||
async def create_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
def get_redis_pubsub():
|
||||
return redis_client.pubsub()
|
||||
|
||||
@@ -12,10 +12,10 @@ async def get_fetcher() -> Fetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
settings.FETCHER_CLIENT_ID,
|
||||
settings.FETCHER_CLIENT_SECRET,
|
||||
settings.FETCHER_SCOPES,
|
||||
settings.FETCHER_CALLBACK_URL,
|
||||
settings.fetcher_client_id,
|
||||
settings.fetcher_client_secret,
|
||||
settings.fetcher_scopes,
|
||||
settings.fetcher_callback_url,
|
||||
)
|
||||
redis = get_redis()
|
||||
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
||||
|
||||
26
app/dependencies/scheduler.py
Normal file
26
app/dependencies/scheduler.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
scheduler: AsyncIOScheduler | None = None
|
||||
|
||||
|
||||
def init_scheduler():
|
||||
global scheduler
|
||||
scheduler = AsyncIOScheduler(timezone=UTC)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
init_scheduler()
|
||||
return scheduler # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
global scheduler
|
||||
if scheduler:
|
||||
scheduler.shutdown()
|
||||
52
app/dependencies/storage.py
Normal file
52
app/dependencies/storage.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from app.config import (
|
||||
AWSS3StorageSettings,
|
||||
CloudflareR2Settings,
|
||||
LocalStorageSettings,
|
||||
StorageServiceType,
|
||||
settings,
|
||||
)
|
||||
from app.storage import StorageService
|
||||
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
storage: StorageService | None = None
|
||||
|
||||
|
||||
def init_storage_service():
|
||||
global storage
|
||||
if settings.storage_service == StorageServiceType.LOCAL:
|
||||
storage_settings = cast(LocalStorageSettings, settings.storage_settings)
|
||||
storage = LocalStorageService(
|
||||
storage_path=storage_settings.local_storage_path,
|
||||
)
|
||||
elif settings.storage_service == StorageServiceType.CLOUDFLARE_R2:
|
||||
storage_settings = cast(CloudflareR2Settings, settings.storage_settings)
|
||||
storage = CloudflareR2StorageService(
|
||||
account_id=storage_settings.r2_account_id,
|
||||
access_key_id=storage_settings.r2_access_key_id,
|
||||
secret_access_key=storage_settings.r2_secret_access_key,
|
||||
bucket_name=storage_settings.r2_bucket_name,
|
||||
public_url_base=storage_settings.r2_public_url_base,
|
||||
)
|
||||
elif settings.storage_service == StorageServiceType.AWS_S3:
|
||||
storage_settings = cast(AWSS3StorageSettings, settings.storage_settings)
|
||||
storage = AWSS3StorageService(
|
||||
access_key_id=storage_settings.s3_access_key_id,
|
||||
secret_access_key=storage_settings.s3_secret_access_key,
|
||||
bucket_name=storage_settings.s3_bucket_name,
|
||||
public_url_base=storage_settings.s3_public_url_base,
|
||||
region_name=storage_settings.s3_region_name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported storage service: {settings.storage_service}")
|
||||
return storage
|
||||
|
||||
|
||||
def get_storage_service():
|
||||
if storage is None:
|
||||
return init_storage_service()
|
||||
return storage
|
||||
@@ -1,34 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
|
||||
from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.security import (
|
||||
HTTPBearer,
|
||||
OAuth2AuthorizationCodeBearer,
|
||||
OAuth2PasswordBearer,
|
||||
SecurityScopes,
|
||||
)
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
oauth2_password = OAuth2PasswordBearer(
|
||||
tokenUrl="oauth/token",
|
||||
scopes={"*": "Allows access to all scopes."},
|
||||
)
|
||||
|
||||
oauth2_code = OAuth2AuthorizationCodeBearer(
|
||||
authorizationUrl="oauth/authorize",
|
||||
tokenUrl="oauth/token",
|
||||
scopes={
|
||||
"chat.read": "Allows read chat messages on a user's behalf.",
|
||||
"chat.write": "Allows sending chat messages on a user's behalf.",
|
||||
"chat.write_manage": (
|
||||
"Allows joining and leaving chat channels on a user's behalf."
|
||||
),
|
||||
"delegate": (
|
||||
"Allows acting as the owner of a client; "
|
||||
"only available for Client Credentials Grant."
|
||||
),
|
||||
"forum.write": "Allows creating and editing forum posts on a user's behalf.",
|
||||
"friends.read": "Allows reading of the user's friend list.",
|
||||
"identify": "Allows reading of the public profile of the user (/me).",
|
||||
"public": "Allows reading of publicly available data on behalf of the user.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
security_scopes: SecurityScopes,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
token = token_pw or token_code
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user = await get_current_user_by_token(token, db)
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
is_client = token_record.client_id in (
|
||||
settings.osu_client_id,
|
||||
settings.osu_web_client_id,
|
||||
)
|
||||
|
||||
if security_scopes.scopes == ["*"]:
|
||||
# client/web only
|
||||
if not token_pw or not is_client:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
elif not is_client:
|
||||
for scope in security_scopes.scopes:
|
||||
if scope not in token_record.scope.split(","):
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Insufficient scope: {scope}"
|
||||
)
|
||||
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
@@ -38,6 +38,22 @@ class BaseFetcher:
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
||||
if self.is_token_expired():
|
||||
await self.refresh_access_token()
|
||||
header = kwargs.pop("headers", {})
|
||||
header = self.header
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method,
|
||||
url,
|
||||
headers=header,
|
||||
**kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def is_token_expired(self) -> bool:
|
||||
return self.token_expiry <= int(time.time())
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ from app.log import logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(
|
||||
@@ -21,11 +19,10 @@ class BeatmapFetcher(BaseFetcher):
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
|
||||
)
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
|
||||
return BeatmapResp.model_validate(
|
||||
await self.request_api(
|
||||
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
|
||||
headers=self.header,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return BeatmapResp.model_validate(response.json())
|
||||
)
|
||||
|
||||
@@ -5,18 +5,15 @@ from app.log import logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class BeatmapsetFetcher(BaseFetcher):
|
||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
|
||||
)
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}",
|
||||
headers=self.header,
|
||||
|
||||
return BeatmapsetResp.model_validate(
|
||||
await self.request_api(
|
||||
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
return BeatmapsetResp.model_validate(response.json())
|
||||
)
|
||||
|
||||
@@ -120,10 +120,10 @@ logger.add(
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
|
||||
),
|
||||
level=settings.LOG_LEVEL,
|
||||
diagnose=settings.DEBUG,
|
||||
level=settings.log_level,
|
||||
diagnose=settings.debug,
|
||||
)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True)
|
||||
|
||||
uvicorn_loggers = [
|
||||
"uvicorn",
|
||||
|
||||
@@ -14,6 +14,20 @@ class BeatmapRankStatus(IntEnum):
|
||||
QUALIFIED = 3
|
||||
LOVED = 4
|
||||
|
||||
def has_leaderboard(self) -> bool:
|
||||
return self in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
}
|
||||
|
||||
def has_pp(self) -> bool:
|
||||
return self in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
|
||||
|
||||
class Genre(IntEnum):
|
||||
ANY = 0
|
||||
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
|
||||
from app.models.signalr import SignalRUnionMessage, UserState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS = 13
|
||||
|
||||
|
||||
class _UserActivity(SignalRUnionMessage): ...
|
||||
|
||||
@@ -96,16 +98,14 @@ UserActivity = (
|
||||
| ModdingBeatmap
|
||||
| TestingBeatmap
|
||||
| InDailyChallengeLobby
|
||||
| PlayingDailyChallenge
|
||||
)
|
||||
|
||||
|
||||
class UserPresence(BaseModel):
|
||||
activity: UserActivity | None = Field(
|
||||
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||
)
|
||||
status: OnlineStatus | None = Field(
|
||||
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||
)
|
||||
activity: UserActivity | None = None
|
||||
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
@@ -126,3 +126,34 @@ class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
DO_NOT_DISTURB = 1
|
||||
ONLINE = 2
|
||||
|
||||
|
||||
class DailyChallengeInfo(BaseModel):
|
||||
room_id: int
|
||||
|
||||
|
||||
class MultiplayerPlaylistItemStats(BaseModel):
|
||||
playlist_item_id: int = 0
|
||||
total_score_distribution: list[int] = Field(
|
||||
default_factory=list,
|
||||
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
)
|
||||
cumulative_score: int = 0
|
||||
last_processed_score_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomStats(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerRoomScoreSetEvent(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_id: int
|
||||
score_id: int
|
||||
user_id: int
|
||||
total_score: int
|
||||
new_rank: int | None = None
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.models.score import GameMode
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
|
||||
@@ -13,3 +15,41 @@ class UTCBaseModel(BaseModel):
|
||||
v = v.replace(tzinfo=UTC)
|
||||
return v.astimezone(UTC).isoformat()
|
||||
return v
|
||||
|
||||
|
||||
Cursor = dict[str, int]
|
||||
|
||||
|
||||
class RespWithCursor(BaseModel):
|
||||
cursor: Cursor | None = None
|
||||
|
||||
|
||||
class PinAttributes(BaseModel):
|
||||
is_pinned: bool
|
||||
score_id: int
|
||||
|
||||
|
||||
class CurrentUserAttributes(BaseModel):
|
||||
can_beatmap_update_owner: bool | None = None
|
||||
can_delete: bool | None = None
|
||||
can_edit_metadata: bool | None = None
|
||||
can_edit_tags: bool | None = None
|
||||
can_hype: bool | None = None
|
||||
can_hype_reason: str | None = None
|
||||
can_love: bool | None = None
|
||||
can_remove_from_loved: bool | None = None
|
||||
is_watching: bool | None = None
|
||||
new_hype_time: datetime | None = None
|
||||
nomination_modes: list[GameMode] | None = None
|
||||
remaining_hype: int | None = None
|
||||
can_destroy: bool | None = None
|
||||
can_reopen: bool | None = None
|
||||
can_moderate_kudosu: bool | None = None
|
||||
can_resolve: bool | None = None
|
||||
vote_score: int | None = None
|
||||
can_message: bool | None = None
|
||||
can_message_error: str | None = None
|
||||
last_read_id: int | None = None
|
||||
can_new_comment: bool | None = None
|
||||
can_new_comment_reason: str | None = None
|
||||
pin: PinAttributes | None = None
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings as app_settings
|
||||
from app.path import STATIC_DIR
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: NotRequired[dict[str, bool | float | str]]
|
||||
settings: NotRequired[dict[str, bool | float | str | int]]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
@@ -129,10 +131,10 @@ COMMON_CONFIG: dict[str, dict] = {
|
||||
}
|
||||
|
||||
RANKED_MODS: dict[int, dict[str, dict]] = {
|
||||
0: COMMON_CONFIG,
|
||||
1: COMMON_CONFIG,
|
||||
2: COMMON_CONFIG,
|
||||
3: COMMON_CONFIG,
|
||||
0: deepcopy(COMMON_CONFIG),
|
||||
1: deepcopy(COMMON_CONFIG),
|
||||
2: deepcopy(COMMON_CONFIG),
|
||||
3: deepcopy(COMMON_CONFIG),
|
||||
}
|
||||
# osu
|
||||
RANKED_MODS[0]["HD"]["only_fade_approach_circles"] = False
|
||||
@@ -154,8 +156,15 @@ for i in range(4, 10):
|
||||
|
||||
|
||||
def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
|
||||
if app_settings.enable_all_mods_pp:
|
||||
return True
|
||||
ranked_mods = RANKED_MODS[ruleset_id]
|
||||
for mod in mods:
|
||||
if app_settings.enable_osu_rx and mod["acronym"] == "RX" and ruleset_id == 0:
|
||||
continue
|
||||
if app_settings.enable_osu_ap and mod["acronym"] == "AP" and ruleset_id == 0:
|
||||
continue
|
||||
|
||||
mod["settings"] = mod.get("settings", {})
|
||||
if (settings := ranked_mods.get(mod["acronym"])) is None:
|
||||
return False
|
||||
|
||||
924
app/models/multiplayer_hub.py
Normal file
924
app/models/multiplayer_hub.py
Normal file
@@ -0,0 +1,924 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
TypedDict,
|
||||
cast,
|
||||
override,
|
||||
)
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import engine
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
|
||||
from .mods import APIMod
|
||||
from .room import (
|
||||
DownloadState,
|
||||
MatchType,
|
||||
MultiplayerRoomState,
|
||||
MultiplayerUserState,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomStatus,
|
||||
)
|
||||
from .signalr import (
|
||||
SignalRMeta,
|
||||
SignalRUnionMessage,
|
||||
UserState,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
PER_USER_LIMIT = 3
|
||||
|
||||
|
||||
class MultiplayerClientState(UserState):
|
||||
room_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
name: str = "Unnamed Room"
|
||||
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
password: str = ""
|
||||
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||
auto_skip: bool = False
|
||||
|
||||
@property
|
||||
def auto_start_enabled(self) -> bool:
|
||||
return self.auto_start_duration != timedelta(seconds=0)
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
state: DownloadState = DownloadState.UNKNOWN
|
||||
download_progress: float | None = None
|
||||
|
||||
|
||||
class _MatchUserState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class TeamVersusUserState(_MatchUserState):
|
||||
team_id: int
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchUserState = TeamVersusUserState
|
||||
|
||||
|
||||
class _MatchRoomState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class MultiplayerTeam(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class TeamVersusRoomState(_MatchRoomState):
|
||||
teams: list[MultiplayerTeam] = Field(
|
||||
default_factory=lambda: [
|
||||
MultiplayerTeam(id=0, name="Team Red"),
|
||||
MultiplayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchRoomState = TeamVersusRoomState
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str
|
||||
ruleset_id: int
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool
|
||||
playlist_order: int
|
||||
played_at: datetime | None = None
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _get_api_mods(self):
|
||||
from app.models.mods import API_MODS, init_mods
|
||||
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
return API_MODS
|
||||
|
||||
def _validate_mod_for_ruleset(
|
||||
self, mod: APIMod, ruleset_key: int, context: str = "mod"
|
||||
) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if (
|
||||
typed_ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[typed_ruleset_key]
|
||||
):
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is invalid for this ruleset"
|
||||
)
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not playable by users"
|
||||
)
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not valid for multiplayer"
|
||||
)
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(
|
||||
f"Mods {mod1['acronym']} and "
|
||||
f"{mod2['acronym']} are incompatible"
|
||||
)
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(
|
||||
f"Required mod {req_acronym} conflicts with "
|
||||
f"allowed mods: {conflict_list}"
|
||||
)
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
proposed_mods: list[APIMod],
|
||||
) -> tuple[bool, list[APIMod]]:
|
||||
"""
|
||||
Validates user mods against playlist item rules and returns valid mods.
|
||||
Returns (is_valid, valid_mods).
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
valid_mods = []
|
||||
all_proposed_valid = True
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if (
|
||||
ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[ruleset_key]
|
||||
):
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
|
||||
# Check mod compatibility within user mods
|
||||
incompatible_mods = set()
|
||||
final_valid_mods = []
|
||||
for mod in valid_mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||
if setting_mods:
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
final_valid_mods.append(mod)
|
||||
|
||||
# If not freestyle, check against allowed mods
|
||||
if not self.freestyle:
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
if mod["acronym"] not in allowed_acronyms:
|
||||
all_proposed_valid = False
|
||||
else:
|
||||
filtered_valid_mods.append(mod)
|
||||
final_valid_mods = filtered_valid_mods
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {
|
||||
mod["acronym"] for mod in final_valid_mods
|
||||
} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
mod_acronym = mod["acronym"]
|
||||
is_compatible = True
|
||||
|
||||
for other_acronym in all_mod_acronyms:
|
||||
if other_acronym == mod_acronym:
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
|
||||
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
|
||||
is_compatible = False
|
||||
all_proposed_valid = False
|
||||
break
|
||||
|
||||
if is_compatible:
|
||||
filtered_valid_mods.append(mod)
|
||||
|
||||
return all_proposed_valid, filtered_valid_mods
|
||||
|
||||
def clone(self) -> "PlaylistItem":
|
||||
copy = self.model_copy()
|
||||
copy.required_mods = list(self.required_mods)
|
||||
copy.allowed_mods = list(self.allowed_mods)
|
||||
copy.expired = False
|
||||
copy.played_at = None
|
||||
return copy
|
||||
|
||||
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[
|
||||
bool, Field(default=True), SignalRMeta(member_ignore=True)
|
||||
] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = (
|
||||
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(
|
||||
state=DownloadState.UNKNOWN, download_progress=None
|
||||
)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
beatmap_id: int | None = None # freestyle
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
room_id: int
|
||||
state: MultiplayerRoomState
|
||||
settings: MultiplayerRoomSettings
|
||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||
host: MultiplayerRoomUser | None = None
|
||||
match_state: MatchRoomState | None = None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room) -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
|
||||
# 用户列表
|
||||
users = [MultiplayerRoomUser(user_id=room.host_id)]
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if hasattr(room, "playlist"):
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
id=item.id,
|
||||
owner_id=item.owner_id,
|
||||
beatmap_id=item.beatmap_id,
|
||||
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
|
||||
ruleset_id=item.ruleset_id,
|
||||
required_mods=item.required_mods,
|
||||
allowed_mods=item.allowed_mods,
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating
|
||||
if item.beatmap is not None
|
||||
else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
room_id=room.id,
|
||||
state=getattr(room, "state", MultiplayerRoomState.OPEN),
|
||||
settings=MultiplayerRoomSettings(
|
||||
name=room.name,
|
||||
playlist_item_id=playlist[0].id if playlist else 0,
|
||||
password=getattr(room, "password", ""),
|
||||
match_type=room.type,
|
||||
queue_mode=room.queue_mode,
|
||||
auto_start_duration=timedelta(seconds=room.auto_start_duration),
|
||||
auto_skip=room.auto_skip,
|
||||
),
|
||||
users=users,
|
||||
host=host_user,
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=getattr(room, "channel_id", 0),
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerQueue:
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.server_room = room
|
||||
self.current_index = 0
|
||||
|
||||
@property
|
||||
def hub(self) -> "MultiplayerHub":
|
||||
return self.server_room.hub
|
||||
|
||||
@property
|
||||
def upcoming_items(self):
|
||||
return sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda i: i.playlist_order,
|
||||
)
|
||||
|
||||
@property
|
||||
def room(self):
|
||||
return self.server_room.room
|
||||
|
||||
async def update_order(self):
|
||||
from app.database import Playlist
|
||||
|
||||
match self.room.settings.queue_mode:
|
||||
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
|
||||
ordered_active_items = []
|
||||
|
||||
is_first_set = True
|
||||
first_set_order_by_user_id = {}
|
||||
|
||||
active_items = [item for item in self.room.playlist if not item.expired]
|
||||
active_items.sort(key=lambda x: x.id)
|
||||
|
||||
user_item_groups = {}
|
||||
for item in active_items:
|
||||
if item.owner_id not in user_item_groups:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max(
|
||||
(len(items) for items in user_item_groups.values()), default=0
|
||||
)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
for user_id, items in user_item_groups.items():
|
||||
if i < len(items):
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(
|
||||
key=lambda item: (item.playlist_order, item.id)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx
|
||||
for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(
|
||||
key=lambda item: first_set_order_by_user_id.get(
|
||||
item.owner_id, 0
|
||||
)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
case _:
|
||||
ordered_active_items = sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
async with AsyncSession(engine) as session:
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
if item.playlist_order == idx:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room, item, beatmap_changed=False
|
||||
)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
next_item = (
|
||||
upcoming_items[0]
|
||||
if upcoming_items
|
||||
else max(
|
||||
self.room.playlist,
|
||||
key=lambda i: i.played_at or datetime.min,
|
||||
)
|
||||
)
|
||||
self.current_index = self.room.playlist.index(next_item)
|
||||
last_id = self.room.settings.playlist_item_id
|
||||
self.room.settings.playlist_item_id = next_item.id
|
||||
if last_id != next_item.id:
|
||||
await self.hub.setting_changed(self.server_room, True)
|
||||
|
||||
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
is_host = self.room.host and self.room.host.user_id == user.user_id
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if (
|
||||
len([True for u in self.room.playlist if u.owner_id == user.user_id])
|
||||
>= limit
|
||||
):
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = beatmap.difficulty_rating
|
||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||
self.room.playlist.append(item)
|
||||
await self.hub.playlist_added(self.server_room, item)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next(
|
||||
(i for i in self.room.playlist if i.id == item.id), None
|
||||
)
|
||||
if existing_item is None:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item that doesn't exist"
|
||||
)
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which has already been played"
|
||||
)
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
|
||||
# Update item in playlist
|
||||
for idx, playlist_item in enumerate(self.room.playlist):
|
||||
if playlist_item.id == item.id:
|
||||
self.room.playlist[idx] = item
|
||||
break
|
||||
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum
|
||||
!= existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
item = next(
|
||||
(i for i in self.room.playlist if i.id == playlist_item_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise InvokeException("Item does not exist in the room")
|
||||
|
||||
# Check if it's the only item and current item
|
||||
if item == self.current_item:
|
||||
upcoming_items = [i for i in self.room.playlist if not i.expired]
|
||||
if len(upcoming_items) == 1:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which has already been played"
|
||||
)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
self.room.playlist.remove(item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
await self.hub.playlist_removed(self.server_room, item.id)
|
||||
|
||||
async def finish_current_item(self):
|
||||
from app.database import Playlist
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
played_at = datetime.now(UTC)
|
||||
await session.execute(
|
||||
update(Playlist)
|
||||
.where(
|
||||
col(Playlist.id) == self.current_item.id,
|
||||
col(Playlist.room_id) == self.room.room_id,
|
||||
)
|
||||
.values(expired=True, played_at=played_at)
|
||||
)
|
||||
self.room.playlist[self.current_index].expired = True
|
||||
self.room.playlist[self.current_index].played_at = played_at
|
||||
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||
await self.update_order()
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_current_item()
|
||||
|
||||
async def update_queue_mode(self):
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
@property
|
||||
def current_item(self):
|
||||
return self.room.playlist[self.current_index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownInfo:
|
||||
countdown: MultiplayerCountdown
|
||||
duration: timedelta
|
||||
task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining
|
||||
if countdown.time_remaining > timedelta(seconds=0)
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
class _MatchRequest(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChangeTeamRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
team_id: int
|
||||
|
||||
|
||||
class StartMatchCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
duration: timedelta
|
||||
|
||||
|
||||
class StopCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
id: int
|
||||
|
||||
|
||||
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
|
||||
|
||||
|
||||
class MatchTypeHandler(ABC):
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.room = room
|
||||
self.hub = room.hub
|
||||
|
||||
@abstractmethod
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_details(self) -> MatchStartedEventDetail: ...
|
||||
|
||||
|
||||
class HeadToHeadHandler(MatchTypeHandler):
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
if user.match_state is not None:
|
||||
user.match_state = None
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
|
||||
return detail
|
||||
|
||||
|
||||
class TeamVersusHandler(MatchTypeHandler):
|
||||
@override
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
super().__init__(room)
|
||||
self.state = TeamVersusRoomState()
|
||||
room.room.match_state = self.state
|
||||
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
|
||||
self.hub.tasks.add(task)
|
||||
task.add_done_callback(self.hub.tasks.discard)
|
||||
|
||||
def _get_best_available_team(self) -> int:
|
||||
for team in self.state.teams:
|
||||
if all(
|
||||
(
|
||||
user.match_state is None
|
||||
or not isinstance(user.match_state, TeamVersusUserState)
|
||||
or user.match_state.team_id != team.id
|
||||
)
|
||||
for user in self.room.room.users
|
||||
):
|
||||
return team.id
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
min_count = min(team_counts.values())
|
||||
for team_id, count in team_counts.items():
|
||||
if count == min_count:
|
||||
return team_id
|
||||
return self.state.teams[0].id if self.state.teams else 0
|
||||
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
best_team_id = self._get_best_available_team()
|
||||
user.match_state = TeamVersusUserState(team_id=best_team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
|
||||
if not isinstance(request, ChangeTeamRequest):
|
||||
return
|
||||
|
||||
if request.team_id not in [team.id for team in self.state.teams]:
|
||||
raise InvokeException("Invalid team ID")
|
||||
|
||||
user.match_state = TeamVersusUserState(team_id=request.team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
|
||||
|
||||
MATCH_TYPE_HANDLERS = {
|
||||
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
|
||||
MatchType.TEAM_VERSUS: TeamVersusHandler,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMultiplayerRoom:
|
||||
room: MultiplayerRoom
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
start_at: datetime
|
||||
hub: "MultiplayerHub"
|
||||
match_type_handler: MatchTypeHandler
|
||||
queue: MultiplayerQueue
|
||||
_next_countdown_id: int
|
||||
_countdown_id_lock: asyncio.Lock
|
||||
_tracked_countdown: dict[int, CountdownInfo]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MultiplayerRoom,
|
||||
category: RoomCategory,
|
||||
start_at: datetime,
|
||||
hub: "MultiplayerHub",
|
||||
):
|
||||
self.room = room
|
||||
self.category = category
|
||||
self.status = RoomStatus.IDLE
|
||||
self.start_at = start_at
|
||||
self.hub = hub
|
||||
self.queue = MultiplayerQueue(self)
|
||||
self._next_countdown_id = 0
|
||||
self._countdown_id_lock = asyncio.Lock()
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
|
||||
self
|
||||
)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
async def get_next_countdown_id(self) -> int:
|
||||
async with self._countdown_id_lock:
|
||||
self._next_countdown_id += 1
|
||||
return self._next_countdown_id
|
||||
|
||||
async def start_countdown(
|
||||
self,
|
||||
countdown: MultiplayerCountdown,
|
||||
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||
):
|
||||
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||
await asyncio.sleep(info.duration.total_seconds())
|
||||
if on_complete is not None:
|
||||
await on_complete(self)
|
||||
await self.stop_countdown(countdown)
|
||||
|
||||
if countdown.is_exclusive:
|
||||
await self.stop_all_countdowns(countdown.__class__)
|
||||
countdown.id = await self.get_next_countdown_id()
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(
|
||||
self, CountdownStartedEvent(countdown=info.countdown)
|
||||
)
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
info = self._tracked_countdown.get(countdown.id)
|
||||
if info is None:
|
||||
return
|
||||
del self._tracked_countdown[countdown.id]
|
||||
self.room.active_countdowns.remove(countdown)
|
||||
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||
if info.task is not None and not info.task.done():
|
||||
info.task.cancel()
|
||||
|
||||
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
|
||||
for countdown in list(self._tracked_countdown.values()):
|
||||
if isinstance(countdown.countdown, typ):
|
||||
await self.stop_countdown(countdown.countdown)
|
||||
|
||||
|
||||
class _MatchServerEvent(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class CountdownStartedEvent(_MatchServerEvent):
|
||||
countdown: MultiplayerCountdown
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class CountdownStoppedEvent(_MatchServerEvent):
|
||||
id: int
|
||||
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||
|
||||
|
||||
class GameplayAbortReason(IntEnum):
|
||||
LOAD_TOOK_TOO_LONG = 0
|
||||
HOST_ABORTED = 1
|
||||
|
||||
|
||||
class MatchStartedEventDetail(TypedDict):
|
||||
room_type: Literal["playlists", "head_to_head", "team_versus"]
|
||||
team: dict[int, Literal["blue", "red"]] | None
|
||||
@@ -1,7 +1,6 @@
|
||||
# OAuth 相关模型
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
"""注册错误响应模型"""
|
||||
|
||||
form_error: dict
|
||||
|
||||
|
||||
class UserRegistrationErrors(BaseModel):
|
||||
"""用户注册错误模型"""
|
||||
username: List[str] = []
|
||||
user_email: List[str] = []
|
||||
password: List[str] = []
|
||||
|
||||
username: list[str] = []
|
||||
user_email: list[str] = []
|
||||
password: list[str] = []
|
||||
|
||||
|
||||
class RegistrationRequestErrors(BaseModel):
|
||||
"""注册请求错误模型"""
|
||||
|
||||
message: str | None = None
|
||||
redirect: str | None = None
|
||||
user: UserRegistrationErrors | None = None
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.database import User
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from .model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RoomCategory(str, Enum):
|
||||
@@ -17,6 +10,7 @@ class RoomCategory(str, Enum):
|
||||
SPOTLIGHT = "spotlight"
|
||||
FEATURED_ARTIST = "featured_artist"
|
||||
DAILY_CHALLENGE = "daily_challenge"
|
||||
REALTIME = "realtime" # INTERNAL USE ONLY, DO NOT USE IN API
|
||||
|
||||
|
||||
class MatchType(str, Enum):
|
||||
@@ -42,18 +36,40 @@ class RoomStatus(str, Enum):
|
||||
PLAYING = "playing"
|
||||
|
||||
|
||||
class PlaylistItem(UTCBaseModel):
|
||||
id: int | None
|
||||
owner_id: int
|
||||
ruleset_id: int
|
||||
expired: bool
|
||||
playlist_order: int | None
|
||||
played_at: datetime | None
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
beatmap_id: int
|
||||
beatmap: Beatmap | None
|
||||
freestyle: bool
|
||||
class MultiplayerRoomState(str, Enum):
|
||||
OPEN = "open"
|
||||
WAITING_FOR_LOAD = "waiting_for_load"
|
||||
PLAYING = "playing"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class MultiplayerUserState(str, Enum):
|
||||
IDLE = "idle"
|
||||
READY = "ready"
|
||||
WAITING_FOR_LOAD = "waiting_for_load"
|
||||
LOADED = "loaded"
|
||||
READY_FOR_GAMEPLAY = "ready_for_gameplay"
|
||||
PLAYING = "playing"
|
||||
FINISHED_PLAY = "finished_play"
|
||||
RESULTS = "results"
|
||||
SPECTATING = "spectating"
|
||||
|
||||
@property
|
||||
def is_playing(self) -> bool:
|
||||
return self in {
|
||||
self.WAITING_FOR_LOAD,
|
||||
self.PLAYING,
|
||||
self.READY_FOR_GAMEPLAY,
|
||||
self.LOADED,
|
||||
}
|
||||
|
||||
|
||||
class DownloadState(str, Enum):
|
||||
UNKNOWN = "unknown"
|
||||
NOT_DOWNLOADED = "not_downloaded"
|
||||
DOWNLOADING = "downloading"
|
||||
IMPORTING = "importing"
|
||||
LOCALLY_AVAILABLE = "locally_available"
|
||||
|
||||
|
||||
class RoomPlaylistItemStats(BaseModel):
|
||||
@@ -67,39 +83,7 @@ class RoomDifficultyRange(BaseModel):
|
||||
max: float
|
||||
|
||||
|
||||
class ItemAttemptsCount(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCount]
|
||||
|
||||
|
||||
class Room(UTCBaseModel):
|
||||
id: int | None
|
||||
name: str = ""
|
||||
password: str | None
|
||||
has_password: bool = False
|
||||
host: User | None
|
||||
category: RoomCategory = RoomCategory.NORMAL
|
||||
duration: int | None
|
||||
starts_at: datetime | None
|
||||
ends_at: datetime | None
|
||||
participant_count: int = 0
|
||||
recent_participants: list[User] = Field(default_factory=list)
|
||||
max_attempts: int | None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
playlist_item_stats: RoomPlaylistItemStats | None
|
||||
difficulty_range: RoomDifficultyRange | None
|
||||
type: MatchType = MatchType.PLAYLISTS
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_skip: bool = False
|
||||
auto_start_duration: int = 0
|
||||
current_user_score: PlaylistAggregateScore | None
|
||||
current_playlist_item: PlaylistItem | None
|
||||
channel_id: int = 0
|
||||
status: RoomStatus = RoomStatus.IDLE
|
||||
# availability 字段在当前序列化中未包含,但可能在某些场景下需要
|
||||
availability: RoomAvailability | None
|
||||
class PlaylistStatus(BaseModel):
|
||||
count_active: int
|
||||
count_total: int
|
||||
ruleset_ids: list[int]
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
from .mods import API_MODS, APIMod, init_mods
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
class GameMode(str, Enum):
|
||||
@@ -14,13 +16,19 @@ class GameMode(str, Enum):
|
||||
TAIKO = "taiko"
|
||||
FRUITS = "fruits"
|
||||
MANIA = "mania"
|
||||
OSURX = "osurx"
|
||||
OSUAP = "osuap"
|
||||
|
||||
def to_rosu(self) -> "rosu.GameMode":
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
def to_rosu(self) -> rosu.GameMode:
|
||||
return {
|
||||
GameMode.OSU: rosu.GameMode.Osu,
|
||||
GameMode.TAIKO: rosu.GameMode.Taiko,
|
||||
GameMode.FRUITS: rosu.GameMode.Catch,
|
||||
GameMode.MANIA: rosu.GameMode.Mania,
|
||||
GameMode.OSURX: rosu.GameMode.Osu,
|
||||
GameMode.OSUAP: rosu.GameMode.Osu,
|
||||
}[self]
|
||||
|
||||
|
||||
@@ -29,8 +37,11 @@ MODE_TO_INT = {
|
||||
GameMode.TAIKO: 1,
|
||||
GameMode.FRUITS: 2,
|
||||
GameMode.MANIA: 3,
|
||||
GameMode.OSURX: 0,
|
||||
GameMode.OSUAP: 0,
|
||||
}
|
||||
INT_TO_MODE = {v: k for k, v in MODE_TO_INT.items()}
|
||||
INT_TO_MODE[0] = GameMode.OSU
|
||||
|
||||
|
||||
class Rank(str, Enum):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
)
|
||||
|
||||
@@ -15,23 +13,7 @@ from pydantic import (
|
||||
class SignalRMeta:
|
||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||
use_upper_case: bool = False # use upper CamelCase for field names
|
||||
|
||||
|
||||
def _by_index(v: Any, class_: type[Enum]):
|
||||
enum_list = list(class_)
|
||||
if not isinstance(v, int):
|
||||
return v
|
||||
if 0 <= v < len(enum_list):
|
||||
return enum_list[v]
|
||||
raise ValueError(
|
||||
f"Value {v} is out of range for enum "
|
||||
f"{class_.__name__} with {len(enum_list)} items"
|
||||
)
|
||||
|
||||
|
||||
def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
|
||||
return BeforeValidator(lambda v: _by_index(v, enum_class))
|
||||
use_abbr: bool = True
|
||||
|
||||
|
||||
class SignalRUnionMessage(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
@@ -89,9 +89,9 @@ class LegacyReplayFrame(BaseModel):
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: FrameHeader | None = Field(
|
||||
default=None, metadata=[SignalRMeta(member_ignore=True)]
|
||||
)
|
||||
header: Annotated[
|
||||
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
|
||||
]
|
||||
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
from .model import UTCBaseModel
|
||||
|
||||
@@ -83,9 +84,9 @@ class RankHistory(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
|
||||
class Page(BaseModel):
|
||||
html: str = ""
|
||||
raw: str = ""
|
||||
class Page(TypedDict):
|
||||
html: NotRequired[str]
|
||||
raw: NotRequired[str]
|
||||
|
||||
|
||||
class BeatmapsetType(str, Enum):
|
||||
|
||||
@@ -3,6 +3,3 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||
|
||||
REPLAY_DIR = Path(__file__).parent.parent / "replays"
|
||||
REPLAY_DIR.mkdir(exist_ok=True)
|
||||
|
||||
@@ -2,16 +2,17 @@ from __future__ import annotations
|
||||
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
from .private import private_router as private_router
|
||||
from .v2.router import router as api_v2_router
|
||||
|
||||
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
|
||||
__all__ = [
|
||||
"api_v2_router",
|
||||
"auth_router",
|
||||
"fetcher_router",
|
||||
"file_router",
|
||||
"private_router",
|
||||
"signalr_router",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -9,12 +10,14 @@ from app.auth import (
|
||||
generate_refresh_token,
|
||||
get_password_hash,
|
||||
get_token_by_refresh_token,
|
||||
get_user_by_authorization_code,
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import DailyChallengeStats, User
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import logger
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
@@ -26,6 +29,7 @@ from app.models.score import GameMode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -159,14 +163,22 @@ async def register_user(
|
||||
country_code="CN", # 默认国家
|
||||
join_date=datetime.now(UTC),
|
||||
last_visit=datetime.now(UTC),
|
||||
is_supporter=settings.enable_supporter_for_all_users,
|
||||
support_level=int(settings.enable_supporter_for_all_users),
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in GameMode:
|
||||
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
if settings.enable_osu_rx:
|
||||
statistics_rx = UserStatistics(mode=GameMode.OSURX, user_id=new_user.id)
|
||||
db.add(statistics_rx)
|
||||
if settings.enable_osu_ap:
|
||||
statistics_ap = UserStatistics(mode=GameMode.OSUAP, user_id=new_user.id)
|
||||
db.add(statistics_ap)
|
||||
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
|
||||
db.add(daily_challenge_user_stats)
|
||||
await db.commit()
|
||||
@@ -187,21 +199,36 @@ async def register_user(
|
||||
|
||||
@router.post("/oauth/token", response_model=TokenResponse)
|
||||
async def oauth_token(
|
||||
grant_type: str = Form(...),
|
||||
client_id: str = Form(...),
|
||||
grant_type: Literal[
|
||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||
] = Form(...),
|
||||
client_id: int = Form(...),
|
||||
client_secret: str = Form(...),
|
||||
code: str | None = Form(None),
|
||||
scope: str = Form("*"),
|
||||
username: str | None = Form(None),
|
||||
password: str | None = Form(None),
|
||||
refresh_token: str | None = Form(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
if (
|
||||
client_id != settings.OSU_CLIENT_ID
|
||||
or client_secret != settings.OSU_CLIENT_SECRET
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
client = (
|
||||
await db.exec(
|
||||
select(OAuthClient).where(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.client_secret == client_secret,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
is_game_client = (client_id, client_secret) in [
|
||||
(settings.osu_client_id, settings.osu_client_secret),
|
||||
(settings.osu_web_client_id, settings.osu_web_client_secret),
|
||||
]
|
||||
|
||||
if client is None and not is_game_client:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description=(
|
||||
@@ -214,7 +241,6 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
if grant_type == "password":
|
||||
# 密码授权流程
|
||||
if not username or not password:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
@@ -225,6 +251,16 @@ async def oauth_token(
|
||||
),
|
||||
hint="Username and password required",
|
||||
)
|
||||
if scopes != ["*"]:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_scope",
|
||||
description=(
|
||||
"The requested scope is invalid, unknown, "
|
||||
"or malformed. The client may not request "
|
||||
"more than one scope at a time."
|
||||
),
|
||||
hint="Only '*' scope is allowed for password grant type",
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = await authenticate_user(db, username, password)
|
||||
@@ -242,7 +278,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id)}, expires_delta=access_token_expires
|
||||
)
|
||||
@@ -253,15 +289,17 @@ async def oauth_token(
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=scope,
|
||||
)
|
||||
@@ -295,7 +333,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
@@ -305,19 +343,83 @@ async def oauth_token(
|
||||
await store_token(
|
||||
db,
|
||||
token_record.user_id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
)
|
||||
elif grant_type == "authorization_code":
|
||||
if client is None:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description=(
|
||||
"Client authentication failed (e.g., unknown client, "
|
||||
"no client authentication included, "
|
||||
"or unsupported authentication method)."
|
||||
),
|
||||
hint="Invalid client credentials",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
if not code:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description=(
|
||||
"The request is missing a required parameter, "
|
||||
"includes an invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Authorization code required",
|
||||
)
|
||||
|
||||
code_result = await get_user_by_authorization_code(db, redis, client_id, code)
|
||||
if not code_result:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) or refresh token is invalid, "
|
||||
"expired, revoked, does not match the redirection URI used in "
|
||||
"the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Invalid authorization code",
|
||||
)
|
||||
user, scopes = code_result
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id)}, expires_delta=access_token_expires
|
||||
)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user.id
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
else:
|
||||
return create_oauth_error_response(
|
||||
error="unsupported_grant_type",
|
||||
|
||||
@@ -5,7 +5,7 @@ from app.fetcher import Fetcher
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
fetcher_router = APIRouter()
|
||||
fetcher_router = APIRouter(prefix="/fetcher", tags=["fetcher"])
|
||||
|
||||
|
||||
@fetcher_router.get("/callback")
|
||||
|
||||
26
app/router/file.py
Normal file
26
app/router/file.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.storage import LocalStorageService, StorageService
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
file_router = APIRouter(prefix="/file")
|
||||
|
||||
|
||||
@file_router.get("/{path:path}")
|
||||
async def get_file(path: str, storage: StorageService = Depends(get_storage_service)):
|
||||
if not isinstance(storage, LocalStorageService):
|
||||
raise HTTPException(404, "Not Found")
|
||||
if not await storage.is_exists(path):
|
||||
raise HTTPException(404, "Not Found")
|
||||
|
||||
try:
|
||||
return FileResponse(
|
||||
path=storage._get_file_path(path),
|
||||
media_type="application/octet-stream",
|
||||
filename=path.split("/")[-1],
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, "Not Found")
|
||||
8
app/router/private/__init__.py
Normal file
8
app/router/private/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import avatar # noqa: F401
|
||||
from .router import router as private_router
|
||||
|
||||
__all__ = [
|
||||
"private_router",
|
||||
]
|
||||
56
app/router/private/avatar.py
Normal file
56
app/router/private/avatar.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.storage.base import StorageService
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException
|
||||
from PIL import Image
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.post("/avatar/upload", tags=["avatar"])
|
||||
async def upload_avatar(
|
||||
file: str = Body(...),
|
||||
user_id: int = Body(...),
|
||||
storage: StorageService = Depends(get_storage_service),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
content = base64.b64decode(file)
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# check file
|
||||
if len(content) > 5 * 1024 * 1024: # 5MB limit
|
||||
raise HTTPException(status_code=400, detail="File size exceeds 5MB limit")
|
||||
elif len(content) == 0:
|
||||
raise HTTPException(status_code=400, detail="File cannot be empty")
|
||||
with Image.open(BytesIO(content)) as img:
|
||||
if img.format not in ["PNG", "JPEG", "GIF"]:
|
||||
raise HTTPException(status_code=400, detail="Invalid image format")
|
||||
if img.size[0] > 256 or img.size[1] > 256:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Image size exceeds 256x256 pixels"
|
||||
)
|
||||
|
||||
filehash = hashlib.sha256(content).hexdigest()
|
||||
storage_path = f"avatars/{user_id}_{filehash}.png"
|
||||
if not await storage.is_exists(storage_path):
|
||||
await storage.write_file(storage_path, content)
|
||||
url = await storage.get_file_url(storage_path)
|
||||
user.avatar_url = url
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"filehash": filehash,
|
||||
}
|
||||
39
app/router/private/router.py
Normal file
39
app/router/private/router.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
|
||||
|
||||
async def verify_signature(
|
||||
request: Request,
|
||||
ts: int = Header(..., alias="X-Timestamp"),
|
||||
nonce: str = Header(..., alias="X-Nonce"),
|
||||
signature: str = Header(..., alias="X-Signature"),
|
||||
):
|
||||
path = request.url.path
|
||||
data = await request.body()
|
||||
body = data.decode("utf-8")
|
||||
|
||||
py_ts = ts // 1000
|
||||
if abs(time.time() - py_ts) > 30:
|
||||
raise HTTPException(status_code=403, detail="Invalid timestamp")
|
||||
|
||||
payload = f"{path}|{body}|{ts}|{nonce}"
|
||||
expected_sig = hmac.new(
|
||||
settings.private_api_secret.encode(), payload.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(expected_sig, signature):
|
||||
raise HTTPException(status_code=403, detail="Invalid signature")
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/private",
|
||||
dependencies=[Depends(verify_signature)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.room import RoomIndex
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.models.room import Room
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[Room])
|
||||
async def get_all_rooms(
|
||||
mode: str = Query(
|
||||
None
|
||||
), # TODO: lazer源码显示房间不会是除了open以外的其他状态,先放在这里
|
||||
status: str = Query(None),
|
||||
category: str = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
|
||||
roomsList: list[Room] = []
|
||||
for room_index in all_room_ids:
|
||||
dumped_room = await redis.get(str(room_index.id))
|
||||
if dumped_room:
|
||||
actual_room = Room.model_validate_json(str(dumped_room))
|
||||
if actual_room.status == status and actual_room.category == category:
|
||||
roomsList.append(actual_room)
|
||||
return roomsList
|
||||
@@ -1,227 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
|
||||
from app.database.score import get_leaderboard, process_score, process_user
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
LeaderboardType,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]"),
|
||||
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores, user_score = await get_leaderboard(
|
||||
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
|
||||
)
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.order_by(col(Score.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=await ScoreResp.from_db(db, user_score),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_solo_score(
|
||||
beatmap: int,
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
assert current_user.id
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
)
|
||||
db.add(score_token)
|
||||
await db.commit()
|
||||
await db.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap}/solo/scores/{token}",
|
||||
tags=["beatmap"],
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
async with db:
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
beatmap_status = (
|
||||
await db.exec(
|
||||
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
|
||||
)
|
||||
).first()
|
||||
if beatmap_status is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = beatmap_status in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
17
app/router/v2/__init__.py
Normal file
17
app/router/v2/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
misc,
|
||||
relationship,
|
||||
room,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
from .router import router as api_v2_router
|
||||
|
||||
__all__ = [
|
||||
"api_v2_router",
|
||||
]
|
||||
@@ -17,9 +17,9 @@ from app.models.score import (
|
||||
GameMode,
|
||||
)
|
||||
|
||||
from .api_router import router
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, HTTPException, Query, Security
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
@@ -33,7 +33,7 @@ async def lookup_beatmap(
|
||||
id: int | None = Query(default=None, alias="id"),
|
||||
md5: str | None = Query(default=None, alias="checksum"),
|
||||
filename: str | None = Query(default=None, alias="filename"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -56,7 +56,7 @@ async def lookup_beatmap(
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -74,9 +74,10 @@ class BatchGetResp(BaseModel):
|
||||
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
|
||||
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: User = Depends(get_current_user),
|
||||
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
@@ -86,9 +87,29 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = (
|
||||
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||
).all()
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in b_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
)
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
@@ -105,7 +126,7 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
async def get_beatmap_attributes(
|
||||
beatmap: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
mods: list[str] = Query(default_factory=list),
|
||||
ruleset: GameMode | None = Query(default=None),
|
||||
ruleset_id: int | None = Query(default=None),
|
||||
@@ -2,47 +2,56 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .api_router import router
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from fastapi import Depends, Form, HTTPException, Query, Security
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPStatusError
|
||||
from httpx import HTTPError
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def lookup_beatmapset(
|
||||
beatmap_id: int = Query(),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=db, user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||
if not beatmapset:
|
||||
try:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
await Beatmapset.from_resp(db, resp)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
else:
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid)
|
||||
return await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return resp
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
|
||||
async def download_beatmapset(
|
||||
beatmapset: int,
|
||||
no_video: bool = Query(True, alias="noVideo"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
):
|
||||
if current_user.country_code == "CN":
|
||||
return RedirectResponse(
|
||||
@@ -59,7 +68,7 @@ async def download_beatmapset(
|
||||
async def favourite_beatmapset(
|
||||
beatmapset: int,
|
||||
action: Literal["favourite", "unfavourite"] = Form(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
existing_favourite = (
|
||||
@@ -6,9 +6,9 @@ from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .api_router import router
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Security
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@router.get("/me/", response_model=UserResp)
|
||||
async def get_user_info_default(
|
||||
ruleset: GameMode | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await UserResp.from_db(
|
||||
25
app/router/v2/misc.py
Normal file
25
app/router/v2/misc.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from .router import router
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Background(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class BackgroundsResp(BaseModel):
|
||||
ends_at: datetime = datetime(year=9999, month=12, day=31, tzinfo=UTC)
|
||||
backgrounds: list[Background]
|
||||
|
||||
|
||||
@router.get("/seasonal-backgrounds", response_model=BackgroundsResp)
|
||||
async def get_seasonal_backgrounds():
|
||||
return BackgroundsResp(
|
||||
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
|
||||
)
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
from .api_router import router
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from fastapi import Depends, HTTPException, Query, Request, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -17,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
async def get_relationship(
|
||||
request: Request,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
relationship_type = (
|
||||
@@ -43,7 +42,7 @@ class AddFriendResp(BaseModel):
|
||||
async def add_relationship(
|
||||
request: Request,
|
||||
target: int = Query(),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
relationship_type = (
|
||||
@@ -106,7 +105,7 @@ async def add_relationship(
|
||||
async def delete_relationship(
|
||||
request: Request,
|
||||
target: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
relationship_type = (
|
||||
361
app/router/v2/room.py
Normal file
361
app/router/v2/room.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
|
||||
from app.database.playlists import Playlist, PlaylistResp
|
||||
from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.room import RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
|
||||
async def get_all_rooms(
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open"
|
||||
),
|
||||
category: RoomCategory = Query(RoomCategory.NORMAL),
|
||||
status: RoomStatus | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
resp_list: list[RoomResp] = []
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
|
||||
now = datetime.now(UTC)
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
if mode == "open":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_(None))
|
||||
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
|
||||
)
|
||||
if category == RoomCategory.REALTIME:
|
||||
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
if mode == "participated":
|
||||
where_clauses.append(
|
||||
exists().where(
|
||||
col(RoomParticipatedUser.room_id) == Room.id,
|
||||
col(RoomParticipatedUser.user_id) == current_user.id,
|
||||
)
|
||||
)
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
if mode == "ended":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_not(None))
|
||||
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
|
||||
)
|
||||
|
||||
db_rooms = (
|
||||
(
|
||||
await db.exec(
|
||||
select(Room).where(
|
||||
*where_clauses,
|
||||
)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
for room in db_rooms:
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
if category == RoomCategory.REALTIME:
|
||||
mp_room = MultiplayerHubs.rooms.get(room.id)
|
||||
resp.has_password = (
|
||||
bool(mp_room.room.settings.password.strip())
|
||||
if mp_room is not None
|
||||
else False
|
||||
)
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp_list.append(resp)
|
||||
|
||||
return resp_list
|
||||
|
||||
|
||||
class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession
|
||||
):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
RoomParticipatedUser.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if participated_user is None:
|
||||
participated_user = RoomParticipatedUser(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
joined_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(participated_user)
|
||||
else:
|
||||
participated_user.left_at = None
|
||||
participated_user.joined_at = datetime.now(UTC)
|
||||
db_room.participant_count += 1
|
||||
|
||||
|
||||
@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
|
||||
async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
):
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db)
|
||||
# await db.commit()
|
||||
# await db.refresh(db_room)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
|
||||
|
||||
@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
|
||||
async def get_room(
|
||||
room: int,
|
||||
category: str = Query(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
# 直接从db获取信息,毕竟都一样
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(
|
||||
db_room, include=["current_user_score"], session=db, user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.delete("/rooms/{room}", tags=["room"])
|
||||
async def delete_room(
|
||||
room: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
else:
|
||||
db_room.ends_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.put("/rooms/{room}/users/{user}", tags=["room"])
|
||||
async def add_user_to_room(
|
||||
room: int,
|
||||
user: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is not None:
|
||||
await _participate_room(room, user, db_room, db)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(404, "room not found0")
|
||||
|
||||
|
||||
@router.delete("/rooms/{room}/users/{user}", tags=["room"])
|
||||
async def remove_user_from_room(
|
||||
room: int,
|
||||
user: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is not None:
|
||||
participated_user = (
|
||||
await db.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
RoomParticipatedUser.room_id == room,
|
||||
RoomParticipatedUser.user_id == user,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if participated_user is not None:
|
||||
participated_user.left_at = datetime.now(UTC)
|
||||
db_room.participant_count -= 1
|
||||
await db.commit()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
|
||||
class APILeaderboard(BaseModel):
|
||||
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
|
||||
user_score: ItemAttemptsResp | None = None
|
||||
|
||||
|
||||
@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
|
||||
async def get_room_leaderboard(
|
||||
room: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
aggs = await db.exec(
|
||||
select(ItemAttemptsCount)
|
||||
.where(ItemAttemptsCount.room_id == room)
|
||||
.order_by(col(ItemAttemptsCount.total_score).desc())
|
||||
)
|
||||
aggs_resp = []
|
||||
user_agg = None
|
||||
for i, agg in enumerate(aggs):
|
||||
resp = await ItemAttemptsResp.from_db(agg, db)
|
||||
resp.position = i + 1
|
||||
# resp.accuracy *= 100
|
||||
aggs_resp.append(resp)
|
||||
if agg.user_id == current_user.id:
|
||||
user_agg = resp
|
||||
return APILeaderboard(
|
||||
leaderboard=aggs_resp,
|
||||
user_score=user_agg,
|
||||
)
|
||||
|
||||
|
||||
class RoomEvents(BaseModel):
|
||||
beatmaps: list[BeatmapResp] = Field(default_factory=list)
|
||||
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
|
||||
current_playlist_item_id: int = 0
|
||||
events: list[MultiplayerEventResp] = Field(default_factory=list)
|
||||
first_event_id: int = 0
|
||||
last_event_id: int = 0
|
||||
playlist_items: list[PlaylistResp] = Field(default_factory=list)
|
||||
room: RoomResp
|
||||
user: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
|
||||
async def get_room_events(
|
||||
room_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
after: int | None = Query(None, ge=0),
|
||||
before: int | None = Query(None, ge=0),
|
||||
):
|
||||
events = (
|
||||
await db.exec(
|
||||
select(MultiplayerEvent)
|
||||
.where(
|
||||
MultiplayerEvent.room_id == room_id,
|
||||
col(MultiplayerEvent.id) > after if after is not None else True,
|
||||
col(MultiplayerEvent.id) < before if before is not None else True,
|
||||
)
|
||||
.order_by(col(MultiplayerEvent.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_ids = set()
|
||||
playlist_items = {}
|
||||
beatmap_ids = set()
|
||||
|
||||
event_resps = []
|
||||
first_event_id = 0
|
||||
last_event_id = 0
|
||||
|
||||
current_playlist_item_id = 0
|
||||
for event in events:
|
||||
event_resps.append(MultiplayerEventResp.from_db(event))
|
||||
|
||||
if event.user_id:
|
||||
user_ids.add(event.user_id)
|
||||
|
||||
if event.playlist_item_id is not None and (
|
||||
playitem := (
|
||||
await db.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == event.playlist_item_id,
|
||||
Playlist.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
):
|
||||
current_playlist_item_id = playitem.id
|
||||
playlist_items[event.playlist_item_id] = playitem
|
||||
beatmap_ids.add(playitem.beatmap_id)
|
||||
scores = await db.exec(
|
||||
select(Score).where(
|
||||
Score.playlist_item_id == event.playlist_item_id,
|
||||
Score.room_id == room_id,
|
||||
)
|
||||
)
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
if room := MultiplayerHubs.rooms.get(room_id):
|
||||
current_playlist_item_id = room.queue.current_item.id
|
||||
room_resp = await RoomResp.from_hub(room)
|
||||
else:
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room, db)
|
||||
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
beatmapsets=beatmapset_resps,
|
||||
current_playlist_item_id=current_playlist_item_id,
|
||||
events=event_resps,
|
||||
first_event_id=first_event_id,
|
||||
last_event_id=last_event_id,
|
||||
playlist_items=playlist_items_resps,
|
||||
room=room_resp,
|
||||
user=user_resps,
|
||||
)
|
||||
@@ -2,4 +2,4 @@ from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(prefix="/api/v2")
|
||||
770
app/router/v2/score.py
Normal file
770
app/router/v2/score.py
Normal file
@@ -0,0 +1,770 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
import time
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.config import settings
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
Playlist,
|
||||
Room,
|
||||
Score,
|
||||
ScoreResp,
|
||||
ScoreToken,
|
||||
ScoreTokenResp,
|
||||
User,
|
||||
)
|
||||
from app.database.counts import ReplayWatchedCount
|
||||
from app.database.playlist_attempts import ItemAttemptsCount
|
||||
from app.database.playlist_best_score import (
|
||||
PlaylistBestScore,
|
||||
get_position,
|
||||
process_playlist_best_score,
|
||||
)
|
||||
from app.database.relationship import Relationship, RelationshipType
|
||||
from app.database.score import (
|
||||
MultiplayerScores,
|
||||
ScoreAround,
|
||||
get_leaderboard,
|
||||
process_score,
|
||||
process_user,
|
||||
)
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.room import RoomCategory
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
LeaderboardType,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
from app.storage.base import StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, Form, HTTPException, Query, Security
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from httpx import HTTPError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 10
|
||||
|
||||
|
||||
async def submit_score(
|
||||
info: SoloScoreSubmissionInfo,
|
||||
beatmap: int,
|
||||
token: int,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
try:
|
||||
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = (
|
||||
db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_leaderboard
|
||||
)
|
||||
beatmap_length = db_beatmap.total_length
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
item_id,
|
||||
room_id,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, beatmap_length, ranked)
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]"),
|
||||
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores, user_score = await get_leaderboard(
|
||||
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
|
||||
)
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.order_by(col(Score.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
resp = await ScoreResp.from_db(db, user_score)
|
||||
return BeatmapUserScore(
|
||||
position=resp.rank_global or 0,
|
||||
score=resp,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_solo_score(
|
||||
beatmap: int,
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
assert current_user.id
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
)
|
||||
db.add(score_token)
|
||||
await db.commit()
|
||||
await db.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap}/solo/scores/{token}",
|
||||
tags=["beatmap"],
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
beatmap_id: int = Form(),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
version_hash: str = Form(""),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
|
||||
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
ItemAttemptsCount.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
agg = agg.first()
|
||||
if agg and room.max_attempts and agg.attempts >= room.max_attempts:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="You have reached the maximum attempts for this room",
|
||||
)
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap_id,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
playlist_item_id=playlist_id,
|
||||
)
|
||||
session.add(score_token)
|
||||
await session.commit()
|
||||
await session.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
|
||||
async def submit_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
|
||||
user_id = current_user.id
|
||||
score_resp = await submit_score(
|
||||
info,
|
||||
item.beatmap_id,
|
||||
token,
|
||||
current_user,
|
||||
session,
|
||||
redis,
|
||||
fetcher,
|
||||
item.id,
|
||||
room_id,
|
||||
)
|
||||
await process_playlist_best_score(
|
||||
room_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
score_resp.id,
|
||||
score_resp.total_score,
|
||||
session,
|
||||
redis,
|
||||
)
|
||||
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
||||
return score_resp
|
||||
|
||||
|
||||
class IndexedScoreResp(MultiplayerScores):
|
||||
total: int
|
||||
user_score: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
|
||||
)
|
||||
async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = 50,
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
limit = clamp(limit, 1, 50)
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.total_score < cursor,
|
||||
)
|
||||
.order_by(col(PlaylistBestScore.total_score).desc())
|
||||
.limit(limit + 1)
|
||||
)
|
||||
).all()
|
||||
has_more = len(scores) > limit
|
||||
if has_more:
|
||||
scores = scores[:-1]
|
||||
|
||||
user_score = None
|
||||
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
|
||||
for score in score_resp:
|
||||
score.position = await get_position(room_id, playlist_id, score.id, session)
|
||||
if score.user_id == current_user.id:
|
||||
user_score = score
|
||||
|
||||
if room.category == RoomCategory.DAILY_CHALLENGE:
|
||||
score_resp = [s for s in score_resp if s.passed]
|
||||
if user_score and not user_score.passed:
|
||||
user_score = None
|
||||
|
||||
resp = IndexedScoreResp(
|
||||
scores=score_resp,
|
||||
user_score=user_score,
|
||||
total=len(scores),
|
||||
params={
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
if has_more:
|
||||
resp.cursor = {
|
||||
"total_score": scores[-1].total_score,
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def show_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
completed = room.category != RoomCategory.REALTIME
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
if score_record is None:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.score_id == score_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(room_id, playlist_id, score_id, session)
|
||||
if completed:
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
higher_scores = []
|
||||
lower_scores = []
|
||||
for score in scores:
|
||||
if score.total_score > resp.total_score:
|
||||
higher_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
elif score.total_score < resp.total_score:
|
||||
lower_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
resp.scores_around = ScoreAround(
|
||||
higher=MultiplayerScores(scores=higher_scores),
|
||||
lower=MultiplayerScores(scores=lower_scores),
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def get_user_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if score_record:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.put("/score-pins/{score}", status_code=204)
|
||||
async def pin_score(
|
||||
score: int,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score,
|
||||
Score.user_id == current_user.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
if score_record.pinned_order > 0:
|
||||
return
|
||||
|
||||
next_order = (
|
||||
(
|
||||
await db.exec(
|
||||
select(func.max(Score.pinned_order)).where(
|
||||
Score.user_id == current_user.id,
|
||||
Score.gamemode == score_record.gamemode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
or 0
|
||||
) + 1
|
||||
score_record.pinned_order = next_order
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.delete("/score-pins/{score}", status_code=204)
|
||||
async def unpin_score(
|
||||
score: int,
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score, Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
if score_record.pinned_order == 0:
|
||||
return
|
||||
changed_score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.user_id == current_user.id,
|
||||
Score.pinned_order > score_record.pinned_order,
|
||||
Score.gamemode == score_record.gamemode,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for s in changed_score:
|
||||
s.pinned_order -= 1
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/score-pins/{score}/reorder", status_code=204)
|
||||
async def reorder_score_pin(
|
||||
score: int,
|
||||
after_score_id: int | None = Body(default=None),
|
||||
before_score_id: int | None = Body(default=None),
|
||||
current_user: User = Security(get_current_user, scopes=["*"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score, Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
if score_record.pinned_order == 0:
|
||||
raise HTTPException(status_code=400, detail="Score is not pinned")
|
||||
|
||||
if (after_score_id is None) == (before_score_id is None):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either after_score_id or before_score_id "
|
||||
"must be provided (but not both)",
|
||||
)
|
||||
|
||||
all_pinned_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.user_id == current_user.id,
|
||||
Score.pinned_order > 0,
|
||||
Score.gamemode == score_record.gamemode,
|
||||
)
|
||||
.order_by(col(Score.pinned_order))
|
||||
)
|
||||
).all()
|
||||
|
||||
target_order = None
|
||||
reference_score_id = after_score_id or before_score_id
|
||||
|
||||
reference_score = next(
|
||||
(s for s in all_pinned_scores if s.id == reference_score_id), None
|
||||
)
|
||||
if not reference_score:
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
if after_score_id:
|
||||
target_order = reference_score.pinned_order + 1
|
||||
else:
|
||||
target_order = reference_score.pinned_order
|
||||
|
||||
current_order = score_record.pinned_order
|
||||
|
||||
if current_order == target_order:
|
||||
return
|
||||
|
||||
updates = []
|
||||
|
||||
if current_order < target_order:
|
||||
for s in all_pinned_scores:
|
||||
if current_order < s.pinned_order <= target_order and s.id != score:
|
||||
updates.append((s.id, s.pinned_order - 1))
|
||||
if after_score_id:
|
||||
final_target = (
|
||||
target_order - 1 if target_order > current_order else target_order
|
||||
)
|
||||
else:
|
||||
final_target = target_order
|
||||
else:
|
||||
for s in all_pinned_scores:
|
||||
if target_order <= s.pinned_order < current_order and s.id != score:
|
||||
updates.append((s.id, s.pinned_order + 1))
|
||||
final_target = target_order
|
||||
|
||||
for score_id, new_order in updates:
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
score_to_update = (
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
).first()
|
||||
if score_to_update:
|
||||
score_to_update.pinned_order = new_order
|
||||
|
||||
score_record.pinned_order = final_target
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.get("/scores/{score_id}/download")
|
||||
async def download_score_replay(
|
||||
score_id: int,
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
):
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
filepath = f"replays/{score.id}_{score.beatmap_id}_{score.user_id}_lazer_replay.osr"
|
||||
|
||||
if not await storage_service.is_exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Replay file not found")
|
||||
|
||||
is_friend = (
|
||||
score.user_id == current_user.id
|
||||
or (
|
||||
await db.exec(
|
||||
select(exists()).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.target_id == score.user_id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
if not is_friend:
|
||||
replay_watched_count = (
|
||||
await db.exec(
|
||||
select(ReplayWatchedCount).where(
|
||||
ReplayWatchedCount.user_id == score.user_id,
|
||||
ReplayWatchedCount.year == date.today().year,
|
||||
ReplayWatchedCount.month == date.today().month,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if replay_watched_count is None:
|
||||
replay_watched_count = ReplayWatchedCount(
|
||||
user_id=score.user_id, year=date.today().year, month=date.today().month
|
||||
)
|
||||
db.add(replay_watched_count)
|
||||
replay_watched_count.count += 1
|
||||
await db.commit()
|
||||
if isinstance(storage_service, LocalStorageService):
|
||||
return FileResponse(
|
||||
path=await storage_service.get_file_url(filepath),
|
||||
filename=filepath,
|
||||
media_type="application/x-osu-replay",
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
await storage_service.get_file_url(filepath),
|
||||
301,
|
||||
)
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from app.database import (
|
||||
BeatmapPlaycounts,
|
||||
BeatmapPlaycountsResp,
|
||||
@@ -8,16 +11,18 @@ from app.database import (
|
||||
UserResp,
|
||||
)
|
||||
from app.database.lazer_user import SEARCH_INCLUDED
|
||||
from app.database.pp_best_score import PPBestScore
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import BeatmapsetType
|
||||
|
||||
from .api_router import router
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, HTTPException, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
@@ -31,6 +36,7 @@ class BatchUserResponse(BaseModel):
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse)
|
||||
async def get_users(
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(default=False), # TODO: future use
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
@@ -59,6 +65,7 @@ async def get_user_info(
|
||||
user: str,
|
||||
ruleset: GameMode | None = None,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
@@ -86,7 +93,7 @@ async def get_user_info(
|
||||
async def get_user_beatmapsets(
|
||||
user_id: int,
|
||||
type: BeatmapsetType,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
@@ -130,3 +137,59 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/users/{user}/scores/{type}", response_model=list[ScoreResp])
|
||||
async def get_user_scores(
|
||||
user: int,
|
||||
type: Literal["best", "recent", "firsts", "pinned"],
|
||||
legacy_only: bool = Query(False),
|
||||
include_fails: bool = Query(False),
|
||||
mode: GameMode | None = None,
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
db_user = await session.get(User, user)
|
||||
if not db_user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (
|
||||
col(Score.gamemode) == gamemode
|
||||
)
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
where_clause &= Score.pinned_order > 0
|
||||
order_by = col(Score.pinned_order).asc()
|
||||
elif type == "best":
|
||||
where_clause &= exists().where(col(PPBestScore.score_id) == Score.id)
|
||||
order_by = col(Score.pp).desc()
|
||||
elif type == "recent":
|
||||
where_clause &= Score.ended_at > datetime.now(UTC) - timedelta(hours=24)
|
||||
order_by = col(Score.ended_at).desc()
|
||||
elif type == "firsts":
|
||||
# TODO
|
||||
where_clause &= false()
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(where_clause)
|
||||
.order_by(order_by)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
return [
|
||||
await ScoreResp.from_db(
|
||||
session,
|
||||
score,
|
||||
)
|
||||
for score in scores
|
||||
]
|
||||
10
app/service/__init__.py
Normal file
10
app/service/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .daily_challenge import create_daily_challenge_room
|
||||
from .room import create_playlist_room, create_playlist_room_from_api
|
||||
|
||||
__all__ = [
|
||||
"create_daily_challenge_room",
|
||||
"create_playlist_room",
|
||||
"create_playlist_room_from_api",
|
||||
]
|
||||
121
app/service/daily_challenge.py
Normal file
121
app/service/daily_challenge.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.dependencies.scheduler import get_scheduler
|
||||
from app.log import logger
|
||||
from app.models.metadata_hub import DailyChallengeInfo
|
||||
from app.models.mods import APIMod
|
||||
from app.models.room import RoomCategory
|
||||
|
||||
from .room import create_playlist_room
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_daily_challenge_room(
|
||||
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
|
||||
) -> Room:
|
||||
async with AsyncSession(engine) as session:
|
||||
today = datetime.now(UTC).date()
|
||||
return await create_playlist_room(
|
||||
session=session,
|
||||
name=str(today),
|
||||
host_id=3,
|
||||
playlist=[
|
||||
Playlist(
|
||||
id=0,
|
||||
room_id=0,
|
||||
owner_id=3,
|
||||
ruleset_id=ruleset_id,
|
||||
beatmap_id=beatmap,
|
||||
required_mods=required_mods,
|
||||
)
|
||||
],
|
||||
category=RoomCategory.DAILY_CHALLENGE,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge")
|
||||
async def daily_challenge_job():
|
||||
from app.signalr.hub import MetadataHubs
|
||||
|
||||
now = datetime.now(UTC)
|
||||
redis = get_redis()
|
||||
key = f"daily_challenge:{now.date()}"
|
||||
if not await redis.exists(key):
|
||||
return
|
||||
async with AsyncSession(engine) as session:
|
||||
room = (
|
||||
await session.exec(
|
||||
select(Room).where(
|
||||
Room.category == RoomCategory.DAILY_CHALLENGE,
|
||||
col(Room.ends_at) > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if room:
|
||||
return
|
||||
|
||||
try:
|
||||
beatmap = await redis.hget(key, "beatmap") # pyright: ignore[reportGeneralTypeIssues]
|
||||
ruleset_id = await redis.hget(key, "ruleset_id") # pyright: ignore[reportGeneralTypeIssues]
|
||||
required_mods = await redis.hget(key, "required_mods") # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
if beatmap is None or ruleset_id is None:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}."
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
run_date=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
return
|
||||
|
||||
beatmap_int = int(beatmap)
|
||||
ruleset_id_int = int(ruleset_id)
|
||||
|
||||
mods_list = []
|
||||
if required_mods:
|
||||
mods_list = json.loads(required_mods)
|
||||
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
room = await create_daily_challenge_room(
|
||||
beatmap=beatmap_int,
|
||||
ruleset_id=ruleset_id_int,
|
||||
required_mods=mods_list,
|
||||
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
|
||||
)
|
||||
await MetadataHubs.broadcast_call(
|
||||
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
|
||||
)
|
||||
logger.success(
|
||||
"[DailyChallenge] Added today's daily challenge: "
|
||||
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
|
||||
)
|
||||
return
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Error processing daily challenge data: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
run_date=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
42
app/service/osu_rx_statistics.py
Normal file
42
app/service/osu_rx_statistics.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.database.lazer_user import User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import engine
|
||||
from app.models.score import GameMode
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_rx_statistics():
|
||||
async with AsyncSession(engine) as session:
|
||||
users = (await session.exec(select(User.id))).all()
|
||||
for i in users:
|
||||
if settings.enable_osu_rx:
|
||||
is_exist = (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
UserStatistics.user_id == i,
|
||||
UserStatistics.mode == GameMode.OSURX,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not is_exist:
|
||||
statistics_rx = UserStatistics(mode=GameMode.OSURX, user_id=i)
|
||||
session.add(statistics_rx)
|
||||
if settings.enable_osu_ap:
|
||||
is_exist = (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
UserStatistics.user_id == i,
|
||||
UserStatistics.mode == GameMode.OSUAP,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not is_exist:
|
||||
statistics_ap = UserStatistics(mode=GameMode.OSUAP, user_id=i)
|
||||
session.add(statistics_ap)
|
||||
await session.commit()
|
||||
78
app/service/room.py
Normal file
78
app/service/room.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import APIUploadedRoom, Room
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_playlist_room_from_api(
|
||||
session: AsyncSession, room: APIUploadedRoom, host_id: int
|
||||
) -> Room:
|
||||
db_room = room.to_room()
|
||||
db_room.host_id = host_id
|
||||
db_room.starts_at = datetime.now(UTC)
|
||||
db_room.ends_at = db_room.starts_at + timedelta(
|
||||
minutes=db_room.duration if db_room.duration is not None else 0
|
||||
)
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
|
||||
async def create_playlist_room(
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
host_id: int,
|
||||
category: RoomCategory = RoomCategory.NORMAL,
|
||||
duration: int = 30,
|
||||
max_attempts: int | None = None,
|
||||
playlist: list[Playlist] = [],
|
||||
) -> Room:
|
||||
db_room = Room(
|
||||
name=name,
|
||||
category=category,
|
||||
duration=duration,
|
||||
starts_at=datetime.now(UTC),
|
||||
ends_at=datetime.now(UTC) + timedelta(minutes=duration),
|
||||
participant_count=0,
|
||||
max_attempts=max_attempts,
|
||||
type=MatchType.PLAYLISTS,
|
||||
queue_mode=QueueMode.HOST_ONLY,
|
||||
auto_skip=False,
|
||||
auto_start_duration=0,
|
||||
status=RoomStatus.IDLE,
|
||||
host_id=host_id,
|
||||
)
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
await add_playlists_to_room(session, db_room.id, playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
|
||||
async def add_playlists_to_room(
|
||||
session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int
|
||||
):
|
||||
for item in playlist:
|
||||
if not (
|
||||
await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))
|
||||
).first():
|
||||
fetcher = await get_fetcher()
|
||||
await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id)
|
||||
item.id = await Playlist.get_next_id_for_room(room_id, session)
|
||||
item.room_id = room_id
|
||||
item.owner_id = owner_id
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
48
app/service/subscribers/base.py
Normal file
48
app/service/subscribers/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from app.dependencies.database import get_redis_pubsub
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
def __init__(self):
|
||||
self.pubsub = get_redis_pubsub()
|
||||
self.handlers: dict[str, list[Callable[[str, str], Awaitable[Any]]]] = {}
|
||||
self.task: asyncio.Task | None = None
|
||||
|
||||
async def subscribe(self, channel: str):
|
||||
await self.pubsub.subscribe(channel)
|
||||
if channel not in self.handlers:
|
||||
self.handlers[channel] = []
|
||||
|
||||
async def unsubscribe(self, channel: str):
|
||||
if channel in self.handlers:
|
||||
del self.handlers[channel]
|
||||
await self.pubsub.unsubscribe(channel)
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
message = await self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=None
|
||||
)
|
||||
if message is not None and message["type"] == "message":
|
||||
method = self.handlers.get(message["channel"])
|
||||
if method:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
handler(message["channel"], message["data"])
|
||||
for handler in method
|
||||
]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
if self.task is None or self.task.done():
|
||||
self.task = asyncio.create_task(self.listen())
|
||||
|
||||
def stop(self):
|
||||
if self.task is not None and not self.task.done():
|
||||
self.task.cancel()
|
||||
self.task = None
|
||||
87
app/service/subscribers/score_processed.py
Normal file
87
app/service/subscribers/score_processed.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database import PlaylistBestScore, Score
|
||||
from app.database.playlist_best_score import get_position
|
||||
from app.dependencies.database import engine
|
||||
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
|
||||
|
||||
from .base import RedisSubscriber
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MetadataHub
|
||||
|
||||
|
||||
CHANNEL = "score:processed"
|
||||
|
||||
|
||||
class ScoreSubscriber(RedisSubscriber):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.room_subscriber: dict[int, list[int]] = {}
|
||||
self.metadata_hub: "MetadataHub | None " = None
|
||||
self.subscribed = False
|
||||
self.handlers[CHANNEL] = [self._handler]
|
||||
|
||||
async def subscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id not in self.room_subscriber:
|
||||
await self.subscribe(CHANNEL)
|
||||
self.start()
|
||||
self.room_subscriber.setdefault(room_id, []).append(user_id)
|
||||
|
||||
async def unsubscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id in self.room_subscriber:
|
||||
self.room_subscriber[room_id].remove(user_id)
|
||||
if not self.room_subscriber[room_id]:
|
||||
del self.room_subscriber[room_id]
|
||||
|
||||
async def _notify_room_score_processed(self, score_id: int):
|
||||
if not self.metadata_hub:
|
||||
return
|
||||
async with AsyncSession(engine) as session:
|
||||
score = await session.get(Score, score_id)
|
||||
if (
|
||||
not score
|
||||
or not score.passed
|
||||
or score.room_id is None
|
||||
or score.playlist_item_id is None
|
||||
):
|
||||
return
|
||||
if not self.room_subscriber.get(score.room_id, []):
|
||||
return
|
||||
|
||||
new_rank = None
|
||||
user_best = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == score.user_id,
|
||||
PlaylistBestScore.room_id == score.room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if user_best and user_best.score_id == score_id:
|
||||
new_rank = await get_position(
|
||||
user_best.room_id,
|
||||
user_best.playlist_id,
|
||||
user_best.score_id,
|
||||
session,
|
||||
)
|
||||
|
||||
event = MultiplayerRoomScoreSetEvent(
|
||||
room_id=score.room_id,
|
||||
playlist_item_id=score.playlist_item_id,
|
||||
score_id=score_id,
|
||||
user_id=score.user_id,
|
||||
total_score=score.total_score,
|
||||
new_rank=new_rank,
|
||||
)
|
||||
await self.metadata_hub.notify_room_score_processed(event)
|
||||
|
||||
async def _handler(self, channel: str, data: str):
|
||||
score_id = int(data)
|
||||
if self.metadata_hub:
|
||||
await self._notify_room_score_processed(score_id)
|
||||
@@ -6,9 +6,9 @@ import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
@@ -74,7 +74,7 @@ class Client:
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PingPacket())
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
await asyncio.sleep(settings.signalr_ping_interval)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -99,6 +99,16 @@ class Hub[TState: UserState]:
|
||||
return client
|
||||
return default
|
||||
|
||||
def get_before_clients(self, id: str, current_token: str) -> list[Client]:
|
||||
clients = []
|
||||
for client in self.clients.values():
|
||||
if client.connection_id != id:
|
||||
continue
|
||||
if client.connection_token == current_token:
|
||||
continue
|
||||
clients.append(client)
|
||||
return clients
|
||||
|
||||
@abstractmethod
|
||||
def create_state(self, client: Client) -> TState:
|
||||
raise NotImplementedError
|
||||
@@ -117,6 +127,11 @@ class Hub[TState: UserState]:
|
||||
if group_id in self.groups:
|
||||
self.groups[group_id].discard(client)
|
||||
|
||||
async def kick_client(self, client: Client) -> None:
|
||||
await self.call_noblock(client, "DisconnectRequested")
|
||||
await client.send_packet(ClosePacket(allow_reconnect=False))
|
||||
await client.connection.close(code=1000, reason="Disconnected by server")
|
||||
|
||||
async def add_client(
|
||||
self,
|
||||
connection_id: str,
|
||||
@@ -131,7 +146,7 @@ class Hub[TState: UserState]:
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
< time.time() - settings.signalr_negotiate_timeout
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
|
||||
@@ -1,18 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Coroutine
|
||||
from datetime import UTC, datetime
|
||||
import math
|
||||
from typing import override
|
||||
|
||||
from app.database import Relationship, RelationshipType
|
||||
from app.database.lazer_user import User
|
||||
from app.calculator import clamp
|
||||
from app.database import Relationship, RelationshipType, User
|
||||
from app.database.playlist_best_score import PlaylistBestScore
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
|
||||
from app.models.metadata_hub import (
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
DailyChallengeInfo,
|
||||
MetadataClientState,
|
||||
MultiplayerPlaylistItemStats,
|
||||
MultiplayerRoomScoreSetEvent,
|
||||
MultiplayerRoomStats,
|
||||
OnlineStatus,
|
||||
UserActivity,
|
||||
)
|
||||
from app.models.room import RoomCategory
|
||||
from app.service.subscribers.score_processed import ScoreSubscriber
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
@@ -21,11 +37,33 @@ ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
class MetadataHub(Hub[MetadataClientState]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.subscriber = ScoreSubscriber()
|
||||
self.subscriber.metadata_hub = self
|
||||
self._daily_challenge_stats: MultiplayerRoomStats | None = None
|
||||
self._today = datetime.now(UTC).date()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def get_daily_challenge_stats(
|
||||
self, daily_challenge_room: int
|
||||
) -> MultiplayerRoomStats:
|
||||
if (
|
||||
self._daily_challenge_stats is None
|
||||
or self._today != datetime.now(UTC).date()
|
||||
):
|
||||
self._daily_challenge_stats = MultiplayerRoomStats(
|
||||
room_id=daily_challenge_room,
|
||||
playlist_item_stats={},
|
||||
)
|
||||
return self._daily_challenge_stats
|
||||
|
||||
@staticmethod
|
||||
def online_presence_watchers_group() -> str:
|
||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||
|
||||
@staticmethod
|
||||
def room_watcher_group(room_id: int) -> str:
|
||||
return f"metadata:multiplayer-room-watchers:{room_id}"
|
||||
|
||||
def broadcast_tasks(
|
||||
self, user_id: int, store: MetadataClientState | None
|
||||
) -> set[Coroutine]:
|
||||
@@ -102,10 +140,29 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state if friend_state.pushable else None,
|
||||
friend_state.for_push
|
||||
if friend_state.pushable
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
daily_challenge_room = (
|
||||
await session.exec(
|
||||
select(Room).where(
|
||||
col(Room.ends_at) > datetime.now(UTC),
|
||||
Room.category == RoomCategory.DAILY_CHALLENGE,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if daily_challenge_room:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"DailyChallengeUpdated",
|
||||
DailyChallengeInfo(
|
||||
room_id=daily_challenge_room.id,
|
||||
),
|
||||
)
|
||||
redis = get_redis()
|
||||
await redis.set(f"metadata:online:{user_id}", "")
|
||||
|
||||
@@ -161,3 +218,76 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
|
||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||
|
||||
async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent):
|
||||
await self.broadcast_group_call(
|
||||
self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event
|
||||
)
|
||||
|
||||
async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.add_to_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.subscribe_room_score(room_id, client.user_id)
|
||||
stats = self.get_daily_challenge_stats(room_id)
|
||||
await self.update_daily_challenge_stats(stats)
|
||||
return list(stats.playlist_item_stats.values())
|
||||
|
||||
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
||||
async with AsyncSession(engine) as session:
|
||||
playlist_ids = (
|
||||
await session.exec(
|
||||
select(Playlist.id).where(
|
||||
Playlist.room_id == stats.room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for playlist_id in playlist_ids:
|
||||
item = stats.playlist_item_stats.get(playlist_id, None)
|
||||
if item is None:
|
||||
item = MultiplayerPlaylistItemStats(
|
||||
playlist_item_id=playlist_id,
|
||||
total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
cumulative_score=0,
|
||||
last_processed_score_id=0,
|
||||
)
|
||||
stats.playlist_item_stats[playlist_id] = item
|
||||
last_processed_score_id = item.last_processed_score_id
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == stats.room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.score_id > last_processed_score_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
if len(scores) == 0:
|
||||
continue
|
||||
|
||||
async with self._lock:
|
||||
if item.last_processed_score_id == last_processed_score_id:
|
||||
totals = defaultdict(int)
|
||||
for score in scores:
|
||||
bin_index = int(
|
||||
clamp(
|
||||
math.floor(score.total_score / 100000),
|
||||
0,
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS - 1,
|
||||
)
|
||||
)
|
||||
totals[bin_index] += 1
|
||||
|
||||
item.cumulative_score += sum(
|
||||
score.total_score for score in scores
|
||||
)
|
||||
|
||||
for j in range(TOTAL_SCORE_DISTRIBUTION_BINS):
|
||||
item.total_score_distribution[j] += totals.get(j, 0)
|
||||
|
||||
if scores:
|
||||
item.last_processed_score_id = max(
|
||||
score.score_id for score in scores
|
||||
)
|
||||
|
||||
async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.remove_from_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.unsubscribe_room_score(room_id, client.user_id)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,11 +7,13 @@ import struct
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Beatmap, User
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||
from app.models.spectator_hub import (
|
||||
@@ -24,7 +26,6 @@ from app.models.spectator_hub import (
|
||||
StoreClientState,
|
||||
StoreScore,
|
||||
)
|
||||
from app.path import REPLAY_DIR
|
||||
from app.utils import unix_timestamp_to_windows
|
||||
|
||||
from .hub import Client, Hub
|
||||
@@ -63,7 +64,7 @@ def encode_string(s: str) -> bytes:
|
||||
return ret
|
||||
|
||||
|
||||
def save_replay(
|
||||
async def save_replay(
|
||||
ruleset_id: int,
|
||||
md5: str,
|
||||
username: str,
|
||||
@@ -135,8 +136,14 @@ def save_replay(
|
||||
data.extend(struct.pack("<i", len(compressed)))
|
||||
data.extend(compressed)
|
||||
|
||||
replay_path = REPLAY_DIR / f"lazer-{score.type}-{username}-{score.id}.osr"
|
||||
replay_path.write_bytes(data)
|
||||
storage_service = get_storage_service()
|
||||
replay_path = (
|
||||
f"replays/{score.id}_{score.beatmap_id}_{score.user_id}_lazer_replay.osr"
|
||||
)
|
||||
await storage_service.write_file(
|
||||
replay_path,
|
||||
bytes(data),
|
||||
)
|
||||
|
||||
|
||||
class SpectatorHub(Hub[StoreClientState]):
|
||||
@@ -179,15 +186,13 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
return
|
||||
if state.beatmap_id is None or state.ruleset_id is None:
|
||||
return
|
||||
|
||||
fetcher = await get_fetcher()
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap).where(Beatmap.id == state.beatmap_id)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
return
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=state.beatmap_id
|
||||
)
|
||||
user = (
|
||||
await session.exec(select(User).where(User.id == user_id))
|
||||
).first()
|
||||
@@ -237,16 +242,17 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
score = store.score
|
||||
assert store.beatmap_status is not None
|
||||
assert store.state is not None
|
||||
assert store.score is not None
|
||||
if not score or not store.score_token:
|
||||
if (
|
||||
score is None
|
||||
or store.score_token is None
|
||||
or store.beatmap_status is None
|
||||
or store.state is None
|
||||
):
|
||||
return
|
||||
if (
|
||||
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
|
||||
) and any(
|
||||
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
|
||||
):
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and store.beatmap_status.has_leaderboard()
|
||||
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
|
||||
await self._process_score(store, client)
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
@@ -296,7 +302,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
score_record.has_replay = True
|
||||
await session.commit()
|
||||
await session.refresh(score_record)
|
||||
save_replay(
|
||||
await save_replay(
|
||||
ruleset_id=store.ruleset_id,
|
||||
md5=store.checksum,
|
||||
username=store.score.score_info.user.name,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||
from app.utils import camel_to_snake, snake_to_camel
|
||||
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
from pydantic import BaseModel
|
||||
@@ -97,6 +97,8 @@ class MsgpackProtocol:
|
||||
return [cls.serialize_msgpack(item) for item in v]
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif issubclass(typ, datetime.timedelta):
|
||||
return int(v.total_seconds() * 10_000_000)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||
@@ -126,15 +128,19 @@ class MsgpackProtocol:
|
||||
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||
if isinstance(v, list):
|
||||
d = {}
|
||||
for i, f in enumerate(typ.model_fields.items()):
|
||||
field, info = f
|
||||
if info.exclude:
|
||||
i = 0
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
anno = info.annotation
|
||||
if anno is None:
|
||||
d[camel_to_snake(field)] = v[i]
|
||||
continue
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
else:
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
i += 1
|
||||
return d
|
||||
return v
|
||||
|
||||
@@ -209,7 +215,9 @@ class MsgpackProtocol:
|
||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return v[0]
|
||||
elif isinstance(v, list):
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
return datetime.timedelta(seconds=int(v / 10_000_000))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
@@ -234,7 +242,9 @@ class MsgpackProtocol:
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
|
||||
if not all(
|
||||
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
@@ -292,36 +302,55 @@ class MsgpackProtocol:
|
||||
|
||||
class JSONProtocol:
|
||||
@classmethod
|
||||
def serialize_to_json(cls, v: Any):
|
||||
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_model(v)
|
||||
return cls.serialize_model(v, in_union)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
||||
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
return [cls.serialize_to_json(item) for item in v]
|
||||
elif isinstance(v, datetime.datetime):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, Enum):
|
||||
elif isinstance(v, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
total_seconds = int(v.total_seconds())
|
||||
hours, remainder = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||
elif isinstance(v, Enum) and dict_key:
|
||||
return v.value
|
||||
elif isinstance(v, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
|
||||
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
|
||||
d = {}
|
||||
is_union = issubclass(v.__class__, SignalRUnionMessage)
|
||||
for field, info in v.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
||||
cls.serialize_to_json(getattr(v, field))
|
||||
name = (
|
||||
snake_to_camel(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
if not is_union
|
||||
else snake_to_pascal(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
)
|
||||
if issubclass(v.__class__, SignalRUnionMessage):
|
||||
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
|
||||
if is_union and not in_union:
|
||||
return {
|
||||
"$dtype": v.__class__.__name__,
|
||||
"$value": d,
|
||||
@@ -339,7 +368,12 @@ class JSONProtocol:
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
value = v.get(snake_to_camel(field, not from_union))
|
||||
name = (
|
||||
snake_to_camel(field, metadata.use_abbr if metadata else True)
|
||||
if not from_union
|
||||
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
|
||||
)
|
||||
value = v.get(name)
|
||||
anno = typ.model_fields[field].annotation
|
||||
if anno is None:
|
||||
d[field] = value
|
||||
@@ -397,7 +431,18 @@ class JSONProtocol:
|
||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
elif isinstance(v, list):
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
parts = v.split(":")
|
||||
if len(parts) == 3:
|
||||
return datetime.timedelta(
|
||||
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
|
||||
)
|
||||
elif len(parts) == 2:
|
||||
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
|
||||
elif len(parts) == 1:
|
||||
return datetime.timedelta(seconds=int(parts[0]))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
|
||||
@@ -6,26 +6,26 @@ import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
from .hub import Hubs
|
||||
from .packet import PROTOCOLS, SEP
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
||||
from fastapi.security import SecurityScopes
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(prefix="/signalr", tags=["SignalR"])
|
||||
|
||||
|
||||
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
|
||||
async def negotiate(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
||||
user: User = Depends(get_current_user),
|
||||
user: DBUser = Depends(get_current_user),
|
||||
):
|
||||
connectionId = str(user.id)
|
||||
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
||||
@@ -55,9 +55,15 @@ async def connect(
|
||||
if id not in hub_:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if (user := await get_current_user_by_token(token, db)) is None or str(
|
||||
user.id
|
||||
) != user_id:
|
||||
try:
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["*"]), db, token_pw=token
|
||||
)
|
||||
) is None or str(user.id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except HTTPException:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await websocket.accept()
|
||||
@@ -92,6 +98,11 @@ async def connect(
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
connected_clients = hub_.get_before_clients(user_id, id)
|
||||
for connected_client in connected_clients:
|
||||
await hub_.kick_client(connected_client)
|
||||
|
||||
await hub_.clean_state(client, False)
|
||||
task = asyncio.create_task(hub_.on_connect(client))
|
||||
hub_.tasks.add(task)
|
||||
|
||||
13
app/storage/__init__.py
Normal file
13
app/storage/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .aws_s3 import AWSS3StorageService
|
||||
from .base import StorageService
|
||||
from .cloudflare_r2 import CloudflareR2StorageService
|
||||
from .local import LocalStorageService
|
||||
|
||||
__all__ = [
|
||||
"AWSS3StorageService",
|
||||
"CloudflareR2StorageService",
|
||||
"LocalStorageService",
|
||||
"StorageService",
|
||||
]
|
||||
103
app/storage/aws_s3.py
Normal file
103
app/storage/aws_s3.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import StorageService
|
||||
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class AWSS3StorageService(StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
region_name: str,
|
||||
public_url_base: str | None = None,
|
||||
):
|
||||
self.bucket_name = bucket_name
|
||||
self.public_url_base = public_url_base
|
||||
self.session = aioboto3.Session()
|
||||
self.access_key_id = access_key_id
|
||||
self.secret_access_key = secret_access_key
|
||||
self.region_name = region_name
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str | None:
|
||||
return None
|
||||
|
||||
def _get_client(self):
|
||||
return self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.access_key_id,
|
||||
aws_secret_access_key=self.secret_access_key,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
async with self._get_client() as client:
|
||||
await client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
Body=content,
|
||||
ContentType=content_type,
|
||||
CacheControl=cache_control,
|
||||
)
|
||||
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
response = await client.get_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
async with response["Body"] as stream:
|
||||
return await stream.read()
|
||||
except ClientError as e:
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
raise RuntimeError(f"Failed to read file from R2: {e}")
|
||||
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
await client.delete_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Failed to delete file from R2: {e}")
|
||||
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
await client.head_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
return True
|
||||
except ClientError as e:
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
return False
|
||||
raise RuntimeError(f"Failed to check file existence in R2: {e}")
|
||||
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
if self.public_url_base:
|
||||
return f"{self.public_url_base.rstrip('/')}/{file_path.lstrip('/')}"
|
||||
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": file_path},
|
||||
)
|
||||
return url
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Failed to generate file URL: {e}")
|
||||
34
app/storage/base.py
Normal file
34
app/storage/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class StorageService(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
26
app/storage/cloudflare_r2.py
Normal file
26
app/storage/cloudflare_r2.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .aws_s3 import AWSS3StorageService
|
||||
|
||||
|
||||
class CloudflareR2StorageService(AWSS3StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str,
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
public_url_base: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
access_key_id=access_key_id,
|
||||
secret_access_key=secret_access_key,
|
||||
bucket_name=bucket_name,
|
||||
public_url_base=public_url_base,
|
||||
region_name="auto",
|
||||
)
|
||||
self.account_id = account_id
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str:
|
||||
return f"https://{self.account_id}.r2.cloudflarestorage.com"
|
||||
80
app/storage/local.py
Normal file
80
app/storage/local.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from .base import StorageService
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
||||
class LocalStorageService(StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str,
|
||||
):
|
||||
self.storage_path = Path(storage_path).resolve()
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, file_path: str) -> Path:
|
||||
clean_path = file_path.lstrip("/")
|
||||
full_path = self.storage_path / clean_path
|
||||
|
||||
try:
|
||||
full_path.resolve().relative_to(self.storage_path)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid file path: {file_path}")
|
||||
|
||||
return full_path
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
full_path = self._get_file_path(file_path)
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
await f.write(content)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write file: {e}")
|
||||
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
full_path = self._get_file_path(file_path)
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
async with aiofiles.open(full_path, "rb") as f:
|
||||
return await f.read()
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to read file: {e}")
|
||||
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
full_path = self._get_file_path(file_path)
|
||||
|
||||
if not full_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
full_path.unlink()
|
||||
|
||||
parent = full_path.parent
|
||||
while parent != self.storage_path and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
parent = parent.parent
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to delete file: {e}")
|
||||
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
full_path = self._get_file_path(file_path)
|
||||
return full_path.exists() and full_path.is_file()
|
||||
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
return f"{settings.server_url}file/{file_path.lstrip('/')}"
|
||||
40
app/utils.py
40
app/utils.py
@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
||||
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to camelCase."""
|
||||
if not name:
|
||||
return name
|
||||
@@ -47,12 +47,46 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
if result or not lower_case:
|
||||
if result:
|
||||
result.append(part.capitalize())
|
||||
else:
|
||||
result.append(part.lower())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to PascalCase."""
|
||||
if not name:
|
||||
return name
|
||||
|
||||
parts = name.split("_")
|
||||
if not parts:
|
||||
return name
|
||||
|
||||
# 常见缩写词列表
|
||||
abbreviations = {
|
||||
"id",
|
||||
"url",
|
||||
"api",
|
||||
"http",
|
||||
"https",
|
||||
"xml",
|
||||
"json",
|
||||
"css",
|
||||
"html",
|
||||
"sql",
|
||||
"db",
|
||||
}
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
result.append(part.capitalize())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
79
docker-compose-osurx.yml
Normal file
79
docker-compose-osurx.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
app:
|
||||
# or use
|
||||
# image: mingxuangame/osu-lazer-api-osurx:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile-osurx
|
||||
container_name: osu_api_server_osurx
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- MYSQL_HOST=mysql
|
||||
- MYSQL_PORT=3306
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- ENABLE_OSU_RX=true
|
||||
- ENABLE_OSU_AP=true
|
||||
- ENABLE_ALL_MODS_PP=true
|
||||
- ENABLE_SUPPORTER_FOR_ALL_USERS=true
|
||||
- ENABLE_ALL_BEATMAP_LEADERBOARD=true
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./replays:/app/replays
|
||||
- ./static:/app/static
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: osu_api_mysql_osurx
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
|
||||
- MYSQL_DATABASE=${MYSQL_DATABASE}
|
||||
- MYSQL_USER=${MYSQL_USER}
|
||||
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
- ./mysql-init:/docker-entrypoint-initdb.d
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
|
||||
timeout: 20s
|
||||
retries: 10
|
||||
interval: 10s
|
||||
start_period: 40s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: osu_api_redis_osurx
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
interval: 10s
|
||||
start_period: 10s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
command: redis-server --appendonly yes
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
redis_data:
|
||||
|
||||
networks:
|
||||
osu-network:
|
||||
driver: bridge
|
||||
@@ -1,50 +1,74 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: osu_api_mysql
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: password
|
||||
MYSQL_DATABASE: osu_api
|
||||
MYSQL_USER: osu_user
|
||||
MYSQL_PASSWORD: osu_password
|
||||
ports:
|
||||
- "3306:3306"
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
- ./mysql-init:/docker-entrypoint-initdb.d
|
||||
restart: unless-stopped
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: osu_api_redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
restart: unless-stopped
|
||||
command: redis-server --appendonly yes
|
||||
|
||||
api:
|
||||
build: .
|
||||
container_name: osu_api_server
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
DATABASE_URL: mysql+aiomysql://osu_user:osu_password@mysql:3306/osu_api
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
SECRET_KEY: your-production-secret-key-here
|
||||
OSU_CLIENT_ID: "5"
|
||||
OSU_CLIENT_SECRET: "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
depends_on:
|
||||
- mysql
|
||||
- redis
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./:/app
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
redis_data:
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
app:
|
||||
# or use
|
||||
# image: mingxuangame/osu-lazer-api:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: osu_api_server
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- MYSQL_HOST=mysql
|
||||
- MYSQL_PORT=3306
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./replays:/app/replays
|
||||
- ./static:/app/static
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
|
||||
mysql:
|
||||
image: mysql:8.0
|
||||
container_name: osu_api_mysql
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
|
||||
- MYSQL_DATABASE=${MYSQL_DATABASE}
|
||||
- MYSQL_USER=${MYSQL_USER}
|
||||
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
- ./mysql-init:/docker-entrypoint-initdb.d
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
|
||||
timeout: 20s
|
||||
retries: 10
|
||||
interval: 10s
|
||||
start_period: 40s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: osu_api_redis
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
interval: 10s
|
||||
start_period: 10s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- osu-network
|
||||
command: redis-server --appendonly yes
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
redis_data:
|
||||
|
||||
networks:
|
||||
osu-network:
|
||||
driver: bridge
|
||||
|
||||
13
docker-entrypoint.sh
Normal file
13
docker-entrypoint.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Waiting for database connection..."
|
||||
while ! nc -z $MYSQL_HOST $MYSQL_PORT; do
|
||||
sleep 1
|
||||
done
|
||||
echo "Database connected"
|
||||
|
||||
echo "Running alembic..."
|
||||
uv run --no-sync alembic upgrade head
|
||||
|
||||
exec "$@"
|
||||
158
main.py
158
main.py
@@ -4,29 +4,55 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import create_tables, engine, redis_client
|
||||
from app.dependencies.database import engine, redis_client
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
||||
from app.dependencies.scheduler import init_scheduler, stop_scheduler
|
||||
from app.log import logger
|
||||
from app.router import (
|
||||
api_v2_router,
|
||||
auth_router,
|
||||
fetcher_router,
|
||||
file_router,
|
||||
private_router,
|
||||
signalr_router,
|
||||
)
|
||||
from app.service.daily_challenge import daily_challenge_job
|
||||
from app.service.osu_rx_statistics import create_rx_statistics
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# on startup
|
||||
await create_tables()
|
||||
await create_rx_statistics()
|
||||
await get_fetcher() # 初始化 fetcher
|
||||
init_scheduler()
|
||||
await daily_challenge_job()
|
||||
# on shutdown
|
||||
yield
|
||||
stop_scheduler()
|
||||
await engine.dispose()
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||
app.include_router(api_router, prefix="/api/v2")
|
||||
app.include_router(signalr_router, prefix="/signalr")
|
||||
app.include_router(fetcher_router, prefix="/fetcher")
|
||||
|
||||
app.include_router(api_v2_router)
|
||||
app.include_router(signalr_router)
|
||||
app.include_router(fetcher_router)
|
||||
app.include_router(file_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(private_router)
|
||||
# CORS 配置
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(settings.server_url)],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -41,114 +67,30 @@ async def health_check():
|
||||
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
||||
|
||||
|
||||
# @app.get("/api/v2/friends")
|
||||
# async def get_friends():
|
||||
# return JSONResponse(
|
||||
# content=[
|
||||
# {
|
||||
# "id": 123456,
|
||||
# "username": "BestFriend",
|
||||
# "is_online": True,
|
||||
# "is_supporter": False,
|
||||
# "country": {"code": "US", "name": "United States"},
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
|
||||
|
||||
# @app.get("/api/v2/notifications")
|
||||
# async def get_notifications():
|
||||
# return JSONResponse(content={"notifications": [], "unread_count": 0})
|
||||
|
||||
|
||||
# @app.post("/api/v2/chat/ack")
|
||||
# async def chat_ack():
|
||||
# return JSONResponse(content={"status": "ok"})
|
||||
|
||||
|
||||
# @app.get("/api/v2/users/{user_id}/{mode}")
|
||||
# async def get_user_mode(user_id: int, mode: str):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "id": user_id,
|
||||
# "username": "测试测试测",
|
||||
# "statistics": {
|
||||
# "level": {"current": 97, "progress": 96},
|
||||
# "pp": 114514,
|
||||
# "global_rank": 666,
|
||||
# "country_rank": 1,
|
||||
# "hit_accuracy": 100,
|
||||
# },
|
||||
# "country": {"code": "JP", "name": "Japan"},
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.get("/api/v2/me")
|
||||
# async def get_me():
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "id": 15651670,
|
||||
# "username": "Googujiang",
|
||||
# "is_online": True,
|
||||
# "country": {"code": "JP", "name": "Japan"},
|
||||
# "statistics": {
|
||||
# "level": {"current": 97, "progress": 96},
|
||||
# "pp": 2826.26,
|
||||
# "global_rank": 298026,
|
||||
# "country_rank": 11220,
|
||||
# "hit_accuracy": 95.7168,
|
||||
# },
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/metadata/negotiate")
|
||||
# async def metadata_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "abc123",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/spectator/negotiate")
|
||||
# async def spectator_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "spec456",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/multiplayer/negotiate")
|
||||
# async def multiplayer_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "multi789",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
if settings.secret_key == "your_jwt_secret_here":
|
||||
logger.warning(
|
||||
"jwt_secret_key is unset. Your server is unsafe. "
|
||||
"Use this command to generate: openssl rand -hex 32"
|
||||
)
|
||||
if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
|
||||
logger.warning(
|
||||
"osu_web_client_secret is unset. Your server is unsafe. "
|
||||
"Use this command to generate: openssl rand -hex 40"
|
||||
)
|
||||
if settings.private_api_secret == "your_private_api_secret_here":
|
||||
logger.warning(
|
||||
"private_api_secret is unset. Your server is unsafe. "
|
||||
"Use this command to generate: openssl rand -hex 32"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from app.log import logger # noqa: F401
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
log_config=None, # 禁用uvicorn默认日志配置
|
||||
access_log=True, # 启用访问日志
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import os
|
||||
|
||||
from app.config import settings
|
||||
from app.database import * # noqa: F403
|
||||
|
||||
from alembic import context
|
||||
@@ -45,7 +45,8 @@ def run_migrations_offline() -> None:
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = os.environ.get("DATABASE_URL", config.get_main_option("sqlalchemy.url"))
|
||||
url = settings.database_url
|
||||
print(url)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
@@ -73,8 +74,7 @@ async def run_async_migrations() -> None:
|
||||
|
||||
"""
|
||||
sa_config = config.get_section(config.config_ini_section, {})
|
||||
if db_url := os.environ.get("DATABASE_URL"):
|
||||
sa_config["sqlalchemy.url"] = db_url
|
||||
sa_config["sqlalchemy.url"] = settings.database_url
|
||||
connectable = async_engine_from_config(
|
||||
sa_config,
|
||||
prefix="sqlalchemy.",
|
||||
|
||||
116
migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py
Normal file
116
migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""gamemode: add osurx & osupp
|
||||
|
||||
Revision ID: 19cdc9ce4dcb
|
||||
Revises: fdb3822a30ba
|
||||
Create Date: 2025-08-10 06:10:08.093591
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "19cdc9ce4dcb"
|
||||
down_revision: str | Sequence[str] | None = "fdb3822a30ba"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"lazer_users",
|
||||
"playmode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"beatmaps",
|
||||
"mode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"lazer_user_statistics",
|
||||
"mode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"score_tokens",
|
||||
"ruleset_id",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"best_scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
op.alter_column(
|
||||
"total_score_best_scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum(
|
||||
"OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"total_score_best_scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"best_scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"scores",
|
||||
"gamemode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"score_tokens",
|
||||
"ruleset_id",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"lazer_user_statistics",
|
||||
"mode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"beatmaps",
|
||||
"mode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
op.alter_column(
|
||||
"lazer_users",
|
||||
"playmode",
|
||||
type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", name="gamemode"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user