chore(merge): merge #53
## Removed
- SignalR server
- `msgpack_lazer_api`
- Unused services
## Changed
- Move once tasks and scheduled tasks into `app.tasks`
- Improve Logs
- Yellow: Tasks
- Blue: Services
- Magenta: Fetcher
- Dark green: Uvicorn
- Red: System
- `#FFC1C1`: Normal
- Redis: use multiple logical databases
- db0: general cache (`redis_client`)
- db1: message cache (`redis_message_client`)
- db2: binary storage (`redis_binary_client`)
- db3: rate limiting (`redis_rate_limit_client`)
- API: move user page API (`/api/v2/users/{user_id}/page`, `/api/v2/me/validate-bbcode`) into private APIs (`/api/private/user/page`, `/api/private/user/validate-bbcode`)
- Remove `from __future__ import annotations` to avoid `ForwardRef` problems
- Assets Proxy: use decorators to simplify code
- Ruff: add rules
- API Router: use Annotated-style dependency injections.
- Database: rename filenames to easily find the model
## Documents
- CONTRIBUTING.md
- AGENTS.md
- copilot-instructions.md
This commit is contained in:
@@ -107,6 +107,6 @@
|
||||
80,
|
||||
8080
|
||||
],
|
||||
"postCreateCommand": "uv sync --dev && uv run alembic upgrade head && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check && cd ../../spectator-server && dotnet restore",
|
||||
"postCreateCommand": "uv sync --dev && uv run alembic upgrade head && uv run pre-commit install && cd spectator-server && dotnet restore",
|
||||
"remoteUser": "vscode"
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ MYSQL_DATABASE="osu_api"
|
||||
MYSQL_USER="osu_api"
|
||||
MYSQL_PASSWORD="password"
|
||||
MYSQL_ROOT_PASSWORD="password"
|
||||
REDIS_URL="redis://127.0.0.1:6379/0"
|
||||
REDIS_URL="redis://127.0.0.1:6379"
|
||||
|
||||
# JWT Settings
|
||||
# Use `openssl rand -hex 32` to generate a secure key
|
||||
|
||||
184
.github/copilot-instructions.md
vendored
Normal file
184
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,184 @@
|
||||
# copilot-instruction
|
||||
|
||||
> 此文件是 AGENTS.md 的复制。一切以 AGENTS.md 为主。
|
||||
|
||||
> 使用自动化与 AI 代理(GitHub Copilot、依赖/CI 机器人,以及仓库中的运行时调度器/worker)的指导原则,适用于 g0v0-server 仓库。
|
||||
|
||||
---
|
||||
|
||||
## API 参考
|
||||
|
||||
本项目必须保持与公开的 osu! API 兼容。在添加或映射端点时请参考:
|
||||
|
||||
- **v1(旧版):** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
|
||||
- **v2(OpenAPI):** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
|
||||
|
||||
任何在 `app/router/v1/`、`app/router/v2/` 或 `app/router/notification/` 中的实现必须与官方规范保持一致。自定义或实验性的端点应放在 `app/router/private/` 中。
|
||||
|
||||
---
|
||||
|
||||
## 代理类别
|
||||
|
||||
允许的代理分为三类:
|
||||
|
||||
- **代码生成/补全代理**(如 GitHub Copilot 或其他 LLM)—— **仅当** 有维护者审核并批准输出时允许使用。
|
||||
- **自动维护代理**(如 Dependabot、Renovate、pre-commit.ci)—— 允许使用,但必须遵守严格的 PR 和 CI 政策。
|
||||
- **运行时/后台代理**(调度器、worker)—— 属于产品代码的一部分;必须遵守生命周期、并发和幂等性规范。
|
||||
|
||||
所有由代理生成或建议的更改必须遵守以下规则。
|
||||
|
||||
---
|
||||
|
||||
## 所有代理的规则
|
||||
|
||||
1. **单一职责的 PR。** 代理的 PR 必须只解决一个问题(一个功能、一个 bug 修复或一次依赖更新)。提交信息应使用 Angular 风格(如 `feat(api): add ...`)。
|
||||
2. **通过 Lint 与 CI 检查。** 每个 PR(包括代理创建的)在合并前必须通过 `pyright`、`ruff`、`pre-commit` 钩子和仓库 CI。PR 中应附带 CI 运行结果链接。
|
||||
3. **绝不可提交敏感信息。** 代理不得提交密钥、密码、token 或真实 `.env` 值。如果检测到可能的敏感信息,代理必须中止并通知指定的维护者。
|
||||
4. **API 位置限制。** 不得在 `app/router/v1` 或 `app/router/v2` 下添加新的公开端点,除非该端点在官方 v1/v2 规范中存在。自定义或实验性端点必须放在 `app/router/private/`。
|
||||
5. **保持公共契约稳定。** 未经批准的迁移计划,不得随意修改响应 schema、路由前缀或其他公共契约。若有变更,PR 中必须包含明确的兼容性说明。
|
||||
|
||||
---
|
||||
|
||||
## Copilot / LLM 使用
|
||||
|
||||
> 关于在本仓库中使用 GitHub Copilot 和其他基于 LLM 的辅助工具的统一指导。
|
||||
|
||||
### 关键项目结构(需要了解的内容)
|
||||
|
||||
- **应用入口:** `main.py` —— FastAPI 应用,包含启动/关闭生命周期管理(fetchers、GeoIP、调度器、缓存与健康检查、Redis 消息、统计、成就系统)。
|
||||
|
||||
- **路由:** `app/router/` 包含所有路由组。主要的路由包括:
|
||||
- `v1/`(v1 端点)
|
||||
- `v2/`(v2 端点)
|
||||
- `notification/` 路由(聊天/通知子系统)
|
||||
- `auth.py`(认证/token 流程)
|
||||
- `private/`(自定义或实验性的端点)
|
||||
|
||||
**规则:** `v1/` 和 `v2/` 必须与官方 API 对应。仅内部或实验端点应放在 `app/router/private/`。
|
||||
|
||||
- **模型与数据库工具:**
|
||||
- SQLModel/ORM 模型在 `app/database/`。
|
||||
- 非数据库模型在 `app/models/`。
|
||||
- 修改模型/schema 时必须生成 Alembic 迁移,并手动检查生成的 SQL 与索引。
|
||||
|
||||
- **服务层:** `app/service/` 保存领域逻辑(如缓存工具、通知/邮件逻辑)。复杂逻辑应放在 service,而不是路由处理器中。
|
||||
|
||||
- **任务:** `app/tasks/` 保存任务(定时任务、启动任务、关闭任务)。
|
||||
- 均在 `__init__.py` 进行导出。
|
||||
- 对于启动任务/关闭任务,在 `main.py` 的 `lifespan` 调用。
|
||||
- 定时任务使用 APScheduler
|
||||
|
||||
- **缓存与依赖:** 使用 `app/dependencies/` 提供的 Redis 依赖和缓存服务(遵循现有 key 命名约定,如 `user:{id}:...`)。
|
||||
|
||||
- **日志:** 使用 `app/log` 提供的日志工具。
|
||||
|
||||
### 实用工作流(提示模式)
|
||||
|
||||
- **添加 v2 端点(正确方式):** 在 `app/router/v2/` 下添加文件,导出路由,实现基于数据库与缓存依赖的异步处理函数。**不得**在 v1/v2 添加非官方端点。
|
||||
- **添加自定义端点:** 放在 `app/router/private/`,保持处理器精简,将业务逻辑放入 `app/service/`。
|
||||
- **鉴权:** 使用 [`app.dependencies.user`](../app/dependencies/user.py) 提供的依赖注入,如 `ClientUser` 和 `get_current_user`,参考下方。
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Security
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
|
||||
|
||||
@router.get("/some-api")
|
||||
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
|
||||
...
|
||||
|
||||
|
||||
@router.get("/some-client-api")
|
||||
async def _(current_user: ClientUser):
|
||||
...
|
||||
```
|
||||
|
||||
- **添加后台任务:** 将任务逻辑写在 `app/service/_job.py`(幂等、可重试)。调度器入口放在 `app/scheduler/_scheduler.py`,并在应用生命周期注册。
|
||||
- **数据库 schema 变更:** 修改 `app/models/` 中的 SQLModel 模型,运行 `alembic revision --autogenerate`,检查迁移并本地测试 `alembic upgrade head` 后再提交。
|
||||
- **缓存写入与响应:** 使用现有的 `UserResp` 模式和 `UserCacheService`;异步缓存写入应使用后台任务。
|
||||
|
||||
### 提示指导(给 LLM/Copilot 的输入)
|
||||
|
||||
- 明确文件位置和限制(如:`Add an async endpoint under app/router/private/... DO NOT add to app/router/v1 or v2`)。
|
||||
- 要求异步处理函数、依赖注入 DB/Redis、复用已有服务/工具、加上类型注解,并生成最小化 pytest 测试样例。
|
||||
|
||||
### 约定与质量要求
|
||||
|
||||
- **使用 Annotated-style 依赖注入** 在路由处理器中。
|
||||
- **提交信息风格:** `type(scope): subject`(Angular 风格)。
|
||||
- **优先异步:** 路由必须为异步函数;避免阻塞事件循环。
|
||||
- **关注点分离:** 业务逻辑应放在 service,而不是路由中。
|
||||
- **错误处理:** 客户端错误用 `HTTPException`,服务端错误使用结构化日志。
|
||||
- **类型与 lint:** 在请求评审前,代码必须通过 `pyright` 和 `ruff` 检查。
|
||||
- **注释:** 避免过多注释,仅为晦涩逻辑添加简洁的“魔法注释”。
|
||||
- **日志:** 使用 `app.log` 提供的 `log` 函数获取 logger 实例。(服务、任务除外)
|
||||
|
||||
### 工具参考
|
||||
|
||||
```
|
||||
uv sync
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
pyright
|
||||
ruff .
|
||||
alembic revision --autogenerate -m "feat(db): ..."
|
||||
alembic upgrade head
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### PR 范围指导
|
||||
|
||||
- 保持 PR 专注:一次只做一件事(如端点或重构,不要混合)。
|
||||
- 不确定时,请参考现有服务,并添加简短说明性注释。
|
||||
|
||||
### PR 审核规则
|
||||
|
||||
> GitHub Copilot PR review 可参考。
|
||||
|
||||
1. 如果 PR 修改了端点,简要说明端点的用途和预期行为。同时检查是否满足上述的 API 位置限制。
|
||||
2. 如果 PR 修改了数据库模型,必须包含 Alembic 迁移。检查迁移的 SQL 语句和索引是否合理。
|
||||
3. 修改的其他功能需要提供简短的说明。
|
||||
4. 提供性能优化的建议(见下文)。
|
||||
|
||||
---
|
||||
|
||||
## 性能优化提示
|
||||
|
||||
以下为结合本仓库架构(FastAPI + SQLModel/SQLAlchemy、Redis 缓存、后台调度器)总结的性能优化建议:
|
||||
|
||||
### 数据库
|
||||
|
||||
- **仅选择必要字段。** 使用 `select(Model.col1, Model.col2)`,避免 `select(Model)`。
|
||||
|
||||
```py
|
||||
stmt = select(User.id, User.username).where(User.active == True)
|
||||
rows = await session.execute(stmt)
|
||||
```
|
||||
|
||||
- **使用 `select(exists())` 检查存在性。** 避免加载整行:
|
||||
|
||||
```py
|
||||
from sqlalchemy import select, exists
|
||||
exists_stmt = select(exists().where(User.id == some_id))
|
||||
found = await session.scalar(exists_stmt)
|
||||
```
|
||||
|
||||
- **避免 N+1 查询。** 需要关联对象时用 `selectinload`、`joinedload`。
|
||||
|
||||
- **批量操作。** 插入/更新时应批量执行,并放在一个事务中,而不是多个小事务。
|
||||
|
||||
|
||||
### 耗时任务
|
||||
|
||||
- 如果这个任务来自 API Router,请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
|
||||
- 其他情况,使用 `app.utils` 的 `bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
|
||||
|
||||
---
|
||||
|
||||
## 部分 LLM 的额外要求
|
||||
|
||||
### Claude Code
|
||||
|
||||
- 禁止创建额外的测试脚本。
|
||||
|
||||
2
.github/scripts/generate_config_doc.py
vendored
2
.github/scripts/generate_config_doc.py
vendored
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import Enum
|
||||
import importlib.util
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -222,7 +222,6 @@ newrelic.ini
|
||||
logs/
|
||||
osu-server-spectator-master/*
|
||||
spectator-server/
|
||||
.github/copilot-instructions.md
|
||||
osu-web-master/*
|
||||
osu-web-master/.env.dusk.local.example
|
||||
osu-web-master/.env.example
|
||||
|
||||
203
AGENTS.md
203
AGENTS.md
@@ -1,117 +1,118 @@
|
||||
# AGENTS.md
|
||||
# AGENTS
|
||||
|
||||
> Guidelines for using automation and AI agents (GitHub Copilot, dependency/CI bots, and in-repo runtime schedulers/workers) with the g0v0-server repository.
|
||||
> 使用自动化与 AI 代理(GitHub Copilot、依赖/CI 机器人,以及仓库中的运行时调度器/worker)的指导原则,适用于 g0v0-server 仓库。
|
||||
|
||||
---
|
||||
|
||||
## API References
|
||||
## API 参考
|
||||
|
||||
This project must stay compatible with the public osu! APIs. Use these references when adding or mapping endpoints:
|
||||
本项目必须保持与公开的 osu! API 兼容。在添加或映射端点时请参考:
|
||||
|
||||
- **v1 (legacy):** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
|
||||
- **v2 (OpenAPI):** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
|
||||
- **v1(旧版):** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
|
||||
- **v2(OpenAPI):** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
|
||||
|
||||
Any implementation in `app/router/v1/`, `app/router/v2/`, or `app/router/notification/` must match official endpoints from the corresponding specification above. Custom or experimental endpoints belong in `app/router/private/`.
|
||||
任何在 `app/router/v1/`、`app/router/v2/` 或 `app/router/notification/` 中的实现必须与官方规范保持一致。自定义或实验性的端点应放在 `app/router/private/` 中。
|
||||
|
||||
---
|
||||
|
||||
## Agent Categories
|
||||
## 代理类别
|
||||
|
||||
Agents are allowed in three categories:
|
||||
允许的代理分为三类:
|
||||
|
||||
- **Code authoring / completion agents** (e.g. GitHub Copilot or other LLMs) — allowed **only** when a human maintainer reviews and approves the output.
|
||||
- **Automated maintenance agents** (e.g. Dependabot, Renovate, pre-commit.ci) — allowed but must follow strict PR and CI policies.
|
||||
- **Runtime / background agents** (schedulers, workers) — part of the product code; must follow lifecycle, concurrency, and idempotency conventions.
|
||||
- **代码生成/补全代理**(如 GitHub Copilot 或其他 LLM)—— **仅当** 有维护者审核并批准输出时允许使用。
|
||||
- **自动维护代理**(如 Dependabot、Renovate、pre-commit.ci)—— 允许使用,但必须遵守严格的 PR 和 CI 政策。
|
||||
- **运行时/后台代理**(调度器、worker)—— 属于产品代码的一部分;必须遵守生命周期、并发和幂等性规范。
|
||||
|
||||
All changes produced or suggested by agents must comply with the rules below.
|
||||
所有由代理生成或建议的更改必须遵守以下规则。
|
||||
|
||||
---
|
||||
|
||||
## Rules for All Agents
|
||||
## 所有代理的规则
|
||||
|
||||
1. **Human review required.** Any code, configuration, or documentation generated by an AI or automation agent must be reviewed and approved by a human maintainer familiar with g0v0-server. Do not merge agent PRs without explicit human approval.
|
||||
2. **Single-responsibility PRs.** Agent PRs must address one concern only (one feature, one bugfix, or one dependency update). Use Angular-style commit messages (e.g. `feat(api): add ...`).
|
||||
3. **Lint & CI compliance.** Every PR (including agent-created ones) must pass `pyright`, `ruff`, `pre-commit` hooks, and the repository CI before merging. Include links to CI runs in the PR.
|
||||
4. **Never commit secrets.** Agents must not add keys, passwords, tokens, or real `.env` values. If a suspected secret is detected, the agent must abort and notify a designated human.
|
||||
5. **API location constraints.** Do not add new public endpoints under `app/router/v1` or `app/router/v2` unless the endpoints exist in the official v1/v2 specs. Custom or experimental endpoints must go under `app/router/private/`.
|
||||
6. **Stable public contracts.** Avoid changing response schemas, route prefixes, or other public contracts without an approved migration plan and explicit compatibility notes in the PR.
|
||||
1. **单一职责的 PR。** 代理的 PR 必须只解决一个问题(一个功能、一个 bug 修复或一次依赖更新)。提交信息应使用 Angular 风格(如 `feat(api): add ...`)。
|
||||
2. **通过 Lint 与 CI 检查。** 每个 PR(包括代理创建的)在合并前必须通过 `pyright`、`ruff`、`pre-commit` 钩子和仓库 CI。PR 中应附带 CI 运行结果链接。
|
||||
3. **绝不可提交敏感信息。** 代理不得提交密钥、密码、token 或真实 `.env` 值。如果检测到可能的敏感信息,代理必须中止并通知指定的维护者。
|
||||
4. **API 位置限制。** 不得在 `app/router/v1` 或 `app/router/v2` 下添加新的公开端点,除非该端点在官方 v1/v2 规范中存在。自定义或实验性端点必须放在 `app/router/private/`。
|
||||
5. **保持公共契约稳定。** 未经批准的迁移计划,不得随意修改响应 schema、路由前缀或其他公共契约。若有变更,PR 中必须包含明确的兼容性说明。
|
||||
|
||||
---
|
||||
|
||||
## Copilot / LLM Usage
|
||||
## Copilot / LLM 使用
|
||||
|
||||
> Consolidated guidance for using GitHub Copilot and other LLM-based helpers with this repository.
|
||||
> 关于在本仓库中使用 GitHub Copilot 和其他基于 LLM 的辅助工具的统一指导。
|
||||
|
||||
### Key project structure (what you should know)
|
||||
### 关键项目结构(需要了解的内容)
|
||||
|
||||
- **App entry:** `main.py` — FastAPI application with lifespan startup/shutdown orchestration (fetchers, GeoIP, schedulers, cache and health checks, Redis messaging, stats, achievements).
|
||||
- **应用入口:** `main.py` —— FastAPI 应用,包含启动/关闭生命周期管理(fetchers、GeoIP、调度器、缓存与健康检查、Redis 消息、统计、成就系统)。
|
||||
|
||||
- **Routers:** `app/router/` contains route groups. Important routers exposed by the project include:
|
||||
- **路由:** `app/router/` 包含所有路由组。主要的路由包括:
|
||||
- `v1/`(v1 端点)
|
||||
- `v2/`(v2 端点)
|
||||
- `notification/` 路由(聊天/通知子系统)
|
||||
- `auth.py`(认证/token 流程)
|
||||
- `private/`(自定义或实验性的端点)
|
||||
|
||||
- `api_v1_router` (v1 endpoints)
|
||||
- `api_v2_router` (v2 endpoints)
|
||||
- `notification` routers (chat/notification subsystems)
|
||||
- `auth_router` (authentication/token flows)
|
||||
- `private_router` (internal or server-specific endpoints)
|
||||
**规则:** `v1/` 和 `v2/` 必须与官方 API 对应。仅内部或实验端点应放在 `app/router/private/`。
|
||||
|
||||
**Rules:** `v1/` and `v2/` must mirror the official APIs. Put internal-only or experimental endpoints under `app/router/private/`.
|
||||
- **模型与数据库工具:**
|
||||
- SQLModel/ORM 模型在 `app/database/`。
|
||||
- 非数据库模型在 `app/models/`。
|
||||
- 修改模型/schema 时必须生成 Alembic 迁移,并手动检查生成的 SQL 与索引。
|
||||
|
||||
- **Models & DB helpers:**
|
||||
- **服务层:** `app/service/` 保存领域逻辑(如缓存工具、通知/邮件逻辑)。复杂逻辑应放在 service,而不是路由处理器中。
|
||||
|
||||
- SQLModel/ORM models live in `app/models/`.
|
||||
- DB access helpers and table-specific helpers live in `app/database/`.
|
||||
- For model/schema changes, draft an Alembic migration and manually review the generated SQL and indexes before applying.
|
||||
- **任务:** `app/tasks/` 保存任务(定时任务、启动任务、关闭任务)。
|
||||
- 均在 `__init__.py` 进行导出。
|
||||
- 对于启动任务/关闭任务,在 `main.py` 的 `lifespan` 调用。
|
||||
- 定时任务使用 APScheduler
|
||||
|
||||
- **Services:** `app/service/` holds domain logic (e.g., user ranking calculation, caching helpers, notification/email logic). Heavy logic belongs in services rather than in route handlers.
|
||||
- **缓存与依赖:** 使用 `app/dependencies/` 提供的 Redis 依赖和缓存服务(遵循现有 key 命名约定,如 `user:{id}:...`)。
|
||||
|
||||
- **Schedulers:** `app/scheduler/` contains scheduler starters; implement `start_*_scheduler()` and `stop_*_scheduler()` and register them in `main.py` lifespan handlers.
|
||||
- **日志:** 使用 `app/log` 提供的日志工具。
|
||||
|
||||
- **Caching & dependencies:** Use injected Redis dependencies from `app/dependencies/` and shared cache services (follow existing key naming conventions such as `user:{id}:...`).
|
||||
### 实用工作流(提示模式)
|
||||
|
||||
- **Rust/native extensions:** `packages/msgpack_lazer_api` is a native MessagePack encoder/decoder. When changing native code, run `maturin develop -R` and validate compatibility with Python bindings.
|
||||
- **添加 v2 端点(正确方式):** 在 `app/router/v2/` 下添加文件,导出路由,实现基于数据库与缓存依赖的异步处理函数。**不得**在 v1/v2 添加非官方端点。
|
||||
- **添加自定义端点:** 放在 `app/router/private/`,保持处理器精简,将业务逻辑放入 `app/service/`。
|
||||
- **鉴权:** 使用 [`app.dependencies.user`](./app/dependencies/user.py) 提供的依赖注入,如 `ClientUser` 和 `get_current_user`,参考下方。
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Security
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
|
||||
### Practical playbooks (prompt patterns)
|
||||
|
||||
- **Add a v2 endpoint (correct):** Add files under `app/router/v2/`, export the router, implement async path operations using DB and injected caching dependencies. Do **not** add non-official endpoints to v1/v2.
|
||||
- **Add an internal endpoint:** Add under `app/router/private/`; keep route handlers thin and move business logic into `app/service/`.
|
||||
- **Add a background job:** Put pure job logic in `app/service/_job.py` (idempotent, retry-safe). Add scheduler start/stop functions in `app/scheduler/_scheduler.py`, and register them in the app lifespan.
|
||||
- **DB schema changes:** Update SQLModel models in `app/models/`, run `alembic revision --autogenerate`, inspect the migration, and validate locally with `alembic upgrade head` before committing.
|
||||
- **Cache writes & responses:** Use existing `UserResp` patterns and `UserCacheService` where applicable; use background tasks for asynchronous cache writes.
|
||||
@router.get("/some-api")
|
||||
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
|
||||
...
|
||||
|
||||
### Prompt guidance (what to include for LLMs/Copilot)
|
||||
|
||||
- Specify the exact file location and constraints (e.g. `Add an async endpoint under app/router/private/ ... DO NOT add to app/router/v1 or v2`).
|
||||
- Ask for asynchronous handlers, dependency injection for DB/Redis, reuse of existing services/helpers, type annotations, and a minimal pytest skeleton.
|
||||
- For native edits, require build instructions, ABI compatibility notes, and import validation steps.
|
||||
@router.get("/some-client-api")
|
||||
async def _(current_user: ClientUser):
|
||||
...
|
||||
```
|
||||
|
||||
### Conventions & quality expectations
|
||||
- **添加后台任务:** 将任务逻辑写在 `app/service/_job.py`(幂等、可重试)。调度器入口放在 `app/scheduler/_scheduler.py`,并在应用生命周期注册。
|
||||
- **数据库 schema 变更:** 修改 `app/models/` 中的 SQLModel 模型,运行 `alembic revision --autogenerate`,检查迁移并本地测试 `alembic upgrade head` 后再提交。
|
||||
- **缓存写入与响应:** 使用现有的 `UserResp` 模式和 `UserCacheService`;异步缓存写入应使用后台任务。
|
||||
|
||||
- **Commit message style:** `type(scope): subject` (Angular-style).
|
||||
- **Async-first:** Route handlers must be async; avoid blocking the event loop.
|
||||
- **Separation of concerns:** Business logic should live in services, not inside route handlers.
|
||||
- **Error handling:** Use `HTTPException` for client errors and structured logging for server-side issues.
|
||||
- **Types & linting:** Aim for `pyright`-clean, `ruff`-clean code before requesting review.
|
||||
- **Comments:** Avoid excessive inline comments. Add short, targeted comments to explain non-obvious or "magical" behavior.
|
||||
### 提示指导(给 LLM/Copilot 的输入)
|
||||
|
||||
### Human reviewer checklist
|
||||
- 明确文件位置和限制(如:`Add an async endpoint under app/router/private/... DO NOT add to app/router/v1 or v2`)。
|
||||
- 要求异步处理函数、依赖注入 DB/Redis、复用已有服务/工具、加上类型注解,并生成最小化 pytest 测试样例。
|
||||
|
||||
- Is the code async and non-blocking, with heavy logic in `app/service/`?
|
||||
- Are DB and Redis dependencies injected via the project's dependency utilities?
|
||||
- Are existing cache keys and services reused consistently?
|
||||
- Are tests or test skeletons present and runnable?
|
||||
- If models changed: is an Alembic migration drafted, reviewed, and applied locally?
|
||||
- If native code changed: was `maturin develop -R` executed and validated?
|
||||
- Do `pyright` and `ruff` pass locally?
|
||||
### 约定与质量要求
|
||||
|
||||
### Merge checklist
|
||||
- **使用 Annotated-style 依赖注入** 在路由处理器中。
|
||||
- **提交信息风格:** `type(scope): subject`(Angular 风格)。
|
||||
- **优先异步:** 路由必须为异步函数;避免阻塞事件循环。
|
||||
- **关注点分离:** 业务逻辑应放在 service,而不是路由中。
|
||||
- **错误处理:** 客户端错误用 `HTTPException`,服务端错误使用结构化日志。
|
||||
- **类型与 lint:** 在请求评审前,代码必须通过 `pyright` 和 `ruff` 检查。
|
||||
- **注释:** 避免过多注释,仅为晦涩逻辑添加简洁的“魔法注释”。
|
||||
- **日志:** 使用 `app.log` 提供的 `log` 函数获取 logger 实例。(服务、任务除外)
|
||||
|
||||
- Run `uv sync` to install/update dependencies.
|
||||
- Run `pre-commit` hooks and fix any failures.
|
||||
- Run `pyright` and `ruff` locally and resolve issues.
|
||||
- If native modules changed: run `maturin develop -R`.
|
||||
- If DB migrations changed: run `alembic upgrade head` locally to validate.
|
||||
|
||||
### Tooling reference
|
||||
### 工具参考
|
||||
|
||||
```
|
||||
uv sync
|
||||
@@ -119,34 +120,41 @@ pre-commit install
|
||||
pre-commit run --all-files
|
||||
pyright
|
||||
ruff .
|
||||
maturin develop -R # when native modules changed
|
||||
alembic revision --autogenerate -m "feat(db): ..."
|
||||
alembic upgrade head
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### PR scope guidance
|
||||
### PR 范围指导
|
||||
|
||||
- Keep PRs focused: one concern per PR (e.g., endpoint OR refactor, not both).
|
||||
- Update README/config docs when adding new environment variables.
|
||||
- If unsure about conventions, align with the closest existing service and leave a clarifying comment.
|
||||
- 保持 PR 专注:一次只做一件事(如端点或重构,不要混合)。
|
||||
- 不确定时,请参考现有服务,并添加简短说明性注释。
|
||||
|
||||
### PR 审核规则
|
||||
|
||||
> GitHub Copilot PR review 可参考。
|
||||
|
||||
1. 如果 PR 修改了端点,简要说明端点的用途和预期行为。同时检查是否满足上述的 API 位置限制。
|
||||
2. 如果 PR 修改了数据库模型,必须包含 Alembic 迁移。检查迁移的 SQL 语句和索引是否合理。
|
||||
3. 修改的其他功能需要提供简短的说明。
|
||||
4. 提供性能优化的建议(见下文)。
|
||||
|
||||
---
|
||||
|
||||
## Performance Tips
|
||||
## 性能优化提示
|
||||
|
||||
Below are practical, project-specific performance tips derived from this repository's architecture (FastAPI + SQLModel/SQLAlchemy, Redis caching, background schedulers, and a Rust-native messagepack module).
|
||||
以下为结合本仓库架构(FastAPI + SQLModel/SQLAlchemy、Redis 缓存、后台调度器)总结的性能优化建议:
|
||||
|
||||
### Database
|
||||
### 数据库
|
||||
|
||||
- **Select only required fields.** Fetch only the columns you need using `select(Model.col1, Model.col2)` instead of `select(Model)`.
|
||||
- **仅选择必要字段。** 使用 `select(Model.col1, Model.col2)`,避免 `select(Model)`。
|
||||
|
||||
```py
|
||||
stmt = select(User.id, User.username).where(User.active == True)
|
||||
rows = await session.execute(stmt)
|
||||
```
|
||||
|
||||
- **Use **``** for existence checks.** This avoids loading full rows:
|
||||
- **使用 `select(exists())` 检查存在性。** 避免加载整行:
|
||||
|
||||
```py
|
||||
from sqlalchemy import select, exists
|
||||
@@ -154,34 +162,21 @@ exists_stmt = select(exists().where(User.id == some_id))
|
||||
found = await session.scalar(exists_stmt)
|
||||
```
|
||||
|
||||
- **Avoid N+1 queries.** Use relationship loading strategies (`selectinload`, `joinedload`) when you need related objects.
|
||||
- **避免 N+1 查询。** 需要关联对象时用 `selectinload`、`joinedload`。
|
||||
|
||||
- **Batch operations.** For inserts/updates, use bulk or batched statements inside a single transaction rather than many small transactions.
|
||||
- **批量操作。** 插入/更新时应批量执行,并放在一个事务中,而不是多个小事务。
|
||||
|
||||
- **Indexes & EXPLAIN.** Add indexes on frequently filtered columns and use `EXPLAIN ANALYZE` to inspect slow queries.
|
||||
|
||||
- **Cursor / keyset pagination.** Prefer keyset pagination for large result sets instead of `OFFSET`/`LIMIT` to avoid high-cost scans.
|
||||
### 耗时任务
|
||||
|
||||
### Caching & Redis
|
||||
- 如果这个任务来自 API Router,请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
|
||||
- 其他情况,使用 `app.utils` 的 `bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
|
||||
|
||||
- **Cache hot reads.** Use `UserCacheService` to cache heavy or frequently-requested responses and store compact serialized forms (e.g., messagepack via the native module).
|
||||
---
|
||||
|
||||
- **Use pipelines and multi/exec.** When performing multiple Redis commands, pipeline them to reduce roundtrips.
|
||||
## 部分 LLM 的额外要求
|
||||
|
||||
- **Set appropriate TTLs.** Avoid never-expiring keys; choose TTLs that balance freshness and read amplification.
|
||||
### Claude Code
|
||||
|
||||
- **Prevent cache stampedes.** Use early recompute with jitter or distributed locks (Redis `SET NX` or a small lock library) to avoid many processes rebuilding the same cache.
|
||||
- 禁止创建额外的测试脚本。
|
||||
|
||||
- **Atomic operations with Lua.** For complex multi-step Redis changes, consider a Lua script to keep operations atomic and fast.
|
||||
|
||||
### Background & Long-running Tasks
|
||||
|
||||
- **BackgroundTasks for lightweight work.** FastAPI's `BackgroundTasks` is fine for quick follow-up work (send email, async cache write). For heavy or long tasks, use a scheduler/worker (e.g., a dedicated async worker or job queue).
|
||||
|
||||
- **Use schedulers or workers for heavy jobs.** For expensive recalculations, use the repository's `app/scheduler/` pattern or an external worker system. Keep request handlers responsive — return quickly and delegate.
|
||||
|
||||
- **Throttling & batching.** When processing many items, batch them and apply concurrency limits (semaphore) to avoid saturating DB/Redis.
|
||||
|
||||
### API & Response Performance
|
||||
|
||||
- **Compress large payloads.** Enable gzip/deflate for large JSON responses
|
||||
|
||||
200
CONTRIBUTING.md
200
CONTRIBUTING.md
@@ -54,11 +54,155 @@ dotnet run --project osu.Server.Spectator --urls "http://0.0.0.0:8086"
|
||||
uv sync
|
||||
```
|
||||
|
||||
## 代码质量和代码检查
|
||||
## 开发规范
|
||||
|
||||
我们使用 `pre-commit` 在提交之前执行代码质量标准。这确保所有代码都通过 `ruff`(用于代码检查和格式化)和 `pyright`(用于类型检查)的检查。
|
||||
### 项目结构
|
||||
|
||||
### 设置
|
||||
以下是项目主要目录和文件的结构说明:
|
||||
|
||||
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
|
||||
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
|
||||
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
|
||||
- `app/`: 存放所有核心应用代码。
|
||||
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
|
||||
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
|
||||
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
|
||||
- `models/`: 定义非数据库模型和其他模型。
|
||||
- `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
|
||||
- `dependencies/`: 管理 FastAPI 的依赖项注入。
|
||||
- `achievements/`: 存放与成就相关的逻辑。
|
||||
- `storage/`: 存储服务代码。
|
||||
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
|
||||
- `middleware/`: 定义中间件,例如会话验证。
|
||||
- `helpers/`: 存放辅助函数和工具类。
|
||||
- `config.py`: 应用配置,使用 pydantic-settings 管理。
|
||||
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
|
||||
- `log.py`: 日志记录模块,提供统一的日志接口。
|
||||
- `const.py`: 定义常量。
|
||||
- `path.py`: 定义跨文件使用的常量。
|
||||
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
|
||||
- `static/`: 存放静态文件,如 `mods.json`。
|
||||
|
||||
### 数据库模型定义
|
||||
|
||||
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
|
||||
|
||||
如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构:
|
||||
|
||||
```python
|
||||
class ModelBase(SQLModel):
|
||||
# 定义共有内容
|
||||
...
|
||||
|
||||
|
||||
class Model(ModelBase, table=True):
|
||||
# 定义数据库表内容
|
||||
...
|
||||
|
||||
|
||||
class ModelResp(ModelBase):
|
||||
# 定义响应内容
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db: Model) -> "ModelResp":
|
||||
# 从数据库模型转换
|
||||
...
|
||||
```
|
||||
|
||||
数据库模块名应与表名相同,定义了多个模型的除外。
|
||||
|
||||
如果你需要使用 Session,使用 `app.dependencies.database` 提供的 `with_db`,注意手动使用 `COMMIT`。
|
||||
|
||||
```python
|
||||
from app.dependencies.database import with_db
|
||||
|
||||
async with with_db() as session:
|
||||
...
|
||||
```
|
||||
|
||||
### Redis
|
||||
|
||||
根据你需要的用途选择对应的 Redis 客户端。如果你的用途较为复杂或趋向一个较大的系统,考虑再创建一个 Redis 连接。
|
||||
|
||||
- `redis_client` (db0):标准用途,存储字符串、哈希等常规数据。
|
||||
- `redis_message_client` (db1):用于消息缓存,存储聊天记录等。
|
||||
- `redis_binary_client` (db2):用于存储二进制数据,如音频文件等。
|
||||
- `redis_rate_limit_client` (db3):仅用于 FastAPI-Limiter 使用。
|
||||
|
||||
### API Router
|
||||
|
||||
所有的 API Router 定义在 `app.router` 里:
|
||||
|
||||
- `app/router/v2` 存放所有 osu! v2 API 实现,**不允许添加额外的,原 v2 API 不存在的 Endpoint**
|
||||
- `app/router/notification` **存放所有 osu! v2 API 聊天、通知和 BanchoBot 的实现,不允许添加额外的,原 v2 API 不存在的 Endpoint**
|
||||
- `app/router/v1` 存放所有 osu! v1 API 实现,**不允许添加额外的,原 v1 API 不存在的 Endpoint**
|
||||
- `app/router/auth.py` 存放账户鉴权/登录的 API
|
||||
- `app/router/private` 存放服务器自定义 API (g0v0 API),供其他服务使用
|
||||
|
||||
任何 Router 需要满足:
|
||||
|
||||
- 使用 Annotated-style 的依赖注入
|
||||
- 对于已经存在的依赖注入如 Database 和 Redis,使用 `app.dependencies` 中的实现
|
||||
- 需要拥有文档
|
||||
- 如果返回需要资源代理,使用 `app.helpers.asset_proxy_helper` 的 `asset_proxy_response` 装饰器。
|
||||
- 如果需要记录日志,请使用 `app.log` 提供的 `log` 函数获取一个 logger 实例
|
||||
|
||||
#### 鉴权
|
||||
|
||||
如果这个 Router 可以为公开使用(客户端、前端、OAuth 程序),考虑使用 `Security(get_current_user, scopes=["some_scope"])`,例如:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from fastapi import Security
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
|
||||
@router.get("/some-api")
|
||||
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
|
||||
...
|
||||
```
|
||||
|
||||
其中 scopes 选择请参考 [`app.dependencies.user`](./app/dependencies/user.py) 的 `oauth2_code` 中的 `scopes`。
|
||||
|
||||
如果这个 Router 仅限客户端和前端使用,请使用 `ClientUser` 依赖注入。
|
||||
|
||||
```python
|
||||
from app.dependencies.user import ClientUser
|
||||
|
||||
|
||||
@router.get("/some-api")
|
||||
async def _(current_user: ClientUser):
|
||||
...
|
||||
```
|
||||
|
||||
此外还存在 `get_current_user_and_token` 和 `get_client_user_and_token` 变种,用来同时获得当前用户的 token。
|
||||
|
||||
### Service
|
||||
|
||||
所有的核心业务逻辑放在 `app.service` 里:
|
||||
|
||||
- 业务逻辑需要要以类实现
|
||||
- 日志只需要使用 `app.log` 中的 `logger` 即可。服务器会对 Service 的日志进行包装。
|
||||
|
||||
### 定时任务/启动任务/关闭任务
|
||||
|
||||
均定义在 `app.tasks` 里。
|
||||
|
||||
- 均在 `__init__.py` 进行导出
|
||||
- 对于启动任务/关闭任务,在 `main.py` 的 `lifespan` 调用。
|
||||
- 定时任务使用 APScheduler
|
||||
|
||||
### 耗时任务
|
||||
|
||||
- 如果这个任务来自 API Router,请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
|
||||
- 其他情况,使用 `app.utils` 的 `bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
|
||||
|
||||
### 代码质量和代码检查
|
||||
|
||||
使用 `pre-commit` 在提交之前执行代码质量标准。这确保所有代码都通过 `ruff`(用于代码检查和格式化)和 `pyright`(用于类型检查)的检查。
|
||||
|
||||
#### 设置
|
||||
|
||||
要设置 `pre-commit`,请运行以下命令:
|
||||
|
||||
@@ -70,19 +214,9 @@ pre-commit install
|
||||
|
||||
pre-commit 不提供 pyright 的 hook,您需要手动运行 `pyright` 检查类型错误。
|
||||
|
||||
## 提交信息指南
|
||||
### 提交信息指南
|
||||
|
||||
我们遵循 [AngularJS 提交规范](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#commit-message-format) 来编写提交信息。这使得在查看项目历史记录时,信息更加可读且易于理解。
|
||||
|
||||
每条提交信息由 **标题**、**主体**和 **页脚** 三部分组成。
|
||||
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
<BLANK LINE>
|
||||
<body>
|
||||
<BLANK LINE>
|
||||
<footer>
|
||||
```
|
||||
遵循 [AngularJS 提交规范](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#commit-message-format) 来编写提交信息。
|
||||
|
||||
**类型** 必须是以下之一:
|
||||
|
||||
@@ -97,40 +231,16 @@ pre-commit 不提供 pyright 的 hook,您需要手动运行 `pyright` 检查
|
||||
* **ci**:持续集成相关的更改
|
||||
* **deploy**: 部署相关的更改
|
||||
|
||||
**范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。
|
||||
**范围** 可以是任何指定提交更改位置的内容。例如 `api`、`db`、`auth` 等等。对整个项目的更改使用 `project`。
|
||||
|
||||
**主题** 包含对更改的简洁描述。
|
||||
|
||||
## 项目结构
|
||||
### 持续集成检查
|
||||
|
||||
以下是项目主要目录和文件的结构说明:
|
||||
所有提交应该通过以下 CI 检查:
|
||||
|
||||
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
|
||||
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
|
||||
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
|
||||
- `app/`: 存放所有核心应用代码。
|
||||
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
|
||||
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
|
||||
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
|
||||
- `models/`: 定义非数据库模型和其他模型。
|
||||
- `scheduler/`: 包含由 APScheduler 调度的后台任务。
|
||||
- `dependencies/`: 管理 FastAPI 的依赖项注入。
|
||||
- `achievement.py`: 存放与成就相关的逻辑。
|
||||
- `storage/`: 存储服务代码。
|
||||
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
|
||||
- `config.py`: 应用配置,使用 pydantic-settings 管理。
|
||||
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
|
||||
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
|
||||
- `packages/`: 包含项目相关的独立包。
|
||||
- `msgpack_lazer_api/`: 基于 Rust 的高性能支持 lazer APIMod 的 MessagePack 解析模块,用于与 osu!lazer 客户端通信。
|
||||
- `static/`: 存放静态文件,如 `mods.json`。
|
||||
|
||||
## API Router 规范
|
||||
|
||||
- `app/router/v2` 存放所有 osu! v2 API 实现,不允许添加额外的,原 v2 API 不存在的 Endpoint
|
||||
- `app/router/notification` 存放所有 osu! v2 API 聊天和通知的实现,不允许添加额外的,原 v2 API 不存在的 Endpoint
|
||||
- `app/router/v1` 存放所有 osu! v1 API 实现,不允许添加额外的,原 v1 API 不存在的 Endpoint
|
||||
- `app/router/auth.py` 存放账户鉴权/登录的 API
|
||||
- `app/router/private` 存放服务器自定义 API,供其他服务使用
|
||||
- Ruff Lint
|
||||
- Pyright Lint
|
||||
- pre-commit
|
||||
|
||||
感谢您的贡献!
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from app.database.daily_challenge import DailyChallengeStats
|
||||
@@ -32,11 +30,9 @@ async def process_streak(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if streak <= stats.daily_streak_best < next_streak:
|
||||
return True
|
||||
elif next_streak == 0 and stats.daily_streak_best >= streak:
|
||||
return True
|
||||
return False
|
||||
return bool(
|
||||
streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
|
||||
)
|
||||
|
||||
|
||||
MEDALS = {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app.database.beatmap import calculate_beatmap_attributes
|
||||
@@ -68,9 +66,7 @@ async def to_the_core(
|
||||
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
|
||||
return False
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "DT" not in mods_ or "NC" not in mods_:
|
||||
return False
|
||||
return True
|
||||
return not ("DT" not in mods_ or "NC" not in mods_)
|
||||
|
||||
|
||||
async def wysi(
|
||||
@@ -83,9 +79,7 @@ async def wysi(
|
||||
return False
|
||||
if str(round(score.accuracy, ndigits=4))[3:] != "727":
|
||||
return False
|
||||
if "xi" not in beatmap.beatmapset.artist:
|
||||
return False
|
||||
return True
|
||||
return "xi" in beatmap.beatmapset.artist
|
||||
|
||||
|
||||
async def prepared(
|
||||
@@ -97,9 +91,7 @@ async def prepared(
|
||||
if score.rank != Rank.X and score.rank != Rank.XH:
|
||||
return False
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "NF" not in mods_:
|
||||
return False
|
||||
return True
|
||||
return "NF" in mods_
|
||||
|
||||
|
||||
async def reckless_adandon(
|
||||
@@ -117,9 +109,7 @@ async def reckless_adandon(
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < 3:
|
||||
return False
|
||||
return True
|
||||
return not attribute.star_rating < 3
|
||||
|
||||
|
||||
async def lights_out(
|
||||
@@ -413,11 +403,10 @@ async def by_the_skin_of_the_teeth(
|
||||
return False
|
||||
|
||||
for mod in score.mods:
|
||||
if mod.get("acronym") == "AC":
|
||||
if "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
||||
target_accuracy = mod["settings"]["minimum_accuracy"]
|
||||
if isinstance(target_accuracy, int | float):
|
||||
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
||||
if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
||||
target_accuracy = mod["settings"]["minimum_accuracy"]
|
||||
if isinstance(target_accuracy, int | float):
|
||||
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from app.database.score import Beatmap, Score
|
||||
@@ -19,9 +17,7 @@ async def process_mod(
|
||||
return False
|
||||
if not beatmap.beatmap_status.has_leaderboard():
|
||||
return False
|
||||
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
|
||||
return False
|
||||
return True
|
||||
return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
|
||||
|
||||
|
||||
async def process_category_mod(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from app.database.score import Beatmap, Score
|
||||
@@ -22,11 +20,7 @@ async def process_combo(
|
||||
return False
|
||||
if next_combo != 0 and combo >= next_combo:
|
||||
return False
|
||||
if combo <= score.max_combo < next_combo:
|
||||
return True
|
||||
elif next_combo == 0 and score.max_combo >= combo:
|
||||
return True
|
||||
return False
|
||||
return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from app.database import UserStatistics
|
||||
@@ -35,11 +33,7 @@ async def process_playcount(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if pc <= stats.play_count < next_pc:
|
||||
return True
|
||||
elif next_pc == 0 and stats.play_count >= pc:
|
||||
return True
|
||||
return False
|
||||
return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, cast
|
||||
|
||||
@@ -47,9 +45,7 @@ async def process_skill(
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < star or attribute.star_rating >= star + 1:
|
||||
return False
|
||||
if type == "fc" and not score.is_perfect_combo:
|
||||
return False
|
||||
return True
|
||||
return not (type == "fc" and not score.is_perfect_combo)
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from app.database.score import Beatmap, Score
|
||||
@@ -35,11 +33,7 @@ async def process_tth(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if tth <= stats.total_hits < next_tth:
|
||||
return True
|
||||
elif next_tth == 0 and stats.play_count >= tth:
|
||||
return True
|
||||
return False
|
||||
return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
25
app/auth.py
25
app/auth.py
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import hashlib
|
||||
import re
|
||||
@@ -13,7 +11,7 @@ from app.database import (
|
||||
User,
|
||||
)
|
||||
from app.database.auth import TotpKeys
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
|
||||
from app.utils import utcnow
|
||||
|
||||
@@ -31,6 +29,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
# bcrypt 缓存(模拟应用状态缓存)
|
||||
bcrypt_cache = {}
|
||||
|
||||
logger = log("Auth")
|
||||
|
||||
|
||||
def validate_username(username: str) -> list[str]:
|
||||
"""验证用户名"""
|
||||
@@ -67,7 +67,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
|
||||
2. MD5哈希 -> bcrypt验证
|
||||
"""
|
||||
# 1. 明文密码转 MD5
|
||||
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
|
||||
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
|
||||
|
||||
# 2. 检查缓存
|
||||
if bcrypt_hash in bcrypt_cache:
|
||||
@@ -101,7 +101,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希 - 使用 osu! 的方式"""
|
||||
# 1. 明文密码 -> MD5
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
|
||||
# 2. MD5 -> bcrypt
|
||||
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
|
||||
return pw_bcrypt.decode()
|
||||
@@ -112,7 +112,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
# 1. 明文密码转 MD5
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
user = None
|
||||
@@ -253,7 +253,7 @@ async def store_token(
|
||||
tokens_to_delete = active_tokens[max_tokens_per_client - 1 :]
|
||||
for token in tokens_to_delete:
|
||||
await db.delete(token)
|
||||
logger.info(f"[Auth] Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
|
||||
logger.info(f"Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
|
||||
|
||||
# 检查是否有重复的 access_token
|
||||
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
|
||||
@@ -274,9 +274,7 @@ async def store_token(
|
||||
await db.commit()
|
||||
await db.refresh(token_record)
|
||||
|
||||
logger.info(
|
||||
f"[Auth] Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})"
|
||||
)
|
||||
logger.info(f"Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})")
|
||||
return token_record
|
||||
|
||||
|
||||
@@ -325,12 +323,7 @@ def _generate_totp_account_label(user: User) -> str:
|
||||
|
||||
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
|
||||
"""
|
||||
if settings.totp_use_username_in_label:
|
||||
# 使用用户名作为主要标识
|
||||
primary_identifier = user.username
|
||||
else:
|
||||
# 使用邮箱作为标识
|
||||
primary_identifier = user.email
|
||||
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
|
||||
|
||||
# 如果配置了服务名称,添加到标签中以便在认证器中区分
|
||||
if settings.totp_service_name:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
@@ -7,7 +5,7 @@ import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, parse_enum_to_str
|
||||
from app.models.score import GameMode
|
||||
@@ -18,6 +16,8 @@ from redis.asyncio import Redis
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
logger = log("Calculator")
|
||||
|
||||
try:
|
||||
import rosu_pp_py as rosu
|
||||
except ImportError:
|
||||
@@ -417,9 +417,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
|
||||
if len(hit_objects) > i + per_1s:
|
||||
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
|
||||
return True
|
||||
elif len(hit_objects) > i + per_10s:
|
||||
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
||||
return True
|
||||
elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -446,10 +445,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
|
||||
|
||||
|
||||
def is_2b(hit_objects: list[HitObject]) -> bool:
|
||||
for i in range(0, len(hit_objects) - 1):
|
||||
if hit_objects[i] == hit_objects[i + 1].start_time:
|
||||
return True
|
||||
return False
|
||||
return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
|
||||
|
||||
|
||||
def is_suspicious_beatmap(content: str) -> bool:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# ruff: noqa: I002
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any
|
||||
|
||||
@@ -142,7 +141,7 @@ STORAGE_SETTINGS='{
|
||||
]
|
||||
redis_url: Annotated[
|
||||
str,
|
||||
Field(default="redis://127.0.0.1:6379/0", description="Redis 连接 URL"),
|
||||
Field(default="redis://127.0.0.1:6379", description="Redis 连接 URL"),
|
||||
"数据库设置",
|
||||
]
|
||||
|
||||
@@ -217,7 +216,7 @@ STORAGE_SETTINGS='{
|
||||
# 服务器设置
|
||||
host: Annotated[
|
||||
str,
|
||||
Field(default="0.0.0.0", description="服务器监听地址"),
|
||||
Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
|
||||
"服务器设置",
|
||||
]
|
||||
port: Annotated[
|
||||
@@ -266,18 +265,6 @@ STORAGE_SETTINGS='{
|
||||
else:
|
||||
return "/"
|
||||
|
||||
# SignalR 设置
|
||||
signalr_negotiate_timeout: Annotated[
|
||||
int,
|
||||
Field(default=30, description="SignalR 协商超时时间(秒)"),
|
||||
"SignalR 服务器设置",
|
||||
]
|
||||
signalr_ping_interval: Annotated[
|
||||
int,
|
||||
Field(default=15, description="SignalR ping 间隔(秒)"),
|
||||
"SignalR 服务器设置",
|
||||
]
|
||||
|
||||
# Fetcher 设置
|
||||
fetcher_client_id: Annotated[
|
||||
str,
|
||||
@@ -329,11 +316,6 @@ STORAGE_SETTINGS='{
|
||||
Field(default=False, description="是否启用邮件验证功能"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_smart_verification: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_session_verification: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="是否启用会话验证中间件"),
|
||||
@@ -487,6 +469,12 @@ STORAGE_SETTINGS='{
|
||||
"缓存设置",
|
||||
"谱面缓存",
|
||||
]
|
||||
beatmapset_cache_expire_seconds: Annotated[
|
||||
int,
|
||||
Field(default=3600, description="Beatmapset 缓存过期时间(秒)"),
|
||||
"缓存设置",
|
||||
"谱面缓存",
|
||||
]
|
||||
|
||||
# 排行榜缓存设置
|
||||
enable_ranking_cache: Annotated[
|
||||
@@ -551,12 +539,6 @@ STORAGE_SETTINGS='{
|
||||
"缓存设置",
|
||||
"用户缓存",
|
||||
]
|
||||
user_cache_concurrent_limit: Annotated[
|
||||
int,
|
||||
Field(default=10, description="并发缓存用户的限制"),
|
||||
"缓存设置",
|
||||
"用户缓存",
|
||||
]
|
||||
|
||||
# 资源代理设置
|
||||
enable_asset_proxy: Annotated[
|
||||
@@ -621,26 +603,26 @@ STORAGE_SETTINGS='{
|
||||
]
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
@classmethod
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
@field_validator("storage_settings", mode="after")
|
||||
@classmethod
|
||||
def validate_storage_settings(
|
||||
cls,
|
||||
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
||||
info: ValidationInfo,
|
||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
||||
if not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
||||
if not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
||||
if not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||
service = info.data.get("storage_service")
|
||||
if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||
if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||
if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
BANCHOBOT_ID = 2
|
||||
|
||||
BACKUP_CODE_LENGTH = 10
|
||||
|
||||
@@ -12,7 +12,7 @@ from .beatmapset import (
|
||||
BeatmapsetResp,
|
||||
)
|
||||
from .beatmapset_ratings import BeatmapRating
|
||||
from .best_score import BestScore
|
||||
from .best_scores import BestScore
|
||||
from .chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
@@ -28,22 +28,16 @@ from .counts import (
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .events import Event
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .lazer_user import (
|
||||
MeResp,
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .notification import Notification, UserNotification
|
||||
from .password_reset import PasswordReset
|
||||
from .playlist_attempts import (
|
||||
from .item_attempts_count import (
|
||||
ItemAttemptsCount,
|
||||
ItemAttemptsResp,
|
||||
PlaylistAggregateScore,
|
||||
)
|
||||
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .notification import Notification, UserNotification
|
||||
from .password_reset import PasswordReset
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .pp_best_score import PPBestScore
|
||||
from .rank_history import RankHistory, RankHistoryResp, RankTop
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .room import APIUploadedRoom, Room, RoomResp
|
||||
@@ -62,6 +56,12 @@ from .statistics import (
|
||||
UserStatisticsResp,
|
||||
)
|
||||
from .team import Team, TeamMember, TeamRequest
|
||||
from .total_score_best_scores import TotalScoreBestScore
|
||||
from .user import (
|
||||
MeResp,
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from .user_account_history import (
|
||||
UserAccountHistory,
|
||||
UserAccountHistoryResp,
|
||||
@@ -105,7 +105,6 @@ __all__ = [
|
||||
"Notification",
|
||||
"OAuthClient",
|
||||
"OAuthToken",
|
||||
"PPBestScore",
|
||||
"PasswordReset",
|
||||
"Playlist",
|
||||
"PlaylistAggregateScore",
|
||||
@@ -131,6 +130,7 @@ __all__ = [
|
||||
"Team",
|
||||
"TeamMember",
|
||||
"TeamRequest",
|
||||
"TotalScoreBestScore",
|
||||
"TotpKeys",
|
||||
"TrustedDevice",
|
||||
"TrustedDeviceResp",
|
||||
|
||||
@@ -24,7 +24,7 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class UserAchievementBase(SQLModel, UTCBaseModel):
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlmodel import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class BeatmapOwner(SQLModel):
|
||||
@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
beatmap = Beatmap.model_validate(
|
||||
beatmap = cls.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
return beatmap
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher",
|
||||
):
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
|
||||
@@ -20,7 +20,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
|
||||
d = resp.model_dump()
|
||||
if resp.nominations:
|
||||
d["nominations_required"] = resp.nominations.required
|
||||
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
async def from_resp(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
resp: "BeatmapsetResp",
|
||||
from_: int = 0,
|
||||
) -> "Beatmapset":
|
||||
from .beatmap import Beatmap
|
||||
|
||||
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
|
||||
beatmapset = await cls.from_resp_no_save(resp)
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
@@ -334,5 +339,5 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
class SearchBeatmapsetsResp(SQLModel):
|
||||
beatmapsets: list[BeatmapsetResp]
|
||||
total: int
|
||||
cursor: dict[str, int | float] | None = None
|
||||
cursor: dict[str, int | float | str] | None = None
|
||||
cursor_string: str | None = None
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.lazer_user import User
|
||||
from app.database.user import User
|
||||
|
||||
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
from .score import Score
|
||||
|
||||
|
||||
class PPBestScore(SQLModel, table=True):
|
||||
class BestScore(SQLModel, table=True):
|
||||
__tablename__: str = "best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
if last_msg and last_msg.isdigit():
|
||||
last_msg = int(last_msg)
|
||||
else:
|
||||
last_msg = None
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
if last_read_id and last_read_id.isdigit():
|
||||
last_read_id = int(last_read_id)
|
||||
else:
|
||||
last_read_id = last_msg
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlmodel import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class CountBase(SQLModel):
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
|
||||
@@ -18,7 +18,7 @@ from sqlmodel import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.lazer_user import User
|
||||
from app.database.user import User
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""
|
||||
数据库字段类型工具
|
||||
提供处理数据库和 Pydantic 之间类型转换的工具
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import Boolean
|
||||
|
||||
|
||||
def bool_field_validator(field_name: str):
|
||||
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
|
||||
|
||||
@field_validator(field_name, mode="before")
|
||||
@classmethod
|
||||
def validate_bool_field(cls, v: Any) -> bool:
|
||||
"""将整数 0/1 转换为布尔值"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
return validate_bool_field
|
||||
|
||||
|
||||
def create_bool_field(**kwargs):
|
||||
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
|
||||
from sqlmodel import Column, Field
|
||||
|
||||
# 如果没有指定 sa_column,则使用 Boolean 类型
|
||||
if "sa_column" not in kwargs:
|
||||
# 处理 index 参数
|
||||
index = kwargs.pop("index", False)
|
||||
if index:
|
||||
kwargs["sa_column"] = Column(Boolean, index=True)
|
||||
else:
|
||||
kwargs["sa_column"] = Column(Boolean)
|
||||
|
||||
return Field(**kwargs)
|
||||
@@ -1,5 +1,5 @@
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .user import User, UserResp
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
@@ -2,8 +2,6 @@
|
||||
密码重置相关数据库模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils import utcnow
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
from app.models.playlist import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
@@ -72,7 +72,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
async def from_model(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
@@ -107,7 +107,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
db_playlist = await cls.from_model(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class RankHistory(SQLModel, table=True):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from .lazer_user import User, UserResp
|
||||
from .user import User, UserResp
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.database.playlist_attempts import PlaylistAggregateScore
|
||||
from app.database.item_attempts_count import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
@@ -14,8 +13,8 @@ from app.models.room import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .user import User, UserResp
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
@@ -160,25 +159,6 @@ class RoomResp(RoomBase):
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
name=room.settings.name,
|
||||
type=room.settings.match_type,
|
||||
queue_mode=room.settings.queue_mode,
|
||||
auto_skip=room.settings.auto_skip,
|
||||
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||
status=server_room.status,
|
||||
category=server_room.category,
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
channel_id=server_room.room.channel_id or 0,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
|
||||
@@ -15,8 +15,8 @@ from sqlmodel import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .room import Room
|
||||
from .user import User
|
||||
|
||||
|
||||
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.calculator import (
|
||||
from app.config import settings
|
||||
from app.database.team import TeamMember
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.model import (
|
||||
CurrentUserAttributes,
|
||||
@@ -38,17 +38,17 @@ from app.utils import utcnow
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .best_score import BestScore
|
||||
from .best_scores import BestScore
|
||||
from .counts import MonthlyPlaycounts
|
||||
from .events import Event, EventType
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import (
|
||||
Relationship as DBRelationship,
|
||||
RelationshipType,
|
||||
)
|
||||
from .score_token import ScoreToken
|
||||
from .total_score_best_scores import TotalScoreBestScore
|
||||
from .user import User, UserResp
|
||||
|
||||
from pydantic import BaseModel, field_serializer, field_validator
|
||||
from redis.asyncio import Redis
|
||||
@@ -74,6 +74,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
logger = log("Score")
|
||||
|
||||
|
||||
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
# 基本字段
|
||||
@@ -193,13 +195,13 @@ class Score(ScoreBase, table=True):
|
||||
# optional
|
||||
beatmap: Mapped[Beatmap] = Relationship()
|
||||
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
best_score: Mapped[BestScore | None] = Relationship(
|
||||
best_score: Mapped[TotalScoreBestScore | None] = Relationship(
|
||||
back_populates="score",
|
||||
sa_relationship_kwargs={
|
||||
"cascade": "all, delete-orphan",
|
||||
},
|
||||
)
|
||||
ranked_score: Mapped[PPBestScore | None] = Relationship(
|
||||
ranked_score: Mapped[BestScore | None] = Relationship(
|
||||
back_populates="score",
|
||||
sa_relationship_kwargs={
|
||||
"cascade": "all, delete-orphan",
|
||||
@@ -479,10 +481,10 @@ class ScoreAround(SQLModel):
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> int | None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(partition_by=(col(PPBestScore.user_id), col(PPBestScore.gamemode)), order_by=col(PPBestScore.pp).desc())
|
||||
.over(partition_by=(col(BestScore.user_id), col(BestScore.gamemode)), order_by=col(BestScore.pp).desc())
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(PPBestScore, rownum).subquery()
|
||||
subq = select(BestScore, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one_or_none()
|
||||
@@ -496,8 +498,8 @@ async def _score_where(
|
||||
user: User | None = None,
|
||||
) -> list[ColumnElement[bool] | TextClause] | None:
|
||||
wheres: list[ColumnElement[bool] | TextClause] = [
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
col(TotalScoreBestScore.beatmap_id) == beatmap,
|
||||
col(TotalScoreBestScore.gamemode) == mode,
|
||||
]
|
||||
|
||||
if type == LeaderboardType.FRIENDS:
|
||||
@@ -510,20 +512,21 @@ async def _score_where(
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
|
||||
wheres.append(col(TotalScoreBestScore.user_id).in_(select(subq.c.target_id)))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.COUNTRY:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
||||
wheres.append(col(TotalScoreBestScore.user).has(col(User.country_code) == user.country_code))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.TEAM:
|
||||
if user:
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||
elif type == LeaderboardType.TEAM and user:
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(
|
||||
col(TotalScoreBestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id))
|
||||
)
|
||||
if mods:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
@@ -557,10 +560,10 @@ async def get_leaderboard(
|
||||
max_score = sys.maxsize
|
||||
while limit > 0:
|
||||
query = (
|
||||
select(BestScore)
|
||||
.where(*wheres, BestScore.total_score < max_score)
|
||||
select(TotalScoreBestScore)
|
||||
.where(*wheres, TotalScoreBestScore.total_score < max_score)
|
||||
.limit(limit)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
)
|
||||
extra_need = 0
|
||||
for s in await session.exec(query):
|
||||
@@ -579,13 +582,13 @@ async def get_leaderboard(
|
||||
user_score = None
|
||||
if user:
|
||||
self_query = (
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user.id)
|
||||
select(TotalScoreBestScore)
|
||||
.where(TotalScoreBestScore.user_id == user.id)
|
||||
.where(
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
col(TotalScoreBestScore.beatmap_id) == beatmap,
|
||||
col(TotalScoreBestScore.gamemode) == mode,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
.limit(1)
|
||||
)
|
||||
if mods:
|
||||
@@ -618,14 +621,14 @@ async def get_score_position_by_user(
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(
|
||||
col(BestScore.beatmap_id),
|
||||
col(BestScore.gamemode),
|
||||
col(TotalScoreBestScore.beatmap_id),
|
||||
col(TotalScoreBestScore.gamemode),
|
||||
),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
order_by=col(TotalScoreBestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
subq = select(TotalScoreBestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.first()
|
||||
@@ -648,14 +651,14 @@ async def get_score_position_by_id(
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(
|
||||
col(BestScore.beatmap_id),
|
||||
col(BestScore.gamemode),
|
||||
col(TotalScoreBestScore.beatmap_id),
|
||||
col(TotalScoreBestScore.gamemode),
|
||||
),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
order_by=col(TotalScoreBestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
subq = select(TotalScoreBestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
@@ -667,16 +670,16 @@ async def get_user_best_score_in_beatmap(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
mode: GameMode | None = None,
|
||||
) -> BestScore | None:
|
||||
) -> TotalScoreBestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
select(BestScore)
|
||||
select(TotalScoreBestScore)
|
||||
.where(
|
||||
BestScore.gamemode == mode if mode is not None else true(),
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
TotalScoreBestScore.gamemode == mode if mode is not None else true(),
|
||||
TotalScoreBestScore.beatmap_id == beatmap,
|
||||
TotalScoreBestScore.user_id == user,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -687,32 +690,32 @@ async def get_user_best_score_with_mod_in_beatmap(
|
||||
user: int,
|
||||
mod: list[str],
|
||||
mode: GameMode | None = None,
|
||||
) -> BestScore | None:
|
||||
) -> TotalScoreBestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
select(BestScore)
|
||||
select(TotalScoreBestScore)
|
||||
.where(
|
||||
BestScore.gamemode == mode if mode is not None else True,
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
TotalScoreBestScore.gamemode == mode if mode is not None else True,
|
||||
TotalScoreBestScore.beatmap_id == beatmap,
|
||||
TotalScoreBestScore.user_id == user,
|
||||
text(
|
||||
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
|
||||
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
|
||||
).params(w=json.dumps(mod)),
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
|
||||
async def get_user_first_scores(
|
||||
session: AsyncSession, user_id: int, mode: GameMode, limit: int = 5, offset: int = 0
|
||||
) -> list[BestScore]:
|
||||
) -> list[TotalScoreBestScore]:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(col(BestScore.beatmap_id), col(BestScore.gamemode)),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
partition_by=(col(TotalScoreBestScore.beatmap_id), col(TotalScoreBestScore.gamemode)),
|
||||
order_by=col(TotalScoreBestScore.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
@@ -720,11 +723,11 @@ async def get_user_first_scores(
|
||||
# Step 1: Fetch top score_ids in Python
|
||||
subq = (
|
||||
select(
|
||||
col(BestScore.score_id).label("score_id"),
|
||||
col(BestScore.user_id).label("user_id"),
|
||||
col(TotalScoreBestScore.score_id).label("score_id"),
|
||||
col(TotalScoreBestScore.user_id).label("user_id"),
|
||||
rownum,
|
||||
)
|
||||
.where(col(BestScore.gamemode) == mode)
|
||||
.where(col(TotalScoreBestScore.gamemode) == mode)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
@@ -733,7 +736,11 @@ async def get_user_first_scores(
|
||||
top_ids = await session.exec(top_ids_stmt)
|
||||
top_ids = list(top_ids)
|
||||
|
||||
stmt = select(BestScore).where(col(BestScore.score_id).in_(top_ids)).order_by(col(BestScore.total_score).desc())
|
||||
stmt = (
|
||||
select(TotalScoreBestScore)
|
||||
.where(col(TotalScoreBestScore.score_id).in_(top_ids))
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
)
|
||||
|
||||
result = await session.exec(stmt)
|
||||
return list(result.all())
|
||||
@@ -743,18 +750,18 @@ async def get_user_first_score_count(session: AsyncSession, user_id: int, mode:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(col(BestScore.beatmap_id), col(BestScore.gamemode)),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
partition_by=(col(TotalScoreBestScore.beatmap_id), col(TotalScoreBestScore.gamemode)),
|
||||
order_by=col(TotalScoreBestScore.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = (
|
||||
select(
|
||||
col(BestScore.score_id).label("score_id"),
|
||||
col(BestScore.user_id).label("user_id"),
|
||||
col(TotalScoreBestScore.score_id).label("score_id"),
|
||||
col(TotalScoreBestScore.user_id).label("user_id"),
|
||||
rownum,
|
||||
)
|
||||
.where(col(BestScore.gamemode) == mode)
|
||||
.where(col(TotalScoreBestScore.gamemode) == mode)
|
||||
.subquery()
|
||||
)
|
||||
count_stmt = select(func.count()).where(subq.c.rn == 1, subq.c.user_id == user_id)
|
||||
@@ -768,13 +775,13 @@ async def get_user_best_pp_in_beatmap(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
mode: GameMode,
|
||||
) -> PPBestScore | None:
|
||||
) -> BestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
select(PPBestScore).where(
|
||||
PPBestScore.beatmap_id == beatmap,
|
||||
PPBestScore.user_id == user,
|
||||
PPBestScore.gamemode == mode,
|
||||
select(BestScore).where(
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
BestScore.gamemode == mode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -799,12 +806,12 @@ async def get_user_best_pp(
|
||||
user: int,
|
||||
mode: GameMode,
|
||||
limit: int = 1000,
|
||||
) -> Sequence[PPBestScore]:
|
||||
) -> Sequence[BestScore]:
|
||||
return (
|
||||
await session.exec(
|
||||
select(PPBestScore)
|
||||
.where(PPBestScore.user_id == user, PPBestScore.gamemode == mode)
|
||||
.order_by(col(PPBestScore.pp).desc())
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user, BestScore.gamemode == mode)
|
||||
.order_by(col(BestScore.pp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
@@ -854,8 +861,7 @@ async def process_score(
|
||||
) -> Score:
|
||||
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
|
||||
logger.info(
|
||||
"[Score] Creating score for user {user_id} | beatmap={beatmap_id} "
|
||||
"ruleset={ruleset} passed={passed} total={total}",
|
||||
"Creating score for user {user_id} | beatmap={beatmap_id} ruleset={ruleset} passed={passed} total={total}",
|
||||
user_id=user.id,
|
||||
beatmap_id=beatmap_id,
|
||||
ruleset=gamemode,
|
||||
@@ -897,7 +903,7 @@ async def process_score(
|
||||
)
|
||||
session.add(score)
|
||||
logger.debug(
|
||||
"[Score] Score staged for commit | token={token} mods={mods} total_hits={hits}",
|
||||
"Score staged for commit | token={token} mods={mods} total_hits={hits}",
|
||||
token=score_token.id,
|
||||
mods=info.mods,
|
||||
hits=sum(info.statistics.values()) if info.statistics else 0,
|
||||
@@ -910,7 +916,7 @@ async def process_score(
|
||||
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||
if score.pp != 0:
|
||||
logger.debug(
|
||||
"[Score] Skipping PP calculation for score {score_id} | already set {pp:.2f}",
|
||||
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
|
||||
score_id=score.id,
|
||||
pp=score.pp,
|
||||
)
|
||||
@@ -918,7 +924,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
|
||||
can_get_pp = score.passed and score.ranked and mods_can_get_pp(int(score.gamemode), score.mods)
|
||||
if not can_get_pp:
|
||||
logger.debug(
|
||||
"[Score] Skipping PP calculation for score {score_id} | passed={passed} ranked={ranked} mods={mods}",
|
||||
"Skipping PP calculation for score {score_id} | passed={passed} ranked={ranked} mods={mods}",
|
||||
score_id=score.id,
|
||||
passed=score.passed,
|
||||
ranked=score.ranked,
|
||||
@@ -928,15 +934,15 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
|
||||
pp, successed = await pre_fetch_and_calculate_pp(score, session, redis, fetcher)
|
||||
if not successed:
|
||||
await redis.rpush("score:need_recalculate", score.id) # pyright: ignore[reportGeneralTypeIssues]
|
||||
logger.warning("[Score] Queued score {score_id} for PP recalculation", score_id=score.id)
|
||||
logger.warning("Queued score {score_id} for PP recalculation", score_id=score.id)
|
||||
return
|
||||
score.pp = pp
|
||||
logger.info("[Score] Calculated PP for score {score_id} | pp={pp:.2f}", score_id=score.id, pp=pp)
|
||||
logger.info("Calculated PP for score {score_id} | pp={pp:.2f}", score_id=score.id, pp=pp)
|
||||
user_id = score.user_id
|
||||
beatmap_id = score.beatmap_id
|
||||
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
|
||||
if previous_pp_best is None or score.pp > previous_pp_best.pp:
|
||||
best_score = PPBestScore(
|
||||
best_score = BestScore(
|
||||
user_id=user_id,
|
||||
score_id=score.id,
|
||||
beatmap_id=beatmap_id,
|
||||
@@ -947,7 +953,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
|
||||
session.add(best_score)
|
||||
await session.delete(previous_pp_best) if previous_pp_best else None
|
||||
logger.info(
|
||||
"[Score] Updated PP best for user {user_id} | score_id={score_id} pp={pp:.2f}",
|
||||
"Updated PP best for user {user_id} | score_id={score_id} pp={pp:.2f}",
|
||||
user_id=user_id,
|
||||
score_id=score.id,
|
||||
pp=score.pp,
|
||||
@@ -966,15 +972,14 @@ async def _process_score_events(score: Score, session: AsyncSession):
|
||||
|
||||
if rank_global == 0 or total_users == 0:
|
||||
logger.debug(
|
||||
"[Score] Skipping event creation for score {score_id} | "
|
||||
"rank_global={rank_global} total_users={total_users}",
|
||||
"Skipping event creation for score {score_id} | rank_global={rank_global} total_users={total_users}",
|
||||
score_id=score.id,
|
||||
rank_global=rank_global,
|
||||
total_users=total_users,
|
||||
)
|
||||
return
|
||||
logger.debug(
|
||||
"[Score] Processing events for score {score_id} | rank_global={rank_global} total_users={total_users}",
|
||||
"Processing events for score {score_id} | rank_global={rank_global} total_users={total_users}",
|
||||
score_id=score.id,
|
||||
rank_global=rank_global,
|
||||
total_users=total_users,
|
||||
@@ -1003,7 +1008,7 @@ async def _process_score_events(score: Score, session: AsyncSession):
|
||||
}
|
||||
session.add(rank_event)
|
||||
logger.info(
|
||||
"[Score] Registered rank event for user {user_id} | score_id={score_id} rank={rank}",
|
||||
"Registered rank event for user {user_id} | score_id={score_id} rank={rank}",
|
||||
user_id=score.user_id,
|
||||
score_id=score.id,
|
||||
rank=rank_global,
|
||||
@@ -1011,12 +1016,12 @@ async def _process_score_events(score: Score, session: AsyncSession):
|
||||
if rank_global == 1:
|
||||
displaced_score = (
|
||||
await session.exec(
|
||||
select(BestScore)
|
||||
select(TotalScoreBestScore)
|
||||
.where(
|
||||
BestScore.beatmap_id == score.beatmap_id,
|
||||
BestScore.gamemode == score.gamemode,
|
||||
TotalScoreBestScore.beatmap_id == score.beatmap_id,
|
||||
TotalScoreBestScore.gamemode == score.gamemode,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.order_by(col(TotalScoreBestScore.total_score).desc())
|
||||
.limit(1)
|
||||
.offset(1)
|
||||
)
|
||||
@@ -1045,12 +1050,12 @@ async def _process_score_events(score: Score, session: AsyncSession):
|
||||
}
|
||||
session.add(rank_lost_event)
|
||||
logger.info(
|
||||
"[Score] Registered rank lost event | displaced_user={user_id} new_score_id={score_id}",
|
||||
"Registered rank lost event | displaced_user={user_id} new_score_id={score_id}",
|
||||
user_id=displaced_score.user_id,
|
||||
score_id=score.id,
|
||||
)
|
||||
logger.debug(
|
||||
"[Score] Event processing committed for score {score_id}",
|
||||
"Event processing committed for score {score_id}",
|
||||
score_id=score.id,
|
||||
)
|
||||
|
||||
@@ -1074,7 +1079,7 @@ async def _process_statistics(
|
||||
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
|
||||
)
|
||||
logger.debug(
|
||||
"[Score] Existing best scores for user {user_id} | global={global_id} mod={mod_id}",
|
||||
"Existing best scores for user {user_id} | global={global_id} mod={mod_id}",
|
||||
user_id=user.id,
|
||||
global_id=previous_score_best.score_id if previous_score_best else None,
|
||||
mod_id=previous_score_best_mod.score_id if previous_score_best_mod else None,
|
||||
@@ -1104,7 +1109,7 @@ async def _process_statistics(
|
||||
statistics.total_score += score.total_score
|
||||
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
|
||||
logger.debug(
|
||||
"[Score] Score delta computed for {score_id}: {difference}",
|
||||
"Score delta computed for {score_id}: {difference}",
|
||||
score_id=score.id,
|
||||
difference=difference,
|
||||
)
|
||||
@@ -1140,7 +1145,7 @@ async def _process_statistics(
|
||||
# 情况2: 有最佳分数记录但没有该mod组合的记录,添加新记录
|
||||
if previous_score_best is None or previous_score_best_mod is None:
|
||||
session.add(
|
||||
BestScore(
|
||||
TotalScoreBestScore(
|
||||
user_id=user.id,
|
||||
beatmap_id=score.beatmap_id,
|
||||
gamemode=score.gamemode,
|
||||
@@ -1151,7 +1156,7 @@ async def _process_statistics(
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"[Score] Created new best score entry for user {user_id} | score_id={score_id} mods={mods}",
|
||||
"Created new best score entry for user {user_id} | score_id={score_id} mods={mods}",
|
||||
user_id=user.id,
|
||||
score_id=score.id,
|
||||
mods=mod_for_save,
|
||||
@@ -1163,7 +1168,7 @@ async def _process_statistics(
|
||||
previous_score_best.rank = score.rank
|
||||
previous_score_best.score_id = score.id
|
||||
logger.info(
|
||||
"[Score] Updated existing best score for user {user_id} | score_id={score_id} total={total}",
|
||||
"Updated existing best score for user {user_id} | score_id={score_id} total={total}",
|
||||
user_id=user.id,
|
||||
score_id=score.id,
|
||||
total=score.total_score,
|
||||
@@ -1175,7 +1180,7 @@ async def _process_statistics(
|
||||
if difference > 0:
|
||||
# 下方的 if 一定会触发。将高分设置为此分数,删除自己防止重复的 score_id
|
||||
logger.info(
|
||||
"[Score] Replacing global best score for user {user_id} | old_score_id={old_score_id}",
|
||||
"Replacing global best score for user {user_id} | old_score_id={old_score_id}",
|
||||
user_id=user.id,
|
||||
old_score_id=previous_score_best.score_id,
|
||||
)
|
||||
@@ -1188,7 +1193,7 @@ async def _process_statistics(
|
||||
previous_score_best_mod.rank = score.rank
|
||||
previous_score_best_mod.score_id = score.id
|
||||
logger.info(
|
||||
"[Score] Replaced mod-specific best for user {user_id} | mods={mods} score_id={score_id}",
|
||||
"Replaced mod-specific best for user {user_id} | mods={mods} score_id={score_id}",
|
||||
user_id=user.id,
|
||||
mods=mod_for_save,
|
||||
score_id=score.id,
|
||||
@@ -1202,14 +1207,14 @@ async def _process_statistics(
|
||||
mouthly_playcount.count += 1
|
||||
statistics.play_time += playtime
|
||||
logger.debug(
|
||||
"[Score] Recorded playtime {playtime}s for score {score_id} (user {user_id})",
|
||||
"Recorded playtime {playtime}s for score {score_id} (user {user_id})",
|
||||
playtime=playtime,
|
||||
score_id=score.id,
|
||||
user_id=user.id,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[Score] Playtime {playtime}s for score {score_id} did not meet validity checks",
|
||||
"Playtime {playtime}s for score {score_id} did not meet validity checks",
|
||||
playtime=playtime,
|
||||
score_id=score.id,
|
||||
)
|
||||
@@ -1242,7 +1247,7 @@ async def _process_statistics(
|
||||
if add_to_db:
|
||||
session.add(mouthly_playcount)
|
||||
logger.debug(
|
||||
"[Score] Created monthly playcount record for user {user_id} ({year}-{month})",
|
||||
"Created monthly playcount record for user {user_id} ({year}-{month})",
|
||||
user_id=user.id,
|
||||
year=mouthly_playcount.year,
|
||||
month=mouthly_playcount.month,
|
||||
@@ -1262,7 +1267,7 @@ async def process_user(
|
||||
score_id = score.id
|
||||
user_id = user.id
|
||||
logger.info(
|
||||
"[Score] Processing score {score_id} for user {user_id} on beatmap {beatmap_id}",
|
||||
"Processing score {score_id} for user {user_id} on beatmap {beatmap_id}",
|
||||
score_id=score_id,
|
||||
user_id=user_id,
|
||||
beatmap_id=score.beatmap_id,
|
||||
@@ -1287,14 +1292,14 @@ async def process_user(
|
||||
score_ = (await session.exec(select(Score).where(Score.id == score_id).options(joinedload(Score.beatmap)))).first()
|
||||
if score_ is None:
|
||||
logger.warning(
|
||||
"[Score] Score {score_id} disappeared after commit, skipping event processing",
|
||||
"Score {score_id} disappeared after commit, skipping event processing",
|
||||
score_id=score_id,
|
||||
)
|
||||
return
|
||||
await _process_score_events(score_, session)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"[Score] Finished processing score {score_id} for user {user_id}",
|
||||
"Finished processing score {score_id} for user {user_id}",
|
||||
score_id=score_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
|
||||
from .beatmap import Beatmap
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User, UserResp
|
||||
from .user import User, UserResp
|
||||
|
||||
|
||||
class UserStatisticsBase(SQLModel):
|
||||
@@ -122,7 +122,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
"progress": int(math.fmod(obj.level_current, 1) * 100),
|
||||
}
|
||||
if "user" in include:
|
||||
from .lazer_user import RANKING_INCLUDES, UserResp
|
||||
from .user import RANKING_INCLUDES, UserResp
|
||||
|
||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
||||
s.user = user
|
||||
@@ -149,7 +149,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
|
||||
|
||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
query = select(
|
||||
UserStatistics.user_id,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
|
||||
class Team(SQLModel, UTCBaseModel, table=True):
|
||||
|
||||
@@ -4,7 +4,7 @@ from app.calculator import calculate_score_to_level
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.models.score import GameMode, Rank
|
||||
|
||||
from .lazer_user import User
|
||||
from .user import User
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
||||
from .score import Score
|
||||
|
||||
|
||||
class BestScore(SQLModel, table=True):
|
||||
class TotalScoreBestScore(SQLModel, table=True):
|
||||
__tablename__: str = "total_score_best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
@@ -41,7 +41,7 @@ class BestScore(SQLModel, table=True):
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[BestScore.score_id]",
|
||||
"foreign_keys": "[TotalScoreBestScore.score_id]",
|
||||
"lazy": "joined",
|
||||
},
|
||||
back_populates="best_score",
|
||||
@@ -75,7 +75,7 @@ class BestScore(SQLModel, table=True):
|
||||
await session.exec(
|
||||
select(func.max(Score.max_combo)).where(
|
||||
Score.user_id == self.user_id,
|
||||
col(Score.id).in_(select(BestScore.score_id)),
|
||||
col(Score.id).in_(select(TotalScoreBestScore.score_id)),
|
||||
Score.gamemode == self.gamemode,
|
||||
)
|
||||
)
|
||||
@@ -72,7 +72,7 @@ COUNTRIES = json.loads((STATIC_DIR / "iso3166.json").read_text())
|
||||
|
||||
|
||||
class UserBase(UTCBaseModel, SQLModel):
|
||||
avatar_url: str = ""
|
||||
avatar_url: str = "https://lazer-data.g0v0.top/default.jpg"
|
||||
country_code: str = Field(default="CN", max_length=2, index=True)
|
||||
# ? default_group: str|None
|
||||
is_active: bool = True
|
||||
@@ -256,16 +256,14 @@ class UserResp(UserBase):
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
*,
|
||||
token_id: int | None = None,
|
||||
) -> "UserResp":
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
from .best_score import BestScore
|
||||
from .best_scores import BestScore
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .score import Score, get_user_first_score_count
|
||||
from .total_score_best_scores import TotalScoreBestScore
|
||||
|
||||
ruleset = ruleset or obj.playmode
|
||||
|
||||
@@ -286,9 +284,9 @@ class UserResp(UserBase):
|
||||
u.scores_best_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(BestScore)
|
||||
.select_from(TotalScoreBestScore)
|
||||
.where(
|
||||
BestScore.user_id == obj.id,
|
||||
TotalScoreBestScore.user_id == obj.id,
|
||||
)
|
||||
.limit(200)
|
||||
)
|
||||
@@ -310,16 +308,16 @@ class UserResp(UserBase):
|
||||
).all()
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if team_membership := await obj.awaitable_attrs.team_membership:
|
||||
u.team = team_membership.team
|
||||
if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
|
||||
u.team = team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
if "daily_challenge_user_stats" in include and (
|
||||
daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
|
||||
):
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
@@ -393,10 +391,10 @@ class UserResp(UserBase):
|
||||
u.scores_best_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(PPBestScore)
|
||||
.select_from(BestScore)
|
||||
.where(
|
||||
PPBestScore.user_id == obj.id,
|
||||
PPBestScore.gamemode == ruleset,
|
||||
BestScore.user_id == obj.id,
|
||||
BestScore.gamemode == ruleset,
|
||||
)
|
||||
.limit(200)
|
||||
)
|
||||
@@ -443,7 +441,7 @@ class MeResp(UserResp):
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
|
||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
|
||||
u.session_verified = (
|
||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||
if token_id
|
||||
@@ -1,4 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .database import get_db as get_db
|
||||
from .user import get_current_user as get_current_user
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header
|
||||
|
||||
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
|
||||
if version is None:
|
||||
return 0
|
||||
if version < 1:
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from typing import Annotated
|
||||
|
||||
from app.service.beatmap_download_service import download_service
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
def get_beatmap_download_service():
|
||||
"""获取谱面下载服务实例"""
|
||||
return download_service
|
||||
|
||||
|
||||
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
Beatmapset缓存服务依赖注入
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
|
||||
from app.dependencies.database import Redis
|
||||
from app.service.beatmapset_cache_service import (
|
||||
BeatmapsetCacheService as OriginBeatmapsetCacheService,
|
||||
get_beatmapset_cache_service,
|
||||
)
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
|
||||
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
|
||||
"""获取beatmapset缓存服务依赖"""
|
||||
return get_beatmapset_cache_service(redis)
|
||||
|
||||
|
||||
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
@@ -11,7 +9,6 @@ from app.config import settings
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
import redis as sync_redis
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
@@ -38,13 +35,16 @@ engine = create_async_engine(
|
||||
)
|
||||
|
||||
# Redis 连接
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True, db=0)
|
||||
|
||||
# Redis 二进制数据连接 (不自动解码响应,用于存储音频等二进制数据)
|
||||
redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False)
|
||||
# Redis 消息缓存连接 (db1)
|
||||
redis_message_client = redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
|
||||
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
|
||||
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
# Redis 二进制数据连接 (不自动解码响应,用于存储音频等二进制数据,db2)
|
||||
redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False, db=2)
|
||||
|
||||
# Redis 限流连接 (db3)
|
||||
redis_rate_limit_client = redis.from_url(settings.redis_url, decode_responses=True, db=3)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
@@ -91,12 +91,15 @@ def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
Redis = Annotated[redis.Redis, Depends(get_redis)]
|
||||
|
||||
|
||||
def get_redis_binary():
|
||||
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
|
||||
return redis_binary_client
|
||||
|
||||
|
||||
def get_redis_message():
|
||||
def get_redis_message() -> redis.Redis:
|
||||
"""获取消息专用的 Redis 客户端 (db1)"""
|
||||
return redis_message_client
|
||||
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import get_redis
|
||||
from app.fetcher import Fetcher
|
||||
from app.log import logger
|
||||
from app.fetcher import Fetcher as OriginFetcher
|
||||
from app.log import fetcher_logger
|
||||
|
||||
fetcher: Fetcher | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
fetcher: OriginFetcher | None = None
|
||||
|
||||
|
||||
async def get_fetcher() -> Fetcher:
|
||||
async def get_fetcher() -> OriginFetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
fetcher = OriginFetcher(
|
||||
settings.fetcher_client_id,
|
||||
settings.fetcher_client_secret,
|
||||
settings.fetcher_scopes,
|
||||
@@ -25,5 +27,10 @@ async def get_fetcher() -> Fetcher:
|
||||
if refresh_token:
|
||||
fetcher.refresh_token = str(refresh_token)
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
|
||||
fetcher_logger("Fetcher").opt(colors=True).info(
|
||||
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
|
||||
)
|
||||
return fetcher
|
||||
|
||||
|
||||
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
GeoIP dependency for FastAPI
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import ipaddress
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_geoip_helper() -> GeoIPHelper:
|
||||
@@ -26,7 +27,7 @@ def get_geoip_helper() -> GeoIPHelper:
|
||||
)
|
||||
|
||||
|
||||
def get_client_ip(request) -> str:
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实 IP 地址
|
||||
支持 IPv4 和 IPv6,考虑代理、负载均衡器等情况
|
||||
@@ -66,6 +67,10 @@ def get_client_ip(request) -> str:
|
||||
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
|
||||
|
||||
|
||||
IPAddress = Annotated[str, Depends(get_client_ip)]
|
||||
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
|
||||
|
||||
|
||||
def is_valid_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
验证 IP 地址是否有效(支持 IPv4 和 IPv6)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
@@ -7,7 +5,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
def BodyOrForm[T: BaseModel](model: type[T]):
|
||||
def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
|
||||
async def dependency(
|
||||
request: Request,
|
||||
) -> T:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from typing import cast
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import Annotated, cast
|
||||
|
||||
from app.config import (
|
||||
AWSS3StorageSettings,
|
||||
@@ -9,11 +7,13 @@ from app.config import (
|
||||
StorageServiceType,
|
||||
settings,
|
||||
)
|
||||
from app.storage import StorageService
|
||||
from app.storage import StorageService as OriginStorageService
|
||||
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
storage: StorageService | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
storage: OriginStorageService | None = None
|
||||
|
||||
|
||||
def init_storage_service():
|
||||
@@ -50,3 +50,6 @@ def get_storage_service():
|
||||
if storage is None:
|
||||
return init_storage_service()
|
||||
return storage
|
||||
|
||||
|
||||
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database import User
|
||||
from app.database.auth import OAuthToken, V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
@@ -11,7 +10,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
from .api_version import APIVersion
|
||||
from .database import Database, get_redis
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi.security import (
|
||||
APIKeyQuery,
|
||||
HTTPBearer,
|
||||
@@ -112,16 +111,13 @@ async def get_client_user(
|
||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||
# 获取当前验证方式
|
||||
verify_method = None
|
||||
if api_version >= 20250913:
|
||||
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||||
|
||||
if verify_method is None:
|
||||
# 智能选择验证方式(有TOTP优先TOTP)
|
||||
totp_key = await user.awaitable_attrs.totp_key
|
||||
if totp_key is not None and api_version >= 20240101:
|
||||
verify_method = "totp"
|
||||
else:
|
||||
verify_method = "mail"
|
||||
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
|
||||
|
||||
# 设置选择的验证方法到Redis中,避免重复选择
|
||||
if api_version >= 20250913:
|
||||
@@ -169,3 +165,6 @@ async def get_current_user(
|
||||
user_and_token: UserAndToken = Depends(get_current_user_and_token),
|
||||
) -> User:
|
||||
return user_and_token[0]
|
||||
|
||||
|
||||
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.models.model import UserAgentInfo as UserAgentInfoModel
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class SignalRException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvokeException(SignalRException):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
用户页面相关的异常类
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class UserpageError(Exception):
|
||||
"""用户页面处理错误基类"""
|
||||
|
||||
def __init__(self, message: str, code: str = "userpage_error"):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ContentTooLongError(UserpageError):
|
||||
"""内容过长错误"""
|
||||
|
||||
def __init__(self, current_length: int, max_length: int):
|
||||
message = f"Content too long. Maximum {max_length} characters allowed, got {current_length}."
|
||||
super().__init__(message, "content_too_long")
|
||||
self.current_length = current_length
|
||||
self.max_length = max_length
|
||||
|
||||
|
||||
class ContentEmptyError(UserpageError):
|
||||
"""内容为空错误"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Content cannot be empty.", "content_empty")
|
||||
|
||||
|
||||
class BBCodeValidationError(UserpageError):
|
||||
"""BBCode验证错误"""
|
||||
|
||||
def __init__(self, errors: list[str]):
|
||||
message = f"BBCode validation failed: {'; '.join(errors)}"
|
||||
super().__init__(message, "bbcode_validation_error")
|
||||
self.errors = errors
|
||||
|
||||
|
||||
class ForbiddenTagError(UserpageError):
|
||||
"""禁止标签错误"""
|
||||
|
||||
def __init__(self, tag: str):
|
||||
message = f"Forbidden tag '{tag}' is not allowed."
|
||||
super().__init__(message, "forbidden_tag")
|
||||
self.tag = tag
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .beatmap import BeatmapFetcher
|
||||
from .beatmap_raw import BeatmapRawFetcher
|
||||
from .beatmapset import BeatmapsetFetcher
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import logger
|
||||
from app.log import fetcher_logger
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
@@ -16,6 +14,9 @@ class TokenAuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
logger = fetcher_logger("Fetcher")
|
||||
|
||||
|
||||
class BaseFetcher:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmap import BeatmapResp
|
||||
from app.log import logger
|
||||
from app.log import fetcher_logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
logger = fetcher_logger("BeatmapFetcher")
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
|
||||
@@ -14,7 +14,7 @@ class BeatmapFetcher(BaseFetcher):
|
||||
params = {"checksum": beatmap_checksum}
|
||||
else:
|
||||
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>")
|
||||
logger.opt(colors=True).debug(f"get_beatmap: <y>{params}</y>")
|
||||
|
||||
return BeatmapResp.model_validate(
|
||||
await self.request_api(
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from app.log import fetcher_logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient, HTTPError
|
||||
from httpx._models import Response
|
||||
from loguru import logger
|
||||
import redis.asyncio as redis
|
||||
|
||||
urls = [
|
||||
@@ -13,12 +12,14 @@ urls = [
|
||||
"https://catboy.best/osu/{beatmap_id}",
|
||||
]
|
||||
|
||||
logger = fetcher_logger("BeatmapRawFetcher")
|
||||
|
||||
|
||||
class BeatmapRawFetcher(BaseFetcher):
|
||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||
for url in urls:
|
||||
req_url = url.format(beatmap_id=beatmap_id)
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>")
|
||||
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>")
|
||||
resp = await self._request(req_url)
|
||||
if resp.status_code >= 400:
|
||||
continue
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
@@ -7,7 +5,7 @@ import json
|
||||
|
||||
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
|
||||
from app.helpers.rate_limiter import osu_api_rate_limiter
|
||||
from app.log import logger
|
||||
from app.log import fetcher_logger
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.models.model import Cursor
|
||||
from app.utils import bg_tasks
|
||||
@@ -24,6 +22,9 @@ class RateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
logger = fetcher_logger("BeatmapsetFetcher")
|
||||
|
||||
|
||||
class BeatmapsetFetcher(BaseFetcher):
|
||||
@staticmethod
|
||||
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
|
||||
@@ -113,7 +114,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
# 序列化为 JSON 并生成 MD5 哈希
|
||||
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
|
||||
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
|
||||
cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
|
||||
|
||||
@@ -135,7 +136,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
return {}
|
||||
|
||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
|
||||
logger.opt(colors=True).debug(f"get_beatmapset: <y>{beatmap_set_id}</y>")
|
||||
|
||||
return BeatmapsetResp.model_validate(
|
||||
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
|
||||
@@ -144,7 +145,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
async def search_beatmapset(
|
||||
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
|
||||
) -> SearchBeatmapsetsResp:
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
|
||||
logger.opt(colors=True).debug(f"search_beatmapset: <y>{query}</y>")
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = self._generate_cache_key(query, cursor)
|
||||
@@ -152,17 +153,15 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 尝试从缓存获取结果
|
||||
cached_result = await redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
|
||||
logger.opt(colors=True).debug(f"Cache hit for key: <y>{cache_key}</y>")
|
||||
try:
|
||||
cached_data = json.loads(cached_result)
|
||||
return SearchBeatmapsetsResp.model_validate(cached_data)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<yellow>[BeatmapsetFetcher]</yellow> Cache data invalid, fetching from API: {e}"
|
||||
)
|
||||
logger.warning(f"Cache data invalid, fetching from API: {e}")
|
||||
|
||||
# 缓存未命中,从 API 获取数据
|
||||
logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
|
||||
logger.debug("Cache miss, fetching from API")
|
||||
|
||||
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
@@ -186,9 +185,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
cache_ttl = 15 * 60 # 15 分钟
|
||||
await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)")
|
||||
|
||||
resp = SearchBeatmapsetsResp.model_validate(api_response)
|
||||
|
||||
@@ -204,9 +201,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info(
|
||||
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
|
||||
)
|
||||
logger.info("Prefetch skipped due to rate limit")
|
||||
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
@@ -230,14 +225,14 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 使用当前 cursor 请求下一页
|
||||
next_query = query.model_copy()
|
||||
|
||||
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
|
||||
logger.debug(f"Prefetching page {page + 1}")
|
||||
|
||||
# 生成下一页的缓存键
|
||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(next_cache_key):
|
||||
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
|
||||
logger.debug(f"Page {page + 1} already cached")
|
||||
# 尝试从缓存获取cursor继续预取
|
||||
cached_data = await redis_client.get(next_cache_key)
|
||||
if cached_data:
|
||||
@@ -247,7 +242,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
cursor = data["cursor"]
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Failed to parse cached data for cursor")
|
||||
break
|
||||
|
||||
# 在预取页面之间添加延迟,避免突发请求
|
||||
@@ -282,22 +277,18 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
ex=prefetch_ttl,
|
||||
)
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
|
||||
)
|
||||
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch stopped due to rate limit")
|
||||
logger.info("Prefetch stopped due to rate limit")
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
|
||||
logger.warning(f"Prefetch failed: {e}")
|
||||
|
||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||
"""预热主页缓存"""
|
||||
homepage_queries = self._get_homepage_queries()
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
|
||||
)
|
||||
logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
|
||||
|
||||
for i, (query, cursor) in enumerate(homepage_queries):
|
||||
try:
|
||||
@@ -309,9 +300,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(cache_key):
|
||||
logger.opt(colors=True).debug(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
|
||||
)
|
||||
logger.debug(f"Query {query.sort} already cached")
|
||||
continue
|
||||
|
||||
# 请求并缓存
|
||||
@@ -334,24 +323,15 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
ex=cache_ttl,
|
||||
)
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
|
||||
)
|
||||
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||
|
||||
if api_response.get("cursor"):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info(
|
||||
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup prefetch "
|
||||
f"skipped for {query.sort} due to rate limit"
|
||||
)
|
||||
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
||||
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup skipped for {query.sort} due to rate limit"
|
||||
)
|
||||
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).error(
|
||||
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
|
||||
)
|
||||
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||
|
||||
0
app/helpers/__init__.py
Normal file
0
app/helpers/__init__.py
Normal file
106
app/helpers/asset_proxy_helper.py
Normal file
106
app/helpers/asset_proxy_helper.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""资源代理辅助方法与路由装饰器。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from functools import wraps
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
Handler = Callable[..., Awaitable[Any]]
|
||||
|
||||
|
||||
def _replace_asset_urls_in_string(value: str) -> str:
|
||||
result = value
|
||||
custom_domain = settings.custom_asset_domain
|
||||
asset_prefix = settings.asset_proxy_prefix
|
||||
avatar_prefix = settings.avatar_proxy_prefix
|
||||
beatmap_prefix = settings.beatmap_proxy_prefix
|
||||
audio_proxy_base_url = f"{settings.server_url}api/private/audio/beatmapset"
|
||||
|
||||
result = re.sub(
|
||||
r"^https://assets\.ppy\.sh/",
|
||||
f"https://{asset_prefix}.{custom_domain}/",
|
||||
result,
|
||||
)
|
||||
|
||||
result = re.sub(
|
||||
r"^https://b\.ppy\.sh/preview/(\d+)\\.mp3",
|
||||
rf"{audio_proxy_base_url}/\1",
|
||||
result,
|
||||
)
|
||||
|
||||
result = re.sub(
|
||||
r"^//b\.ppy\.sh/preview/(\d+)\\.mp3",
|
||||
rf"{audio_proxy_base_url}/\1",
|
||||
result,
|
||||
)
|
||||
|
||||
result = re.sub(
|
||||
r"^https://a\.ppy\.sh/",
|
||||
f"https://{avatar_prefix}.{custom_domain}/",
|
||||
result,
|
||||
)
|
||||
|
||||
result = re.sub(
|
||||
r"https://b\.ppy\.sh/",
|
||||
f"https://{beatmap_prefix}.{custom_domain}/",
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _replace_asset_urls_in_data(data: Any) -> Any:
|
||||
if isinstance(data, str):
|
||||
return _replace_asset_urls_in_string(data)
|
||||
if isinstance(data, list):
|
||||
return [_replace_asset_urls_in_data(item) for item in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_replace_asset_urls_in_data(item) for item in data)
|
||||
if isinstance(data, dict):
|
||||
return {key: _replace_asset_urls_in_data(value) for key, value in data.items()}
|
||||
return data
|
||||
|
||||
|
||||
async def replace_asset_urls(data: Any) -> Any:
|
||||
"""替换数据中的 osu! 资源 URL。"""
|
||||
|
||||
if not settings.enable_asset_proxy:
|
||||
return data
|
||||
|
||||
if hasattr(data, "model_dump"):
|
||||
raw = data.model_dump()
|
||||
processed = _replace_asset_urls_in_data(raw)
|
||||
try:
|
||||
return data.__class__(**processed)
|
||||
except Exception:
|
||||
return processed
|
||||
|
||||
if isinstance(data, (dict, list, tuple, str)):
|
||||
return _replace_asset_urls_in_data(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def asset_proxy_response(func: Handler) -> Handler:
|
||||
"""装饰器:在返回响应前替换资源 URL。"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if not settings.enable_asset_proxy:
|
||||
return result
|
||||
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump()
|
||||
|
||||
return _replace_asset_urls_in_data(result)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
@@ -1,19 +1,37 @@
|
||||
"""
|
||||
GeoLite2 Helper Class
|
||||
GeoLite2 Helper Class (asynchronous)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Required, TypedDict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
import maxminddb
|
||||
|
||||
|
||||
class GeoIPLookupResult(TypedDict, total=False):
|
||||
ip: Required[str]
|
||||
country_iso: str
|
||||
country_name: str
|
||||
city_name: str
|
||||
latitude: str
|
||||
longitude: str
|
||||
time_zone: str
|
||||
postal_code: str
|
||||
asn: int | None
|
||||
organization: str
|
||||
|
||||
|
||||
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
||||
EDITIONS = {
|
||||
"City": "GeoLite2-City",
|
||||
@@ -25,163 +43,184 @@ EDITIONS = {
|
||||
class GeoIPHelper:
|
||||
def __init__(
|
||||
self,
|
||||
dest_dir="./geoip",
|
||||
license_key=None,
|
||||
editions=None,
|
||||
max_age_days=8,
|
||||
timeout=60.0,
|
||||
dest_dir: str | Path = Path("./geoip"),
|
||||
license_key: str | None = None,
|
||||
editions: list[str] | None = None,
|
||||
max_age_days: int = 8,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
self.dest_dir = dest_dir
|
||||
self.dest_dir = Path(dest_dir).expanduser()
|
||||
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
||||
self.editions = editions or ["City", "ASN"]
|
||||
self.editions = list(editions or ["City", "ASN"])
|
||||
self.max_age_days = max_age_days
|
||||
self.timeout = timeout
|
||||
self._readers = {}
|
||||
self._readers: dict[str, maxminddb.Reader] = {}
|
||||
self._update_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _safe_extract(tar: tarfile.TarFile, path: str):
|
||||
base = Path(path).resolve()
|
||||
for m in tar.getmembers():
|
||||
target = (base / m.name).resolve()
|
||||
if not str(target).startswith(str(base)):
|
||||
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||
base = path.resolve()
|
||||
for member in tar.getmembers():
|
||||
target = (base / member.name).resolve()
|
||||
if not target.is_relative_to(base): # py312
|
||||
raise RuntimeError("Unsafe path in tar file")
|
||||
tar.extractall(path=path, filter="data")
|
||||
tar.extractall(path=base, filter="data")
|
||||
|
||||
def _download_and_extract(self, edition_id: str) -> str:
|
||||
"""
|
||||
下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb
|
||||
- 跟随 302 重定向
|
||||
- 流式下载到临时文件
|
||||
- 临时目录退出后自动清理
|
||||
"""
|
||||
@staticmethod
|
||||
def _as_mapping(value: Any) -> dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _as_str(value: Any, default: str = "") -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any) -> int | None:
|
||||
return value if isinstance(value, int) else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tarball(src: Path, dest: Path) -> None:
|
||||
with tarfile.open(src, "r:gz") as tar:
|
||||
GeoIPHelper._safe_extract(tar, dest)
|
||||
|
||||
@staticmethod
|
||||
def _find_mmdb(root: Path) -> Path | None:
|
||||
for candidate in root.rglob("*.mmdb"):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _latest_file_sync(self, edition_id: str) -> Path | None:
|
||||
directory = self.dest_dir
|
||||
if not directory.is_dir():
|
||||
return None
|
||||
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
async def _latest_file(self, edition_id: str) -> Path | None:
|
||||
return await asyncio.to_thread(self._latest_file_sync, edition_id)
|
||||
|
||||
async def _download_and_extract(self, edition_id: str) -> Path:
|
||||
if not self.license_key:
|
||||
raise ValueError("缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY")
|
||||
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
|
||||
|
||||
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
||||
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
|
||||
|
||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
||||
with client.stream("GET", url) as resp:
|
||||
try:
|
||||
tgz_path = tmp_dir / "db.tgz"
|
||||
async with (
|
||||
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
|
||||
client.stream("GET", url) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
tgz_path = os.path.join(tmpd, "db.tgz")
|
||||
# 流式写入
|
||||
with open(tgz_path, "wb") as f:
|
||||
for chunk in resp.iter_bytes():
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
async with aiofiles.open(tgz_path, "wb") as download_file:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
await download_file.write(chunk)
|
||||
|
||||
# 解压并只移动 .mmdb
|
||||
with tarfile.open(tgz_path, "r:gz") as tar:
|
||||
# 先安全检查与解压
|
||||
self._safe_extract(tar, tmpd)
|
||||
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
|
||||
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
|
||||
if mmdb_path is None:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
|
||||
# 递归找 .mmdb
|
||||
mmdb_path = None
|
||||
for root, _, files in os.walk(tmpd):
|
||||
for fn in files:
|
||||
if fn.endswith(".mmdb"):
|
||||
mmdb_path = os.path.join(root, fn)
|
||||
break
|
||||
if mmdb_path:
|
||||
break
|
||||
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
|
||||
dst = self.dest_dir / mmdb_path.name
|
||||
await asyncio.to_thread(shutil.move, mmdb_path, dst)
|
||||
return dst
|
||||
finally:
|
||||
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
|
||||
|
||||
if not mmdb_path:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
async def update(self, force: bool = False) -> None:
|
||||
async with self._update_lock:
|
||||
for edition in self.editions:
|
||||
edition_id = EDITIONS[edition]
|
||||
path = await self._latest_file(edition_id)
|
||||
need_download = force or path is None
|
||||
|
||||
os.makedirs(self.dest_dir, exist_ok=True)
|
||||
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
|
||||
shutil.move(mmdb_path, dst)
|
||||
return dst
|
||||
|
||||
def _latest_file(self, edition_id: str):
|
||||
if not os.path.isdir(self.dest_dir):
|
||||
return None
|
||||
files = [
|
||||
os.path.join(self.dest_dir, f)
|
||||
for f in os.listdir(self.dest_dir)
|
||||
if f.startswith(edition_id) and f.endswith(".mmdb")
|
||||
]
|
||||
return max(files, key=os.path.getmtime) if files else None
|
||||
|
||||
def update(self, force=False):
|
||||
from app.log import logger
|
||||
|
||||
for ed in self.editions:
|
||||
eid = EDITIONS[ed]
|
||||
path = self._latest_file(eid)
|
||||
need = force or not path
|
||||
|
||||
if path:
|
||||
age_days = (time.time() - os.path.getmtime(path)) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need = True
|
||||
logger.info(
|
||||
f"[GeoIP] {eid} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
if path:
|
||||
mtime = await asyncio.to_thread(path.stat)
|
||||
age_days = (time.time() - mtime.st_mtime) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need_download = True
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[GeoIP] {eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[GeoIP] {eid} database not found, will download")
|
||||
logger.info(f"{edition_id} database not found, will download")
|
||||
|
||||
if need:
|
||||
logger.info(f"[GeoIP] Downloading {eid} database...")
|
||||
path = self._download_and_extract(eid)
|
||||
logger.info(f"[GeoIP] {eid} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"[GeoIP] Using existing {eid} database")
|
||||
if need_download:
|
||||
logger.info(f"Downloading {edition_id} database...")
|
||||
path = await self._download_and_extract(edition_id)
|
||||
logger.info(f"{edition_id} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"Using existing {edition_id} database")
|
||||
|
||||
old = self._readers.get(ed)
|
||||
if old:
|
||||
try:
|
||||
old.close()
|
||||
except Exception:
|
||||
pass
|
||||
if path is not None:
|
||||
self._readers[ed] = maxminddb.open_database(path)
|
||||
old_reader = self._readers.get(edition)
|
||||
if old_reader:
|
||||
with suppress(Exception):
|
||||
old_reader.close()
|
||||
if path is not None:
|
||||
self._readers[edition] = maxminddb.open_database(str(path))
|
||||
|
||||
def lookup(self, ip: str):
|
||||
res = {"ip": ip}
|
||||
# City
|
||||
city_r = self._readers.get("City")
|
||||
if city_r:
|
||||
data = city_r.get(ip)
|
||||
if data:
|
||||
country = data.get("country") or {}
|
||||
res["country_iso"] = country.get("iso_code") or ""
|
||||
res["country_name"] = (country.get("names") or {}).get("en", "")
|
||||
city = data.get("city") or {}
|
||||
res["city_name"] = (city.get("names") or {}).get("en", "")
|
||||
loc = data.get("location") or {}
|
||||
res["latitude"] = str(loc.get("latitude") or "")
|
||||
res["longitude"] = str(loc.get("longitude") or "")
|
||||
res["time_zone"] = str(loc.get("time_zone") or "")
|
||||
postal = data.get("postal") or {}
|
||||
if "code" in postal:
|
||||
res["postal_code"] = postal["code"]
|
||||
# ASN
|
||||
asn_r = self._readers.get("ASN")
|
||||
if asn_r:
|
||||
data = asn_r.get(ip)
|
||||
if data:
|
||||
res["asn"] = data.get("autonomous_system_number")
|
||||
res["organization"] = data.get("autonomous_system_organization")
|
||||
def lookup(self, ip: str) -> GeoIPLookupResult:
|
||||
res: GeoIPLookupResult = {"ip": ip}
|
||||
city_reader = self._readers.get("City")
|
||||
if city_reader:
|
||||
data = city_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
country = self._as_mapping(data.get("country"))
|
||||
res["country_iso"] = self._as_str(country.get("iso_code"))
|
||||
country_names = self._as_mapping(country.get("names"))
|
||||
res["country_name"] = self._as_str(country_names.get("en"))
|
||||
|
||||
city = self._as_mapping(data.get("city"))
|
||||
city_names = self._as_mapping(city.get("names"))
|
||||
res["city_name"] = self._as_str(city_names.get("en"))
|
||||
|
||||
location = self._as_mapping(data.get("location"))
|
||||
latitude = location.get("latitude")
|
||||
longitude = location.get("longitude")
|
||||
res["latitude"] = str(latitude) if latitude is not None else ""
|
||||
res["longitude"] = str(longitude) if longitude is not None else ""
|
||||
res["time_zone"] = self._as_str(location.get("time_zone"))
|
||||
|
||||
postal = self._as_mapping(data.get("postal"))
|
||||
postal_code = postal.get("code")
|
||||
if postal_code is not None:
|
||||
res["postal_code"] = self._as_str(postal_code)
|
||||
|
||||
asn_reader = self._readers.get("ASN")
|
||||
if asn_reader:
|
||||
data = asn_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
res["asn"] = self._as_int(data.get("autonomous_system_number"))
|
||||
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
|
||||
return res
|
||||
|
||||
def close(self):
|
||||
for r in self._readers.values():
|
||||
try:
|
||||
r.close()
|
||||
except Exception:
|
||||
pass
|
||||
def close(self) -> None:
|
||||
for reader in self._readers.values():
|
||||
with suppress(Exception):
|
||||
reader.close()
|
||||
self._readers = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
async def _demo() -> None:
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
await geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
asyncio.run(_demo())
|
||||
|
||||
@@ -6,8 +6,6 @@ Rate limiter for osu! API requests to avoid abuse detection.
|
||||
- 建议:每分钟不超过 60 次请求以避免滥用检测
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
import time
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
会话验证接口
|
||||
|
||||
基于osu-web的SessionVerificationInterface实现
|
||||
用于标准化会话验证行为
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SessionVerificationInterface(ABC):
|
||||
"""会话验证接口
|
||||
|
||||
定义了会话验证所需的基本操作,参考osu-web的实现
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def find_for_verification(cls, session_id: str) -> SessionVerificationInterface | None:
|
||||
"""根据会话ID查找会话用于验证
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
会话实例或None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥/ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_verification_method(self) -> str | None:
|
||||
"""获取当前验证方法
|
||||
|
||||
Returns:
|
||||
验证方法 ('totp', 'mail') 或 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_verification_method(self, method: str) -> None:
|
||||
"""设置验证方法
|
||||
|
||||
Args:
|
||||
method: 验证方法 ('totp', 'mail')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def user_id(self) -> int | None:
|
||||
"""获取关联的用户ID"""
|
||||
pass
|
||||
120
app/log.py
120
app/log.py
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from sys import stdout
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
from app.utils import snake_to_pascal
|
||||
|
||||
import loguru
|
||||
|
||||
@@ -37,16 +37,18 @@ class InterceptHandler(logging.Handler):
|
||||
depth += 1
|
||||
|
||||
message = record.getMessage()
|
||||
|
||||
_logger = logger
|
||||
if record.name == "uvicorn.access":
|
||||
message = self._format_uvicorn_access_log(message)
|
||||
color = True
|
||||
_logger = uvicorn_logger()
|
||||
elif record.name == "uvicorn.error":
|
||||
message = self._format_uvicorn_error_log(message)
|
||||
_logger = uvicorn_logger()
|
||||
color = True
|
||||
else:
|
||||
color = False
|
||||
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
|
||||
_logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
|
||||
|
||||
def _format_uvicorn_error_log(self, message: str) -> str:
|
||||
websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
|
||||
@@ -93,9 +95,7 @@ class InterceptHandler(logging.Handler):
|
||||
status_color = "green"
|
||||
elif 300 <= status < 400:
|
||||
status_color = "yellow"
|
||||
elif 400 <= status < 500:
|
||||
status_color = "red"
|
||||
elif 500 <= status < 600:
|
||||
elif 400 <= status < 500 or 500 <= status < 600:
|
||||
status_color = "red"
|
||||
|
||||
return (
|
||||
@@ -107,11 +107,106 @@ class InterceptHandler(logging.Handler):
|
||||
return message
|
||||
|
||||
|
||||
def get_caller_class_name(module_prefix: str = "", just_last_part: bool = True) -> str | None:
|
||||
stack = inspect.stack()
|
||||
for frame_info in stack[2:]:
|
||||
module = frame_info.frame.f_globals.get("__name__", "")
|
||||
if module_prefix and not module.startswith(module_prefix):
|
||||
continue
|
||||
|
||||
local_vars = frame_info.frame.f_locals
|
||||
# 实例方法
|
||||
if "self" in local_vars:
|
||||
return local_vars["self"].__class__.__name__
|
||||
# 类方法
|
||||
if "cls" in local_vars:
|
||||
return local_vars["cls"].__name__
|
||||
|
||||
# 静态方法 / 普通函数 -> 尝试通过函数名匹配类
|
||||
func_name = frame_info.function
|
||||
for obj_name, obj in frame_info.frame.f_globals.items():
|
||||
if isinstance(obj, type): # 遍历模块内类
|
||||
cls = obj
|
||||
attr = getattr(cls, func_name, None)
|
||||
if isinstance(attr, (staticmethod, classmethod, FunctionType)):
|
||||
return cls.__name__
|
||||
|
||||
# 如果没找到类,返回模块名
|
||||
if just_last_part:
|
||||
return module.rsplit(".", 1)[-1]
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def service_logger(name: str) -> "Logger":
|
||||
return logger.bind(service=name)
|
||||
|
||||
|
||||
def fetcher_logger(name: str) -> "Logger":
|
||||
return logger.bind(fetcher=name)
|
||||
|
||||
|
||||
def task_logger(name: str) -> "Logger":
|
||||
return logger.bind(task=name)
|
||||
|
||||
|
||||
def system_logger(name: str) -> "Logger":
|
||||
return logger.bind(system=name)
|
||||
|
||||
|
||||
def uvicorn_logger() -> "Logger":
|
||||
return logger.bind(uvicorn="Uvicorn")
|
||||
|
||||
|
||||
def log(name: str) -> "Logger":
|
||||
return logger.bind(real_name=name)
|
||||
|
||||
|
||||
def dynamic_format(record):
|
||||
name = ""
|
||||
|
||||
uvicorn = record["extra"].get("uvicorn")
|
||||
if uvicorn:
|
||||
name = f"<fg #228B22>{uvicorn}</fg #228B22>"
|
||||
|
||||
service = record["extra"].get("service")
|
||||
if not service:
|
||||
service = get_caller_class_name("app.service")
|
||||
if service:
|
||||
name = f"<blue>{service}</blue>"
|
||||
|
||||
fetcher = record["extra"].get("fetcher")
|
||||
if not fetcher:
|
||||
fetcher = get_caller_class_name("app.fetcher")
|
||||
if fetcher:
|
||||
name = f"<magenta>{fetcher}</magenta>"
|
||||
|
||||
task = record["extra"].get("task")
|
||||
if not task:
|
||||
task = get_caller_class_name("app.tasks")
|
||||
if task:
|
||||
task = snake_to_pascal(task)
|
||||
name = f"<fg #FFD700>{task}</fg #FFD700>"
|
||||
|
||||
system = record["extra"].get("system")
|
||||
if system:
|
||||
name = f"<red>{system}</red>"
|
||||
|
||||
if name == "":
|
||||
real_name = record["extra"].get("real_name", "") or record["name"]
|
||||
name = f"<fg #FFC1C1>{real_name}</fg #FFC1C1>"
|
||||
|
||||
format = f"<green>{{time:YYYY-MM-DD HH:mm:ss}}</green> [<level>{{level}}</level>] | {name} | {{message}}\n"
|
||||
if record["exception"]:
|
||||
format += "{exception}\n"
|
||||
return format
|
||||
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
stdout,
|
||||
colorize=True,
|
||||
format=("<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"),
|
||||
format=dynamic_format,
|
||||
level=settings.log_level,
|
||||
diagnose=settings.debug,
|
||||
)
|
||||
@@ -120,7 +215,7 @@ logger.add(
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
colorize=False,
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} {level} | {message}",
|
||||
format=dynamic_format,
|
||||
level=settings.log_level,
|
||||
diagnose=settings.debug,
|
||||
encoding="utf8",
|
||||
@@ -135,8 +230,9 @@ uvicorn_loggers = [
|
||||
]
|
||||
|
||||
for logger_name in uvicorn_loggers:
|
||||
uvicorn_logger = logging.getLogger(logger_name)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
uvicorn_logger.propagate = False
|
||||
_uvicorn_logger = logging.getLogger(logger_name)
|
||||
_uvicorn_logger.handlers = [InterceptHandler()]
|
||||
_uvicorn_logger.propagate = False
|
||||
|
||||
logging.getLogger("httpx").setLevel("WARNING")
|
||||
logging.getLogger("apscheduler").setLevel("WARNING")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .verify_session import SessionState, VerifySessionMiddleware
|
||||
|
||||
__all__ = ["SessionState", "VerifySessionMiddleware"]
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.middleware.verify_session import VerifySessionMiddleware
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def setup_session_verification_middleware(app: FastAPI) -> None:
|
||||
"""设置会话验证中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# 只在启用会话验证时添加中间件
|
||||
if settings.enable_session_verification:
|
||||
app.add_middleware(VerifySessionMiddleware)
|
||||
|
||||
# 可以在这里添加中间件配置日志
|
||||
from app.log import logger
|
||||
|
||||
logger.info("[Middleware] Session verification middleware enabled")
|
||||
else:
|
||||
from app.log import logger
|
||||
|
||||
logger.info("[Middleware] Session verification middleware disabled")
|
||||
|
||||
|
||||
def setup_all_middlewares(app: FastAPI) -> None:
|
||||
"""设置所有中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# 设置会话验证中间件
|
||||
setup_session_verification_middleware(app)
|
||||
|
||||
# 可以在这里添加其他中间件
|
||||
# app.add_middleware(OtherMiddleware)
|
||||
|
||||
from app.log import logger
|
||||
|
||||
logger.info("[Middleware] All middlewares configured")
|
||||
@@ -4,17 +4,15 @@ FastAPI会话验证中间件
|
||||
基于osu-web的会话验证系统,适配FastAPI框架
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ClassVar
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database.lazer_user import User
|
||||
from app.database.user import User
|
||||
from app.database.verification import LoginSession
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.service.verification_service import LoginSessionService
|
||||
from app.utils import extract_user_agent
|
||||
|
||||
@@ -25,180 +23,7 @@ from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"""会话验证中间件
|
||||
|
||||
参考osu-web的VerifyUser中间件,适配FastAPI
|
||||
"""
|
||||
|
||||
# 需要跳过验证的路由
|
||||
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/session/verify/mail-fallback",
|
||||
"/api/v2/me",
|
||||
"/api/v2/me/",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/redoc",
|
||||
}
|
||||
|
||||
# 总是需要验证的路由前缀
|
||||
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
|
||||
"/api/private/admin/",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""中间件主处理逻辑"""
|
||||
try:
|
||||
# 检查是否跳过验证
|
||||
if self._should_skip_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取当前用户
|
||||
user = await self._get_current_user(request)
|
||||
if not user:
|
||||
# 未登录用户跳过验证
|
||||
return await call_next(request)
|
||||
|
||||
# 获取会话状态
|
||||
session_state = await self._get_session_state(request, user)
|
||||
if not session_state:
|
||||
# 无会话状态,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否已验证
|
||||
if session_state.is_verified():
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否需要验证
|
||||
if not self._requires_verification(request, user):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await self._initiate_verification(request, session_state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error: {e}")
|
||||
# 出错时允许请求继续,避免阻塞
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_verification(self, request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
path = request.url.path
|
||||
|
||||
# 完全匹配的跳过路由
|
||||
if path in self.SKIP_VERIFICATION_ROUTES:
|
||||
return True
|
||||
|
||||
# 非API请求跳过
|
||||
if not path.startswith("/api/"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _requires_verification(self, request: Request, user: User) -> bool:
|
||||
"""检查是否需要验证"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# 检查是否为强制验证的路由
|
||||
for pattern in self.ALWAYS_VERIFY_PATTERNS:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
if not user.is_active:
|
||||
return True
|
||||
|
||||
# 安全方法(GET/HEAD/OPTIONS)一般不需要验证
|
||||
safe_methods = {"GET", "HEAD", "OPTIONS"}
|
||||
if method in safe_methods:
|
||||
return False
|
||||
|
||||
# 修改操作(POST/PUT/DELETE/PATCH)需要验证
|
||||
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def _get_current_user(self, request: Request) -> User | None:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 从Authorization header提取token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # 移除"Bearer "前缀
|
||||
|
||||
# 创建专用数据库会话
|
||||
async with with_db() as db:
|
||||
# 获取token记录
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
# 获取用户
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
|
||||
return None
|
||||
|
||||
async def _get_session_state(self, request: Request, user: User) -> SessionState | None:
|
||||
"""获取会话状态"""
|
||||
try:
|
||||
# 提取会话token(这里简化为使用相同的auth token)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_version = 0
|
||||
raw_api_version = request.headers.get("x-api-version")
|
||||
if raw_api_version is not None:
|
||||
try:
|
||||
api_version = int(raw_api_version)
|
||||
except ValueError:
|
||||
api_version = 0
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
session_token = auth_header[7:]
|
||||
|
||||
# 获取数据库和Redis连接
|
||||
async with with_db() as db:
|
||||
redis = get_redis()
|
||||
|
||||
# 查找会话
|
||||
session = await LoginSessionService.find_for_verification(db, session_token)
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return SessionState(session, user, redis, db, api_version)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
|
||||
return None
|
||||
|
||||
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
|
||||
"""启动验证流程"""
|
||||
try:
|
||||
method = await state.get_method()
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
# 返回验证要求响应
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"method": method, "message": "Session verification required"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error initiating verification: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
|
||||
)
|
||||
logger = log("Middleware")
|
||||
|
||||
|
||||
class SessionState:
|
||||
@@ -261,7 +86,7 @@ class SessionState:
|
||||
self.session.web_uuid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session State] Error marking verified: {e}")
|
||||
logger.error(f"Error marking verified: {e}")
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
@@ -274,7 +99,7 @@ class SessionState:
|
||||
self.db, self.redis, self.user.id, self.user.username, self.user.email, None, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session State] Error issuing mail: {e}")
|
||||
logger.error(f"Error issuing mail: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
@@ -289,3 +114,169 @@ class SessionState:
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
|
||||
class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"""会话验证中间件
|
||||
|
||||
参考osu-web的VerifyUser中间件,适配FastAPI
|
||||
"""
|
||||
|
||||
# 需要跳过验证的路由
|
||||
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/session/verify/mail-fallback",
|
||||
"/api/v2/me",
|
||||
"/api/v2/me/",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/redoc",
|
||||
}
|
||||
|
||||
# 总是需要验证的路由前缀
|
||||
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
|
||||
"/api/private/admin/",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""中间件主处理逻辑"""
|
||||
# 检查是否跳过验证
|
||||
if self._should_skip_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取当前用户
|
||||
user = await self._get_current_user(request)
|
||||
if not user:
|
||||
# 未登录用户跳过验证
|
||||
return await call_next(request)
|
||||
|
||||
# 获取会话状态
|
||||
session_state = await self._get_session_state(request, user)
|
||||
if not session_state:
|
||||
# 无会话状态,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否已验证
|
||||
if session_state.is_verified():
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否需要验证
|
||||
if not self._requires_verification(request, user):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await self._initiate_verification(session_state)
|
||||
|
||||
def _should_skip_verification(self, request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
path = request.url.path
|
||||
|
||||
# 完全匹配的跳过路由
|
||||
if path in self.SKIP_VERIFICATION_ROUTES:
|
||||
return True
|
||||
|
||||
# 非API请求跳过
|
||||
return bool(not path.startswith("/api/"))
|
||||
|
||||
def _requires_verification(self, request: Request, user: User) -> bool:
|
||||
"""检查是否需要验证"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# 检查是否为强制验证的路由
|
||||
for pattern in self.ALWAYS_VERIFY_PATTERNS:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
if not user.is_active:
|
||||
return True
|
||||
|
||||
# 安全方法(GET/HEAD/OPTIONS)一般不需要验证
|
||||
safe_methods = {"GET", "HEAD", "OPTIONS"}
|
||||
if method in safe_methods:
|
||||
return False
|
||||
|
||||
# 修改操作(POST/PUT/DELETE/PATCH)需要验证
|
||||
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def _get_current_user(self, request: Request) -> User | None:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 从Authorization header提取token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # 移除"Bearer "前缀
|
||||
|
||||
# 创建专用数据库会话
|
||||
async with with_db() as db:
|
||||
# 获取token记录
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
# 获取用户
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting user: {e}")
|
||||
return None
|
||||
|
||||
async def _get_session_state(self, request: Request, user: User) -> SessionState | None:
|
||||
"""获取会话状态"""
|
||||
try:
|
||||
# 提取会话token(这里简化为使用相同的auth token)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_version = 0
|
||||
raw_api_version = request.headers.get("x-api-version")
|
||||
if raw_api_version is not None:
|
||||
try:
|
||||
api_version = int(raw_api_version)
|
||||
except ValueError:
|
||||
api_version = 0
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
session_token = auth_header[7:]
|
||||
|
||||
# 获取数据库和Redis连接
|
||||
async with with_db() as db:
|
||||
redis = get_redis()
|
||||
|
||||
# 查找会话
|
||||
session = await LoginSessionService.find_for_verification(db, session_token)
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return SessionState(session, user, redis, db, api_version)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session state: {e}")
|
||||
return None
|
||||
|
||||
async def _initiate_verification(self, state: SessionState) -> Response:
|
||||
"""启动验证流程"""
|
||||
try:
|
||||
method = await state.get_method()
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
# 返回验证要求响应
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"method": method, "message": "Session verification required"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating verification: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
@@ -204,3 +202,6 @@ class SearchQueryModel(BaseModel):
|
||||
default=None,
|
||||
description="游标字符串,用于分页",
|
||||
)
|
||||
|
||||
|
||||
SearchQueryModel.model_rebuild()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
扩展的 OAuth 响应模型,支持二次验证
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -11,7 +9,7 @@ class ExtendedTokenResponse(BaseModel):
|
||||
"""扩展的令牌响应,支持二次验证状态"""
|
||||
|
||||
access_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
token_type: str = "Bearer" # noqa: S105
|
||||
expires_in: int | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
@@ -20,14 +18,3 @@ class ExtendedTokenResponse(BaseModel):
|
||||
requires_second_factor: bool = False
|
||||
verification_message: str | None = None
|
||||
user_id: int | None = None # 用于二次验证的用户ID
|
||||
|
||||
|
||||
class SessionState(BaseModel):
|
||||
"""会话状态"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
requires_verification: bool
|
||||
session_token: str | None = None
|
||||
verification_sent: bool = False
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.models.signalr import SignalRUnionMessage, UserState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS = 13
|
||||
|
||||
|
||||
class _UserActivity(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChoosingBeatmap(_UserActivity):
|
||||
union_type: ClassVar[Literal[11]] = 11
|
||||
|
||||
|
||||
class _InGame(_UserActivity):
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
ruleset_id: int
|
||||
ruleset_playing_verb: str
|
||||
|
||||
|
||||
class InSoloGame(_InGame):
|
||||
union_type: ClassVar[Literal[12]] = 12
|
||||
|
||||
|
||||
class InMultiplayerGame(_InGame):
|
||||
union_type: ClassVar[Literal[23]] = 23
|
||||
|
||||
|
||||
class SpectatingMultiplayerGame(_InGame):
|
||||
union_type: ClassVar[Literal[24]] = 24
|
||||
|
||||
|
||||
class InPlaylistGame(_InGame):
|
||||
union_type: ClassVar[Literal[31]] = 31
|
||||
|
||||
|
||||
class PlayingDailyChallenge(_InGame):
|
||||
union_type: ClassVar[Literal[52]] = 52
|
||||
|
||||
|
||||
class EditingBeatmap(_UserActivity):
|
||||
union_type: ClassVar[Literal[41]] = 41
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class TestingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[43]] = 43
|
||||
|
||||
|
||||
class ModdingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[42]] = 42
|
||||
|
||||
|
||||
class WatchingReplay(_UserActivity):
|
||||
union_type: ClassVar[Literal[13]] = 13
|
||||
score_id: int
|
||||
player_name: str
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class SpectatingUser(WatchingReplay):
|
||||
union_type: ClassVar[Literal[14]] = 14
|
||||
|
||||
|
||||
class SearchingForLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[21]] = 21
|
||||
|
||||
|
||||
class InLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[22]] = 22
|
||||
room_id: int
|
||||
room_name: str
|
||||
|
||||
|
||||
class InDailyChallengeLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[51]] = 51
|
||||
|
||||
|
||||
UserActivity = (
|
||||
ChoosingBeatmap
|
||||
| InSoloGame
|
||||
| WatchingReplay
|
||||
| SpectatingUser
|
||||
| SearchingForLobby
|
||||
| InLobby
|
||||
| InMultiplayerGame
|
||||
| SpectatingMultiplayerGame
|
||||
| InPlaylistGame
|
||||
| EditingBeatmap
|
||||
| ModdingBeatmap
|
||||
| TestingBeatmap
|
||||
| InDailyChallengeLobby
|
||||
| PlayingDailyChallenge
|
||||
)
|
||||
|
||||
|
||||
class UserPresence(BaseModel):
|
||||
activity: UserActivity | None = None
|
||||
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
||||
|
||||
@property
|
||||
def for_push(self) -> "UserPresence | None":
|
||||
return UserPresence(
|
||||
activity=self.activity,
|
||||
status=self.status,
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(UserPresence, UserState): ...
|
||||
|
||||
|
||||
class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
DO_NOT_DISTURB = 1
|
||||
ONLINE = 2
|
||||
|
||||
|
||||
class DailyChallengeInfo(BaseModel):
|
||||
room_id: int
|
||||
|
||||
|
||||
class MultiplayerPlaylistItemStats(BaseModel):
|
||||
playlist_item_id: int = 0
|
||||
total_score_distribution: list[int] = Field(
|
||||
default_factory=list,
|
||||
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
)
|
||||
cumulative_score: int = 0
|
||||
last_processed_score_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomStats(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MultiplayerRoomScoreSetEvent(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_id: int
|
||||
score_id: int
|
||||
user_id: int
|
||||
total_score: int
|
||||
new_rank: int | None = None
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings as app_settings
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.path import CONFIG_DIR, STATIC_DIR
|
||||
|
||||
from pydantic import ConfigDict, Field, create_model
|
||||
@@ -268,7 +266,7 @@ def generate_ranked_mod_settings(enable_all: bool = False):
|
||||
for mod_acronym in ruleset_mods:
|
||||
result[ruleset_id][mod_acronym] = {}
|
||||
if not enable_all:
|
||||
logger.info("ENABLE_ALL_MODS_PP is deprecated, transformed to config/ranked_mods.json")
|
||||
log("Mod").info("ENABLE_ALL_MODS_PP is deprecated, transformed to config/ranked_mods.json")
|
||||
result["$mods_checksum"] = checksum # pyright: ignore[reportArgumentType]
|
||||
ranked_mods_file.write_text(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
@@ -1,840 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
TypedDict,
|
||||
cast,
|
||||
override,
|
||||
)
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.utils import utcnow
|
||||
|
||||
from .mods import API_MODS, APIMod
|
||||
from .room import (
|
||||
DownloadState,
|
||||
MatchType,
|
||||
MultiplayerRoomState,
|
||||
MultiplayerUserState,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomStatus,
|
||||
)
|
||||
from .signalr import (
|
||||
SignalRMeta,
|
||||
SignalRUnionMessage,
|
||||
UserState,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.room import Room
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
PER_USER_LIMIT = 3
|
||||
|
||||
|
||||
class MultiplayerClientState(UserState):
|
||||
room_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
name: str = "Unnamed Room"
|
||||
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
password: str = ""
|
||||
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||
auto_skip: bool = False
|
||||
|
||||
@property
|
||||
def auto_start_enabled(self) -> bool:
|
||||
return self.auto_start_duration != timedelta(seconds=0)
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
state: DownloadState = DownloadState.UNKNOWN
|
||||
download_progress: float | None = None
|
||||
|
||||
|
||||
class _MatchUserState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class TeamVersusUserState(_MatchUserState):
|
||||
team_id: int
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchUserState = TeamVersusUserState
|
||||
|
||||
|
||||
class _MatchRoomState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class MultiplayerTeam(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class TeamVersusRoomState(_MatchRoomState):
|
||||
teams: list[MultiplayerTeam] = Field(
|
||||
default_factory=lambda: [
|
||||
MultiplayerTeam(id=0, name="Team Red"),
|
||||
MultiplayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchRoomState = TeamVersusRoomState
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str
|
||||
ruleset_id: int
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool
|
||||
playlist_order: int
|
||||
played_at: datetime | None = None
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
proposed_mods: list[APIMod],
|
||||
) -> tuple[bool, list[APIMod]]:
|
||||
"""
|
||||
Validates user mods against playlist item rules and returns valid mods.
|
||||
Returns (is_valid, valid_mods).
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
valid_mods = []
|
||||
all_proposed_valid = True
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
|
||||
# Check mod compatibility within user mods
|
||||
incompatible_mods = set()
|
||||
final_valid_mods = []
|
||||
for mod in valid_mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||
if setting_mods:
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
final_valid_mods.append(mod)
|
||||
|
||||
# If not freestyle, check against allowed mods
|
||||
if not self.freestyle:
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
if mod["acronym"] not in allowed_acronyms:
|
||||
all_proposed_valid = False
|
||||
else:
|
||||
filtered_valid_mods.append(mod)
|
||||
final_valid_mods = filtered_valid_mods
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
mod_acronym = mod["acronym"]
|
||||
is_compatible = True
|
||||
|
||||
for other_acronym in all_mod_acronyms:
|
||||
if other_acronym == mod_acronym:
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
|
||||
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
|
||||
is_compatible = False
|
||||
all_proposed_valid = False
|
||||
break
|
||||
|
||||
if is_compatible:
|
||||
filtered_valid_mods.append(mod)
|
||||
|
||||
return all_proposed_valid, filtered_valid_mods
|
||||
|
||||
def clone(self) -> "PlaylistItem":
|
||||
copy = self.model_copy()
|
||||
copy.required_mods = list(self.required_mods)
|
||||
copy.allowed_mods = list(self.allowed_mods)
|
||||
copy.expired = False
|
||||
copy.played_at = None
|
||||
return copy
|
||||
|
||||
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
beatmap_id: int | None = None # freestyle
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
room_id: int
|
||||
state: MultiplayerRoomState
|
||||
settings: MultiplayerRoomSettings
|
||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||
host: MultiplayerRoomUser | None = None
|
||||
match_state: MatchRoomState | None = None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room: "Room") -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
|
||||
# 用户列表
|
||||
users = [MultiplayerRoomUser(user_id=room.host_id)]
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if room.playlist:
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
id=item.id,
|
||||
owner_id=item.owner_id,
|
||||
beatmap_id=item.beatmap_id,
|
||||
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
|
||||
ruleset_id=item.ruleset_id,
|
||||
required_mods=item.required_mods,
|
||||
allowed_mods=item.allowed_mods,
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
room_id=room.id,
|
||||
state=getattr(room, "state", MultiplayerRoomState.OPEN),
|
||||
settings=MultiplayerRoomSettings(
|
||||
name=room.name,
|
||||
playlist_item_id=playlist[0].id if playlist else 0,
|
||||
password=getattr(room, "password", ""),
|
||||
match_type=room.type,
|
||||
queue_mode=room.queue_mode,
|
||||
auto_start_duration=timedelta(seconds=room.auto_start_duration),
|
||||
auto_skip=room.auto_skip,
|
||||
),
|
||||
users=users,
|
||||
host=host_user,
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=room.channel_id or 0,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerQueue:
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.server_room = room
|
||||
self.current_index = 0
|
||||
|
||||
@property
|
||||
def hub(self) -> "MultiplayerHub":
|
||||
return self.server_room.hub
|
||||
|
||||
@property
|
||||
def upcoming_items(self):
|
||||
return sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda i: i.playlist_order,
|
||||
)
|
||||
|
||||
@property
|
||||
def room(self):
|
||||
return self.server_room.room
|
||||
|
||||
async def update_order(self):
|
||||
from app.database import Playlist
|
||||
|
||||
match self.room.settings.queue_mode:
|
||||
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
|
||||
ordered_active_items = []
|
||||
|
||||
is_first_set = True
|
||||
first_set_order_by_user_id = {}
|
||||
|
||||
active_items = [item for item in self.room.playlist if not item.expired]
|
||||
active_items.sort(key=lambda x: x.id)
|
||||
|
||||
user_item_groups = {}
|
||||
for item in active_items:
|
||||
if item.owner_id not in user_item_groups:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max((len(items) for items in user_item_groups.values()), default=0)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
for user_id, items in user_item_groups.items():
|
||||
if i < len(items):
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(key=lambda item: (item.playlist_order, item.id))
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
case _:
|
||||
ordered_active_items = sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
async with with_db() as session:
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
if item.playlist_order == idx:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
if upcoming_items:
|
||||
# 优先选择未过期的项目
|
||||
next_item = upcoming_items[0]
|
||||
else:
|
||||
# 如果所有项目都过期了,选择最近添加的项目(played_at 为 None 或最新的)
|
||||
# 优先选择 expired=False 的项目,然后是 played_at 最晚的
|
||||
next_item = max(
|
||||
self.room.playlist,
|
||||
key=lambda i: (not i.expired, i.played_at or datetime.min),
|
||||
)
|
||||
self.current_index = self.room.playlist.index(next_item)
|
||||
last_id = self.room.settings.playlist_item_id
|
||||
self.room.settings.playlist_item_id = next_item.id
|
||||
if last_id != next_item.id:
|
||||
await self.hub.setting_changed(self.server_room, True)
|
||||
|
||||
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
is_host = self.room.host and self.room.host.user_id == user.user_id
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = beatmap.difficulty_rating
|
||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||
self.room.playlist.append(item)
|
||||
await self.hub.playlist_added(self.server_room, item)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if existing_item is None:
|
||||
raise InvokeException("Attempted to change an item that doesn't exist")
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException("Attempted to change an item which is not owned by the user")
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException("Attempted to change an item which has already been played")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
|
||||
# Update item in playlist
|
||||
for idx, playlist_item in enumerate(self.room.playlist):
|
||||
if playlist_item.id == item.id:
|
||||
self.room.playlist[idx] = item
|
||||
break
|
||||
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
item = next(
|
||||
(i for i in self.room.playlist if i.id == playlist_item_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise InvokeException("Item does not exist in the room")
|
||||
|
||||
# Check if it's the only item and current item
|
||||
if item == self.current_item:
|
||||
upcoming_items = [i for i in self.room.playlist if not i.expired]
|
||||
if len(upcoming_items) == 1:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException("Attempted to remove an item which is not owned by the user")
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException("Attempted to remove an item which has already been played")
|
||||
|
||||
async with with_db() as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if found_item:
|
||||
self.room.playlist.remove(found_item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
await self.hub.playlist_removed(self.server_room, item.id)
|
||||
|
||||
async def finish_current_item(self):
|
||||
from app.database import Playlist
|
||||
|
||||
async with with_db() as session:
|
||||
played_at = utcnow()
|
||||
await session.execute(
|
||||
update(Playlist)
|
||||
.where(
|
||||
col(Playlist.id) == self.current_item.id,
|
||||
col(Playlist.room_id) == self.room.room_id,
|
||||
)
|
||||
.values(expired=True, played_at=played_at)
|
||||
)
|
||||
self.room.playlist[self.current_index].expired = True
|
||||
self.room.playlist[self.current_index].played_at = played_at
|
||||
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||
await self.update_order()
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_current_item()
|
||||
|
||||
async def update_queue_mode(self):
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
@property
|
||||
def current_item(self):
|
||||
return self.room.playlist[self.current_index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownInfo:
|
||||
countdown: MultiplayerCountdown
|
||||
duration: timedelta
|
||||
task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
class _MatchRequest(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChangeTeamRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
team_id: int
|
||||
|
||||
|
||||
class StartMatchCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
duration: timedelta
|
||||
|
||||
|
||||
class StopCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
id: int
|
||||
|
||||
|
||||
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
|
||||
|
||||
|
||||
class MatchTypeHandler(ABC):
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.room = room
|
||||
self.hub = room.hub
|
||||
|
||||
@abstractmethod
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_details(self) -> MatchStartedEventDetail: ...
|
||||
|
||||
|
||||
class HeadToHeadHandler(MatchTypeHandler):
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
if user.match_state is not None:
|
||||
user.match_state = None
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
|
||||
return detail
|
||||
|
||||
|
||||
class TeamVersusHandler(MatchTypeHandler):
|
||||
@override
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
super().__init__(room)
|
||||
self.state = TeamVersusRoomState()
|
||||
room.room.match_state = self.state
|
||||
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
|
||||
self.hub.tasks.add(task)
|
||||
task.add_done_callback(self.hub.tasks.discard)
|
||||
|
||||
def _get_best_available_team(self) -> int:
|
||||
for team in self.state.teams:
|
||||
if all(
|
||||
(
|
||||
user.match_state is None
|
||||
or not isinstance(user.match_state, TeamVersusUserState)
|
||||
or user.match_state.team_id != team.id
|
||||
)
|
||||
for user in self.room.room.users
|
||||
):
|
||||
return team.id
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
min_count = min(team_counts.values())
|
||||
for team_id, count in team_counts.items():
|
||||
if count == min_count:
|
||||
return team_id
|
||||
return self.state.teams[0].id if self.state.teams else 0
|
||||
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
best_team_id = self._get_best_available_team()
|
||||
user.match_state = TeamVersusUserState(team_id=best_team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
|
||||
if not isinstance(request, ChangeTeamRequest):
|
||||
return
|
||||
|
||||
if request.team_id not in [team.id for team in self.state.teams]:
|
||||
raise InvokeException("Invalid team ID")
|
||||
|
||||
user.match_state = TeamVersusUserState(team_id=request.team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
|
||||
|
||||
MATCH_TYPE_HANDLERS = {
|
||||
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
|
||||
MatchType.TEAM_VERSUS: TeamVersusHandler,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMultiplayerRoom:
|
||||
room: MultiplayerRoom
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
start_at: datetime
|
||||
hub: "MultiplayerHub"
|
||||
match_type_handler: MatchTypeHandler
|
||||
queue: MultiplayerQueue
|
||||
_next_countdown_id: int
|
||||
_countdown_id_lock: asyncio.Lock
|
||||
_tracked_countdown: dict[int, CountdownInfo]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MultiplayerRoom,
|
||||
category: RoomCategory,
|
||||
start_at: datetime,
|
||||
hub: "MultiplayerHub",
|
||||
):
|
||||
self.room = room
|
||||
self.category = category
|
||||
self.status = RoomStatus.IDLE
|
||||
self.start_at = start_at
|
||||
self.hub = hub
|
||||
self.queue = MultiplayerQueue(self)
|
||||
self._next_countdown_id = 0
|
||||
self._countdown_id_lock = asyncio.Lock()
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
async def get_next_countdown_id(self) -> int:
|
||||
async with self._countdown_id_lock:
|
||||
self._next_countdown_id += 1
|
||||
return self._next_countdown_id
|
||||
|
||||
async def start_countdown(
|
||||
self,
|
||||
countdown: MultiplayerCountdown,
|
||||
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||
):
|
||||
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||
await asyncio.sleep(info.duration.total_seconds())
|
||||
if on_complete is not None:
|
||||
await on_complete(self)
|
||||
await self.stop_countdown(countdown)
|
||||
|
||||
if countdown.is_exclusive:
|
||||
await self.stop_all_countdowns(countdown.__class__)
|
||||
countdown.id = await self.get_next_countdown_id()
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
info = self._tracked_countdown.get(countdown.id)
|
||||
if info is None:
|
||||
return
|
||||
del self._tracked_countdown[countdown.id]
|
||||
self.room.active_countdowns.remove(countdown)
|
||||
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||
if info.task is not None and not info.task.done():
|
||||
info.task.cancel()
|
||||
|
||||
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
|
||||
for countdown in list(self._tracked_countdown.values()):
|
||||
if isinstance(countdown.countdown, typ):
|
||||
await self.stop_countdown(countdown.countdown)
|
||||
|
||||
|
||||
class _MatchServerEvent(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class CountdownStartedEvent(_MatchServerEvent):
|
||||
countdown: MultiplayerCountdown
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class CountdownStoppedEvent(_MatchServerEvent):
|
||||
id: int
|
||||
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||
|
||||
|
||||
class GameplayAbortReason(IntEnum):
|
||||
LOAD_TOOK_TOO_LONG = 0
|
||||
HOST_ABORTED = 1
|
||||
|
||||
|
||||
class MatchStartedEventDetail(TypedDict):
|
||||
room_type: Literal["playlists", "head_to_head", "team_versus"]
|
||||
team: dict[int, Literal["blue", "red"]] | None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
# ruff: noqa: ARG002
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
@@ -118,10 +118,7 @@ class ChannelMessageBase(NotificationDetail):
|
||||
channel_type: "ChannelType",
|
||||
) -> Self:
|
||||
try:
|
||||
avatar_url = (
|
||||
getattr(user, "avatar_url", "https://lazer-data.g0v0.top/default.jpg")
|
||||
or "https://lazer-data.g0v0.top/default.jpg"
|
||||
)
|
||||
avatar_url = user.avatar_url or "https://lazer-data.g0v0.top/default.jpg"
|
||||
except Exception:
|
||||
avatar_url = "https://lazer-data.g0v0.top/default.jpg"
|
||||
instance = cls(
|
||||
@@ -160,7 +157,7 @@ class ChannelMessageTeam(ChannelMessageBase):
|
||||
cls,
|
||||
message: "ChatMessage",
|
||||
user: "User",
|
||||
) -> ChannelMessageTeam:
|
||||
) -> Self:
|
||||
from app.database import ChannelType
|
||||
|
||||
return super().init(message, user, [], ChannelType.TEAM)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# OAuth 相关模型 # noqa: I002
|
||||
# OAuth 相关模型
|
||||
from typing import Annotated, Any, cast
|
||||
from typing_extensions import Doc
|
||||
|
||||
@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
token_type: str = "Bearer" # noqa: S105
|
||||
expires_in: int
|
||||
refresh_token: str
|
||||
scope: str = "*"
|
||||
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
|
||||
class OAuth2ClientCredentialsBearer(OAuth2):
|
||||
def __init__(
|
||||
self,
|
||||
tokenUrl: Annotated[
|
||||
tokenUrl: Annotated[ # noqa: N803
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
|
||||
"""
|
||||
),
|
||||
],
|
||||
refreshUrl: Annotated[
|
||||
refreshUrl: Annotated[ # noqa: N803
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
|
||||
20
app/models/playlist.py
Normal file
20
app/models/playlist.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: int = Field(default=0, ge=-1)
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str = ""
|
||||
ruleset_id: int = 0
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool = False
|
||||
playlist_order: int = 0
|
||||
played_at: datetime | None = None
|
||||
star_rating: float = 0.0
|
||||
freestyle: bool = False
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
||||
|
||||
|
||||
@@ -1,37 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalRMeta:
|
||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||
use_abbr: bool = True
|
||||
|
||||
|
||||
class SignalRUnionMessage(BaseModel):
|
||||
union_type: ClassVar[int]
|
||||
|
||||
|
||||
class Transport(BaseModel):
|
||||
transport: str
|
||||
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
|
||||
|
||||
|
||||
class NegotiateResponse(BaseModel):
|
||||
connectionId: str
|
||||
connectionToken: str
|
||||
negotiateVersion: int = 1
|
||||
availableTransports: list[Transport]
|
||||
|
||||
|
||||
class UserState(BaseModel):
|
||||
connection_id: str
|
||||
connection_token: str
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from .score import (
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .signalr import SignalRMeta, UserState
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SpectatedUserState(IntEnum):
|
||||
Idle = 0
|
||||
Playing = 1
|
||||
Paused = 2
|
||||
Passed = 3
|
||||
Failed = 4
|
||||
Quit = 5
|
||||
|
||||
|
||||
class SpectatorState(BaseModel):
|
||||
beatmap_id: int | None = None
|
||||
ruleset_id: int | None = None # 0,1,2,3
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
state: SpectatedUserState
|
||||
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SpectatorState):
|
||||
return False
|
||||
return (
|
||||
self.beatmap_id == other.beatmap_id
|
||||
and self.ruleset_id == other.ruleset_id
|
||||
and self.mods == other.mods
|
||||
and self.state == other.state
|
||||
)
|
||||
|
||||
|
||||
class ScoreProcessorStatistics(BaseModel):
|
||||
base_score: float
|
||||
maximum_base_score: float
|
||||
accuracy_judgement_count: int
|
||||
combo_portion: float
|
||||
bonus_portion: float
|
||||
|
||||
|
||||
class FrameHeader(BaseModel):
|
||||
total_score: int
|
||||
accuracy: float
|
||||
combo: int
|
||||
max_combo: int
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
score_processor_statistics: ScoreProcessorStatistics
|
||||
received_time: datetime.datetime
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("received_time", mode="before")
|
||||
@classmethod
|
||||
def validate_timestamp(cls, v: Any) -> datetime.datetime:
|
||||
if isinstance(v, list):
|
||||
return v[0]
|
||||
if isinstance(v, datetime.datetime):
|
||||
return v
|
||||
if isinstance(v, int | float):
|
||||
return datetime.datetime.fromtimestamp(v, tz=datetime.UTC)
|
||||
if isinstance(v, str):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
||||
|
||||
|
||||
# class ReplayButtonState(IntEnum):
|
||||
# NONE = 0
|
||||
# LEFT1 = 1
|
||||
# RIGHT1 = 2
|
||||
# LEFT2 = 4
|
||||
# RIGHT2 = 8
|
||||
# SMOKE = 16
|
||||
|
||||
|
||||
class LegacyReplayFrame(BaseModel):
|
||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
mouse_x: float | None = None
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
|
||||
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
header: FrameHeader
|
||||
frames: list[LegacyReplayFrame]
|
||||
|
||||
|
||||
# Use for server
|
||||
class APIUser(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ScoreInfo(BaseModel):
|
||||
mods: list[APIMod]
|
||||
user: APIUser
|
||||
ruleset: int
|
||||
maximum_statistics: ScoreStatistics
|
||||
id: int | None = None
|
||||
total_score: int | None = None
|
||||
accuracy: float | None = None
|
||||
max_combo: int | None = None
|
||||
combo: int | None = None
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoreScore(BaseModel):
|
||||
score_info: ScoreInfo
|
||||
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StoreClientState(UserState):
|
||||
state: SpectatorState | None = None
|
||||
beatmap_status: BeatmapRankStatus | None = None
|
||||
checksum: str | None = None
|
||||
ruleset_id: int | None = None
|
||||
score_token: int | None = None
|
||||
watched_user: set[int] = Field(default_factory=set)
|
||||
score: StoreScore | None = None
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.path import STATIC_DIR
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +14,7 @@ class BeatmapTags(BaseModel):
|
||||
|
||||
|
||||
ALL_TAGS: dict[int, BeatmapTags] = {}
|
||||
logger = log("BeatmapTag")
|
||||
|
||||
|
||||
def load_tags() -> None:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
用户页面编辑相关的API模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@@ -56,3 +54,47 @@ class ValidateBBCodeResponse(BaseModel):
|
||||
valid: bool = Field(description="BBCode是否有效")
|
||||
errors: list[str] = Field(default_factory=list, description="错误列表")
|
||||
preview: dict[str, str] = Field(description="预览内容")
|
||||
|
||||
|
||||
class UserpageError(Exception):
|
||||
"""用户页面处理错误基类"""
|
||||
|
||||
def __init__(self, message: str, code: str = "userpage_error"):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ContentTooLongError(UserpageError):
|
||||
"""内容过长错误"""
|
||||
|
||||
def __init__(self, current_length: int, max_length: int):
|
||||
message = f"Content too long. Maximum {max_length} characters allowed, got {current_length}."
|
||||
super().__init__(message, "content_too_long")
|
||||
self.current_length = current_length
|
||||
self.max_length = max_length
|
||||
|
||||
|
||||
class ContentEmptyError(UserpageError):
|
||||
"""内容为空错误"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Content cannot be empty.", "content_empty")
|
||||
|
||||
|
||||
class BBCodeValidationError(UserpageError):
|
||||
"""BBCode验证错误"""
|
||||
|
||||
def __init__(self, errors: list[str]):
|
||||
message = f"BBCode validation failed: {'; '.join(errors)}"
|
||||
super().__init__(message, "bbcode_validation_error")
|
||||
self.errors = errors
|
||||
|
||||
|
||||
class ForbiddenTagError(UserpageError):
|
||||
"""禁止标签错误"""
|
||||
|
||||
def __init__(self, tag: str):
|
||||
message = f"Forbidden tag '{tag}' is not allowed."
|
||||
super().__init__(message, "forbidden_tag")
|
||||
self.tag = tag
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""V1 API 用户相关模型"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -46,10 +44,10 @@ class PlayerStatsResponse(BaseModel):
|
||||
class PlayerEventItem(BaseModel):
|
||||
"""玩家事件项目"""
|
||||
|
||||
userId: int
|
||||
userId: int # noqa: N815
|
||||
name: str
|
||||
mapId: int | None = None
|
||||
setId: int | None = None
|
||||
mapId: int | None = None # noqa: N815
|
||||
setId: int | None = None # noqa: N815
|
||||
artist: str | None = None
|
||||
title: str | None = None
|
||||
version: str | None = None
|
||||
@@ -88,7 +86,7 @@ class PlayerInfo(BaseModel):
|
||||
custom_badge_icon: str
|
||||
custom_badge_color: str
|
||||
userpage_content: str
|
||||
recentFailed: int
|
||||
recentFailed: int # noqa: N815
|
||||
social_discord: str | None = None
|
||||
social_youtube: str | None = None
|
||||
social_twitter: str | None = None
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# from app.signalr import signalr_router as signalr_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
@@ -25,5 +22,4 @@ __all__ = [
|
||||
"private_router",
|
||||
"redirect_api_router",
|
||||
"redirect_router",
|
||||
# "signalr_router",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -19,11 +17,10 @@ from app.const import BANCHOBOT_ID
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.geoip import GeoIPService, IPAddress
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.log import log
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
@@ -40,12 +37,13 @@ from app.service.verification_service import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Header, Request
|
||||
from fastapi import APIRouter, Form, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import exists, select
|
||||
|
||||
logger = log("Auth")
|
||||
|
||||
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
)
|
||||
async def register_user(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
|
||||
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
|
||||
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
|
||||
geoip: GeoIPService,
|
||||
client_ip: IPAddress,
|
||||
):
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
@@ -126,22 +124,22 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
client_ip = get_client_ip(request)
|
||||
country_code = "CN" # 默认国家代码
|
||||
country_code = None # 默认国家代码
|
||||
|
||||
try:
|
||||
# 查询 IP 地理位置
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
if geo_info and (country_code := geo_info.get("country_iso")):
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||
if country_code is None:
|
||||
country_code = "CN"
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
@@ -158,7 +156,7 @@ async def register_user(
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country_code=country_code, # 根据 IP 地理位置设置国家
|
||||
country_code=country_code,
|
||||
join_date=utcnow(),
|
||||
last_visit=utcnow(),
|
||||
is_supporter=settings.enable_supporter_for_all_users,
|
||||
@@ -201,19 +199,21 @@ async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_agent: UserAgentInfo,
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
ip_address: IPAddress,
|
||||
grant_type: Annotated[
|
||||
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
|
||||
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
|
||||
],
|
||||
client_id: Annotated[int, Form(..., description="客户端 ID")],
|
||||
client_secret: Annotated[str, Form(..., description="客户端密钥")],
|
||||
redis: Redis,
|
||||
geoip: GeoIPService,
|
||||
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
|
||||
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*')")] = "*",
|
||||
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
|
||||
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
|
||||
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
|
||||
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
@@ -311,8 +311,6 @@ async def oauth_token(
|
||||
)
|
||||
token_id = token.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
@@ -363,9 +361,7 @@ async def oauth_token(
|
||||
await LoginSessionService.mark_session_verified(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, web_uuid
|
||||
)
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
logger.debug(f"New location login detected but email verification disabled, auto-verifying user {user_id}")
|
||||
else:
|
||||
# 不是新设备登录,正常登录
|
||||
await LoginLogService.record_login(
|
||||
@@ -389,7 +385,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=scope,
|
||||
@@ -442,7 +438,7 @@ async def oauth_token(
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
@@ -508,11 +504,11 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 打印jwt
|
||||
logger.info(f"[Auth] Generated JWT for user {user_id}: {access_token}")
|
||||
logger.info(f"Generated JWT for user {user_id}: {access_token}")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
@@ -557,7 +553,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
@@ -571,16 +567,14 @@ async def oauth_token(
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
# 请求密码重置
|
||||
@@ -599,20 +593,16 @@ async def request_password_reset(
|
||||
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
reset_code: str = Form(..., description="重置验证码"),
|
||||
new_password: str = Form(..., description="新密码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
reset_code: Annotated[str, Form(..., description="重置验证码")],
|
||||
new_password: Annotated[str, Form(..., description="新密码")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user