chore(merge): merge #53

## Removed

- SignalR server
- `msgpack_lazer_api`
- Unused services

## Changed

- Move once tasks and scheduled tasks into `app.tasks`
- Improve Logs
  - Yellow: Tasks
  - Blue: Services
  - Magenta: Fetcher
  - Dark green: Uvicorn
  - Red: System
  - `#FFC1C1`: Normal
- Redis: use multiple logical databases
  - db0: general cache (`redis_client`)
  - db1: message cache (`redis_message_client`)
  - db2: binary storage (`redis_binary_client`)
  - db3: rate limiting (`redis_rate_limit_client`)
- API: move user page API (`/api/v2/users/{user_id}/page`, `/api/v2/me/validate-bbcode`) into private APIs (`/api/private/user/page`, `/api/private/user/validate-bbcode`)
- Remove `from __future__ import annotations` to avoid `ForwardRef` problems
- Assets Proxy: use decorators to simplify code
- Ruff: add rules
- API Router: use Annotated-style dependency injections.
- Database: rename filenames to easily find the model

## Documents

- CONTRIBUTING.md
- AGENTS.md
- copilot-instructions.md
This commit is contained in:
MingxuanGame
2025-10-04 16:37:40 +08:00
committed by GitHub
260 changed files with 3152 additions and 10093 deletions

View File

@@ -107,6 +107,6 @@
80,
8080
],
"postCreateCommand": "uv sync --dev && uv run alembic upgrade head && uv run pre-commit install && cd packages/msgpack_lazer_api && cargo check && cd ../../spectator-server && dotnet restore",
"postCreateCommand": "uv sync --dev && uv run alembic upgrade head && uv run pre-commit install && cd spectator-server && dotnet restore",
"remoteUser": "vscode"
}

View File

@@ -6,7 +6,7 @@ MYSQL_DATABASE="osu_api"
MYSQL_USER="osu_api"
MYSQL_PASSWORD="password"
MYSQL_ROOT_PASSWORD="password"
REDIS_URL="redis://127.0.0.1:6379/0"
REDIS_URL="redis://127.0.0.1:6379"
# JWT Settings
# Use `openssl rand -hex 32` to generate a secure key

184
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,184 @@
# copilot-instruction
> 此文件是 AGENTS.md 的复制。一切以 AGENTS.md 为主。
> 使用自动化与 AI 代理GitHub Copilot、依赖/CI 机器人,以及仓库中的运行时调度器/worker的指导原则适用于 g0v0-server 仓库。
---
## API 参考
本项目必须保持与公开的 osu! API 兼容。在添加或映射端点时请参考:
- **v1旧版** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
- **v2OpenAPI** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
任何在 `app/router/v1/``app/router/v2/``app/router/notification/` 中的实现必须与官方规范保持一致。自定义或实验性的端点应放在 `app/router/private/` 中。
---
## 代理类别
允许的代理分为三类:
- **代码生成/补全代理**(如 GitHub Copilot 或其他 LLM—— **仅当** 有维护者审核并批准输出时允许使用。
- **自动维护代理**(如 Dependabot、Renovate、pre-commit.ci—— 允许使用,但必须遵守严格的 PR 和 CI 政策。
- **运行时/后台代理**调度器、worker—— 属于产品代码的一部分;必须遵守生命周期、并发和幂等性规范。
所有由代理生成或建议的更改必须遵守以下规则。
---
## 所有代理的规则
1. **单一职责的 PR。** 代理的 PR 必须只解决一个问题(一个功能、一个 bug 修复或一次依赖更新)。提交信息应使用 Angular 风格(如 `feat(api): add ...`)。
2. **通过 Lint 与 CI 检查。** 每个 PR包括代理创建的在合并前必须通过 `pyright``ruff``pre-commit` 钩子和仓库 CI。PR 中应附带 CI 运行结果链接。
3. **绝不可提交敏感信息。** 代理不得提交密钥、密码、token 或真实 `.env` 值。如果检测到可能的敏感信息,代理必须中止并通知指定的维护者。
4. **API 位置限制。** 不得在 `app/router/v1``app/router/v2` 下添加新的公开端点,除非该端点在官方 v1/v2 规范中存在。自定义或实验性端点必须放在 `app/router/private/`
5. **保持公共契约稳定。** 未经批准的迁移计划,不得随意修改响应 schema、路由前缀或其他公共契约。若有变更PR 中必须包含明确的兼容性说明。
---
## Copilot / LLM 使用
> 关于在本仓库中使用 GitHub Copilot 和其他基于 LLM 的辅助工具的统一指导。
### 关键项目结构(需要了解的内容)
- **应用入口:** `main.py` —— FastAPI 应用,包含启动/关闭生命周期管理fetchers、GeoIP、调度器、缓存与健康检查、Redis 消息、统计、成就系统)。
- **路由:** `app/router/` 包含所有路由组。主要的路由包括:
- `v1/`v1 端点)
- `v2/`v2 端点)
- `notification/` 路由(聊天/通知子系统)
- `auth.py`(认证/token 流程)
- `private/`(自定义或实验性的端点)
**规则:** `v1/``v2/` 必须与官方 API 对应。仅内部或实验端点应放在 `app/router/private/`
- **模型与数据库工具:**
- SQLModel/ORM 模型在 `app/database/`
- 非数据库模型在 `app/models/`
- 修改模型/schema 时必须生成 Alembic 迁移,并手动检查生成的 SQL 与索引。
- **服务层:** `app/service/` 保存领域逻辑(如缓存工具、通知/邮件逻辑)。复杂逻辑应放在 service而不是路由处理器中。
- **任务:** `app/tasks/` 保存任务(定时任务、启动任务、关闭任务)。
- 均在 `__init__.py` 进行导出。
- 对于启动任务/关闭任务,在 `main.py``lifespan` 调用。
- 定时任务使用 APScheduler
- **缓存与依赖:** 使用 `app/dependencies/` 提供的 Redis 依赖和缓存服务(遵循现有 key 命名约定,如 `user:{id}:...`)。
- **日志:** 使用 `app/log` 提供的日志工具。
### 实用工作流(提示模式)
- **添加 v2 端点(正确方式):** 在 `app/router/v2/` 下添加文件,导出路由,实现基于数据库与缓存依赖的异步处理函数。**不得**在 v1/v2 添加非官方端点。
- **添加自定义端点:** 放在 `app/router/private/`,保持处理器精简,将业务逻辑放入 `app/service/`
- **鉴权:** 使用 [`app.dependencies.user`](../app/dependencies/user.py) 提供的依赖注入,如 `ClientUser``get_current_user`,参考下方。
```python
from typing import Annotated
from fastapi import Security
from app.dependencies.user import ClientUser, get_current_user
@router.get("/some-api")
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
...
@router.get("/some-client-api")
async def _(current_user: ClientUser):
...
```
- **添加后台任务:** 将任务逻辑写在 `app/service/_job.py`(幂等、可重试)。调度器入口放在 `app/scheduler/_scheduler.py`,并在应用生命周期注册。
- **数据库 schema 变更:** 修改 `app/models/` 中的 SQLModel 模型,运行 `alembic revision --autogenerate`,检查迁移并本地测试 `alembic upgrade head` 后再提交。
- **缓存写入与响应:** 使用现有的 `UserResp` 模式和 `UserCacheService`;异步缓存写入应使用后台任务。
### 提示指导(给 LLM/Copilot 的输入)
- 明确文件位置和限制(如:`Add an async endpoint under app/router/private/... DO NOT add to app/router/v1 or v2`)。
- 要求异步处理函数、依赖注入 DB/Redis、复用已有服务/工具、加上类型注解,并生成最小化 pytest 测试样例。
### 约定与质量要求
- **使用 Annotated-style 依赖注入** 在路由处理器中。
- **提交信息风格:** `type(scope): subject`Angular 风格)。
- **优先异步:** 路由必须为异步函数;避免阻塞事件循环。
- **关注点分离:** 业务逻辑应放在 service而不是路由中。
- **错误处理:** 客户端错误用 `HTTPException`,服务端错误使用结构化日志。
- **类型与 lint** 在请求评审前,代码必须通过 `pyright` 和 `ruff` 检查。
- **注释:** 避免过多注释,仅为晦涩逻辑添加简洁的“魔法注释”。
- **日志:** 使用 `app.log` 提供的 `log` 函数获取 logger 实例。(服务、任务除外)
### 工具参考
```
uv sync
pre-commit install
pre-commit run --all-files
pyright
ruff .
alembic revision --autogenerate -m "feat(db): ..."
alembic upgrade head
uvicorn main:app --reload --host 0.0.0.0 --port 8000
```
### PR 范围指导
- 保持 PR 专注:一次只做一件事(如端点或重构,不要混合)。
- 不确定时,请参考现有服务,并添加简短说明性注释。
### PR 审核规则
> GitHub Copilot PR review 可参考。
1. 如果 PR 修改了端点,简要说明端点的用途和预期行为。同时检查是否满足上述的 API 位置限制。
2. 如果 PR 修改了数据库模型,必须包含 Alembic 迁移。检查迁移的 SQL 语句和索引是否合理。
3. 修改的其他功能需要提供简短的说明。
4. 提供性能优化的建议(见下文)。
---
## 性能优化提示
以下为结合本仓库架构FastAPI + SQLModel/SQLAlchemy、Redis 缓存、后台调度器)总结的性能优化建议:
### 数据库
- **仅选择必要字段。** 使用 `select(Model.col1, Model.col2)`,避免 `select(Model)`。
```py
stmt = select(User.id, User.username).where(User.active == True)
rows = await session.execute(stmt)
```
- **使用 `select(exists())` 检查存在性。** 避免加载整行:
```py
from sqlalchemy import select, exists
exists_stmt = select(exists().where(User.id == some_id))
found = await session.scalar(exists_stmt)
```
- **避免 N+1 查询。** 需要关联对象时用 `selectinload`、`joinedload`。
- **批量操作。** 插入/更新时应批量执行,并放在一个事务中,而不是多个小事务。
### 耗时任务
- 如果这个任务来自 API Router请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
- 其他情况,使用 `app.utils` 的 `bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
---
## 部分 LLM 的额外要求
### Claude Code
- 禁止创建额外的测试脚本。

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import datetime
from enum import Enum
import importlib.util

1
.gitignore vendored
View File

@@ -222,7 +222,6 @@ newrelic.ini
logs/
osu-server-spectator-master/*
spectator-server/
.github/copilot-instructions.md
osu-web-master/*
osu-web-master/.env.dusk.local.example
osu-web-master/.env.example

203
AGENTS.md
View File

@@ -1,117 +1,118 @@
# AGENTS.md
# AGENTS
> Guidelines for using automation and AI agents (GitHub Copilot, dependency/CI bots, and in-repo runtime schedulers/workers) with the g0v0-server repository.
> 使用自动化与 AI 代理GitHub Copilot、依赖/CI 机器人,以及仓库中的运行时调度器/worker的指导原则适用于 g0v0-server 仓库。
---
## API References
## API 参考
This project must stay compatible with the public osu! APIs. Use these references when adding or mapping endpoints:
本项目必须保持与公开的 osu! API 兼容。在添加或映射端点时请参考:
- **v1 (legacy):** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
- **v2 (OpenAPI):** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
- **v1(旧版):** [https://github.com/ppy/osu-api/wiki](https://github.com/ppy/osu-api/wiki)
- **v2OpenAPI** [https://osu.ppy.sh/docs/openapi.yaml](https://osu.ppy.sh/docs/openapi.yaml)
Any implementation in `app/router/v1/`, `app/router/v2/`, or `app/router/notification/` must match official endpoints from the corresponding specification above. Custom or experimental endpoints belong in `app/router/private/`.
任何在 `app/router/v1/``app/router/v2/` `app/router/notification/` 中的实现必须与官方规范保持一致。自定义或实验性的端点应放在 `app/router/private/` 中。
---
## Agent Categories
## 代理类别
Agents are allowed in three categories:
允许的代理分为三类:
- **Code authoring / completion agents** (e.g. GitHub Copilot or other LLMs) — allowed **only** when a human maintainer reviews and approves the output.
- **Automated maintenance agents** (e.g. Dependabot, Renovate, pre-commit.ci) — allowed but must follow strict PR and CI policies.
- **Runtime / background agents** (schedulers, workers) — part of the product code; must follow lifecycle, concurrency, and idempotency conventions.
- **代码生成/补全代理**(如 GitHub Copilot 或其他 LLM—— **仅当** 有维护者审核并批准输出时允许使用。
- **自动维护代理**(如 DependabotRenovatepre-commit.ci)—— 允许使用,但必须遵守严格的 PR CI 政策。
- **运行时/后台代理**调度器、worker—— 属于产品代码的一部分;必须遵守生命周期、并发和幂等性规范。
All changes produced or suggested by agents must comply with the rules below.
所有由代理生成或建议的更改必须遵守以下规则。
---
## Rules for All Agents
## 所有代理的规则
1. **Human review required.** Any code, configuration, or documentation generated by an AI or automation agent must be reviewed and approved by a human maintainer familiar with g0v0-server. Do not merge agent PRs without explicit human approval.
2. **Single-responsibility PRs.** Agent PRs must address one concern only (one feature, one bugfix, or one dependency update). Use Angular-style commit messages (e.g. `feat(api): add ...`).
3. **Lint & CI compliance.** Every PR (including agent-created ones) must pass `pyright`, `ruff`, `pre-commit` hooks, and the repository CI before merging. Include links to CI runs in the PR.
4. **Never commit secrets.** Agents must not add keys, passwords, tokens, or real `.env` values. If a suspected secret is detected, the agent must abort and notify a designated human.
5. **API location constraints.** Do not add new public endpoints under `app/router/v1` or `app/router/v2` unless the endpoints exist in the official v1/v2 specs. Custom or experimental endpoints must go under `app/router/private/`.
6. **Stable public contracts.** Avoid changing response schemas, route prefixes, or other public contracts without an approved migration plan and explicit compatibility notes in the PR.
1. **单一职责的 PR。** 代理的 PR 必须只解决一个问题(一个功能、一个 bug 修复或一次依赖更新)。提交信息应使用 Angular 风格(如 `feat(api): add ...`)。
2. **通过 Lint 与 CI 检查。** 每个 PR包括代理创建的在合并前必须通过 `pyright``ruff``pre-commit` 钩子和仓库 CI。PR 中应附带 CI 运行结果链接。
3. **绝不可提交敏感信息。** 代理不得提交密钥、密码、token 或真实 `.env` 值。如果检测到可能的敏感信息,代理必须中止并通知指定的维护者。
4. **API 位置限制。** 不得在 `app/router/v1``app/router/v2` 下添加新的公开端点,除非该端点在官方 v1/v2 规范中存在。自定义或实验性端点必须放在 `app/router/private/`
5. **保持公共契约稳定。** 未经批准的迁移计划,不得随意修改响应 schema、路由前缀或其他公共契约。若有变更PR 中必须包含明确的兼容性说明。
---
## Copilot / LLM Usage
## Copilot / LLM 使用
> Consolidated guidance for using GitHub Copilot and other LLM-based helpers with this repository.
> 关于在本仓库中使用 GitHub Copilot 和其他基于 LLM 的辅助工具的统一指导。
### Key project structure (what you should know)
### 关键项目结构(需要了解的内容)
- **App entry:** `main.py` — FastAPI application with lifespan startup/shutdown orchestration (fetchers, GeoIP, schedulers, cache and health checks, Redis messaging, stats, achievements).
- **应用入口:** `main.py` FastAPI 应用,包含启动/关闭生命周期管理(fetchersGeoIP、调度器、缓存与健康检查、Redis 消息、统计、成就系统)。
- **Routers:** `app/router/` contains route groups. Important routers exposed by the project include:
- **路由:** `app/router/` 包含所有路由组。主要的路由包括:
- `v1/`v1 端点)
- `v2/`v2 端点)
- `notification/` 路由(聊天/通知子系统)
- `auth.py`(认证/token 流程)
- `private/`(自定义或实验性的端点)
- `api_v1_router` (v1 endpoints)
- `api_v2_router` (v2 endpoints)
- `notification` routers (chat/notification subsystems)
- `auth_router` (authentication/token flows)
- `private_router` (internal or server-specific endpoints)
**规则:** `v1/``v2/` 必须与官方 API 对应。仅内部或实验端点应放在 `app/router/private/`
**Rules:** `v1/` and `v2/` must mirror the official APIs. Put internal-only or experimental endpoints under `app/router/private/`.
- **模型与数据库工具:**
- SQLModel/ORM 模型在 `app/database/`
- 非数据库模型在 `app/models/`
- 修改模型/schema 时必须生成 Alembic 迁移,并手动检查生成的 SQL 与索引。
- **Models & DB helpers:**
- **服务层:** `app/service/` 保存领域逻辑(如缓存工具、通知/邮件逻辑)。复杂逻辑应放在 service而不是路由处理器中。
- SQLModel/ORM models live in `app/models/`.
- DB access helpers and table-specific helpers live in `app/database/`.
- For model/schema changes, draft an Alembic migration and manually review the generated SQL and indexes before applying.
- **任务:** `app/tasks/` 保存任务(定时任务、启动任务、关闭任务)。
- 均在 `__init__.py` 进行导出。
- 对于启动任务/关闭任务,在 `main.py``lifespan` 调用。
- 定时任务使用 APScheduler
- **Services:** `app/service/` holds domain logic (e.g., user ranking calculation, caching helpers, notification/email logic). Heavy logic belongs in services rather than in route handlers.
- **缓存与依赖:** 使用 `app/dependencies/` 提供的 Redis 依赖和缓存服务(遵循现有 key 命名约定,如 `user:{id}:...`)。
- **Schedulers:** `app/scheduler/` contains scheduler starters; implement `start_*_scheduler()` and `stop_*_scheduler()` and register them in `main.py` lifespan handlers.
- **日志:** 使用 `app/log` 提供的日志工具。
- **Caching & dependencies:** Use injected Redis dependencies from `app/dependencies/` and shared cache services (follow existing key naming conventions such as `user:{id}:...`).
### 实用工作流(提示模式)
- **Rust/native extensions:** `packages/msgpack_lazer_api` is a native MessagePack encoder/decoder. When changing native code, run `maturin develop -R` and validate compatibility with Python bindings.
- **添加 v2 端点(正确方式):** 在 `app/router/v2/` 下添加文件,导出路由,实现基于数据库与缓存依赖的异步处理函数。**不得**在 v1/v2 添加非官方端点。
- **添加自定义端点:** 放在 `app/router/private/`,保持处理器精简,将业务逻辑放入 `app/service/`
- **鉴权:** 使用 [`app.dependencies.user`](./app/dependencies/user.py) 提供的依赖注入,如 `ClientUser``get_current_user`,参考下方。
```python
from typing import Annotated
from fastapi import Security
from app.dependencies.user import ClientUser, get_current_user
### Practical playbooks (prompt patterns)
- **Add a v2 endpoint (correct):** Add files under `app/router/v2/`, export the router, implement async path operations using DB and injected caching dependencies. Do **not** add non-official endpoints to v1/v2.
- **Add an internal endpoint:** Add under `app/router/private/`; keep route handlers thin and move business logic into `app/service/`.
- **Add a background job:** Put pure job logic in `app/service/_job.py` (idempotent, retry-safe). Add scheduler start/stop functions in `app/scheduler/_scheduler.py`, and register them in the app lifespan.
- **DB schema changes:** Update SQLModel models in `app/models/`, run `alembic revision --autogenerate`, inspect the migration, and validate locally with `alembic upgrade head` before committing.
- **Cache writes & responses:** Use existing `UserResp` patterns and `UserCacheService` where applicable; use background tasks for asynchronous cache writes.
@router.get("/some-api")
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
...
### Prompt guidance (what to include for LLMs/Copilot)
- Specify the exact file location and constraints (e.g. `Add an async endpoint under app/router/private/ ... DO NOT add to app/router/v1 or v2`).
- Ask for asynchronous handlers, dependency injection for DB/Redis, reuse of existing services/helpers, type annotations, and a minimal pytest skeleton.
- For native edits, require build instructions, ABI compatibility notes, and import validation steps.
@router.get("/some-client-api")
async def _(current_user: ClientUser):
...
```
### Conventions & quality expectations
- **添加后台任务:** 将任务逻辑写在 `app/service/_job.py`(幂等、可重试)。调度器入口放在 `app/scheduler/_scheduler.py`,并在应用生命周期注册。
- **数据库 schema 变更:** 修改 `app/models/` 中的 SQLModel 模型,运行 `alembic revision --autogenerate`,检查迁移并本地测试 `alembic upgrade head` 后再提交。
- **缓存写入与响应:** 使用现有的 `UserResp` 模式和 `UserCacheService`;异步缓存写入应使用后台任务。
- **Commit message style:** `type(scope): subject` (Angular-style).
- **Async-first:** Route handlers must be async; avoid blocking the event loop.
- **Separation of concerns:** Business logic should live in services, not inside route handlers.
- **Error handling:** Use `HTTPException` for client errors and structured logging for server-side issues.
- **Types & linting:** Aim for `pyright`-clean, `ruff`-clean code before requesting review.
- **Comments:** Avoid excessive inline comments. Add short, targeted comments to explain non-obvious or "magical" behavior.
### 提示指导(给 LLM/Copilot 的输入)
### Human reviewer checklist
- 明确文件位置和限制(如:`Add an async endpoint under app/router/private/... DO NOT add to app/router/v1 or v2`)。
- 要求异步处理函数、依赖注入 DB/Redis、复用已有服务/工具、加上类型注解,并生成最小化 pytest 测试样例。
- Is the code async and non-blocking, with heavy logic in `app/service/`?
- Are DB and Redis dependencies injected via the project's dependency utilities?
- Are existing cache keys and services reused consistently?
- Are tests or test skeletons present and runnable?
- If models changed: is an Alembic migration drafted, reviewed, and applied locally?
- If native code changed: was `maturin develop -R` executed and validated?
- Do `pyright` and `ruff` pass locally?
### 约定与质量要求
### Merge checklist
- **使用 Annotated-style 依赖注入** 在路由处理器中。
- **提交信息风格:** `type(scope): subject`Angular 风格)。
- **优先异步:** 路由必须为异步函数;避免阻塞事件循环。
- **关注点分离:** 业务逻辑应放在 service而不是路由中。
- **错误处理:** 客户端错误用 `HTTPException`,服务端错误使用结构化日志。
- **类型与 lint** 在请求评审前,代码必须通过 `pyright` 和 `ruff` 检查。
- **注释:** 避免过多注释,仅为晦涩逻辑添加简洁的“魔法注释”。
- **日志:** 使用 `app.log` 提供的 `log` 函数获取 logger 实例。(服务、任务除外)
- Run `uv sync` to install/update dependencies.
- Run `pre-commit` hooks and fix any failures.
- Run `pyright` and `ruff` locally and resolve issues.
- If native modules changed: run `maturin develop -R`.
- If DB migrations changed: run `alembic upgrade head` locally to validate.
### Tooling reference
### 工具参考
```
uv sync
@@ -119,34 +120,41 @@ pre-commit install
pre-commit run --all-files
pyright
ruff .
maturin develop -R # when native modules changed
alembic revision --autogenerate -m "feat(db): ..."
alembic upgrade head
uvicorn main:app --reload --host 0.0.0.0 --port 8000
```
### PR scope guidance
### PR 范围指导
- Keep PRs focused: one concern per PR (e.g., endpoint OR refactor, not both).
- Update README/config docs when adding new environment variables.
- If unsure about conventions, align with the closest existing service and leave a clarifying comment.
- 保持 PR 专注:一次只做一件事(如端点或重构,不要混合)。
- 不确定时,请参考现有服务,并添加简短说明性注释。
### PR 审核规则
> GitHub Copilot PR review 可参考。
1. 如果 PR 修改了端点,简要说明端点的用途和预期行为。同时检查是否满足上述的 API 位置限制。
2. 如果 PR 修改了数据库模型,必须包含 Alembic 迁移。检查迁移的 SQL 语句和索引是否合理。
3. 修改的其他功能需要提供简短的说明。
4. 提供性能优化的建议(见下文)。
---
## Performance Tips
## 性能优化提示
Below are practical, project-specific performance tips derived from this repository's architecture (FastAPI + SQLModel/SQLAlchemy, Redis caching, background schedulers, and a Rust-native messagepack module).
以下为结合本仓库架构(FastAPI + SQLModel/SQLAlchemyRedis 缓存、后台调度器)总结的性能优化建议:
### Database
### 数据库
- **Select only required fields.** Fetch only the columns you need using `select(Model.col1, Model.col2)` instead of `select(Model)`.
- **仅选择必要字段。** 使用 `select(Model.col1, Model.col2)`,避免 `select(Model)`。
```py
stmt = select(User.id, User.username).where(User.active == True)
rows = await session.execute(stmt)
```
- **Use **``** for existence checks.** This avoids loading full rows:
- **使用 `select(exists())` 检查存在性。** 避免加载整行:
```py
from sqlalchemy import select, exists
@@ -154,34 +162,21 @@ exists_stmt = select(exists().where(User.id == some_id))
found = await session.scalar(exists_stmt)
```
- **Avoid N+1 queries.** Use relationship loading strategies (`selectinload`, `joinedload`) when you need related objects.
- **避免 N+1 查询。** 需要关联对象时用 `selectinload`、`joinedload`。
- **Batch operations.** For inserts/updates, use bulk or batched statements inside a single transaction rather than many small transactions.
- **批量操作。** 插入/更新时应批量执行,并放在一个事务中,而不是多个小事务。
- **Indexes & EXPLAIN.** Add indexes on frequently filtered columns and use `EXPLAIN ANALYZE` to inspect slow queries.
- **Cursor / keyset pagination.** Prefer keyset pagination for large result sets instead of `OFFSET`/`LIMIT` to avoid high-cost scans.
### 耗时任务
### Caching & Redis
- 如果这个任务来自 API Router请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
- 其他情况,使用 `app.utils` 的 `bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
- **Cache hot reads.** Use `UserCacheService` to cache heavy or frequently-requested responses and store compact serialized forms (e.g., messagepack via the native module).
---
- **Use pipelines and multi/exec.** When performing multiple Redis commands, pipeline them to reduce roundtrips.
## 部分 LLM 的额外要求
- **Set appropriate TTLs.** Avoid never-expiring keys; choose TTLs that balance freshness and read amplification.
### Claude Code
- **Prevent cache stampedes.** Use early recompute with jitter or distributed locks (Redis `SET NX` or a small lock library) to avoid many processes rebuilding the same cache.
- 禁止创建额外的测试脚本。
- **Atomic operations with Lua.** For complex multi-step Redis changes, consider a Lua script to keep operations atomic and fast.
### Background & Long-running Tasks
- **BackgroundTasks for lightweight work.** FastAPI's `BackgroundTasks` is fine for quick follow-up work (send email, async cache write). For heavy or long tasks, use a scheduler/worker (e.g., a dedicated async worker or job queue).
- **Use schedulers or workers for heavy jobs.** For expensive recalculations, use the repository's `app/scheduler/` pattern or an external worker system. Keep request handlers responsive — return quickly and delegate.
- **Throttling & batching.** When processing many items, batch them and apply concurrency limits (semaphore) to avoid saturating DB/Redis.
### API & Response Performance
- **Compress large payloads.** Enable gzip/deflate for large JSON responses

View File

@@ -54,11 +54,155 @@ dotnet run --project osu.Server.Spectator --urls "http://0.0.0.0:8086"
uv sync
```
## 代码质量和代码检查
## 开发规范
我们使用 `pre-commit` 在提交之前执行代码质量标准。这确保所有代码都通过 `ruff`(用于代码检查和格式化)和 `pyright`(用于类型检查)的检查。
### 项目结构
### 设置
以下是项目主要目录和文件的结构说明:
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
- `app/`: 存放所有核心应用代码。
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
- `models/`: 定义非数据库模型和其他模型。
- `tasks/`: 包含由 APScheduler 调度的后台任务和启动/关闭任务。
- `dependencies/`: 管理 FastAPI 的依赖项注入。
- `achievements/`: 存放与成就相关的逻辑。
- `storage/`: 存储服务代码。
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
- `middleware/`: 定义中间件,例如会话验证。
- `helpers/`: 存放辅助函数和工具类。
- `config.py`: 应用配置,使用 pydantic-settings 管理。
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
- `log.py`: 日志记录模块,提供统一的日志接口。
- `const.py`: 定义常量。
- `path.py`: 定义跨文件使用的常量。
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
- `static/`: 存放静态文件,如 `mods.json`
### 数据库模型定义
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构:
```python
class ModelBase(SQLModel):
# 定义共有内容
...
class Model(ModelBase, table=True):
# 定义数据库表内容
...
class ModelResp(ModelBase):
# 定义响应内容
...
@classmethod
def from_db(cls, db: Model) -> "ModelResp":
# 从数据库模型转换
...
```
数据库模块名应与表名相同,定义了多个模型的除外。
如果你需要使用 Session使用 `app.dependencies.database` 提供的 `with_db`,注意手动使用 `COMMIT`
```python
from app.dependencies.database import with_db
async with with_db() as session:
...
```
### Redis
根据你需要的用途选择对应的 Redis 客户端。如果你的用途较为复杂或趋向一个较大的系统,考虑再创建一个 Redis 连接。
- `redis_client` (db0):标准用途,存储字符串、哈希等常规数据。
- `redis_message_client` (db1):用于消息缓存,存储聊天记录等。
- `redis_binary_client` (db2):用于存储二进制数据,如音频文件等。
- `redis_rate_limit_client` (db3):仅用于 FastAPI-Limiter 使用。
### API Router
所有的 API Router 定义在 `app.router` 里:
- `app/router/v2` 存放所有 osu! v2 API 实现,**不允许添加额外的,原 v2 API 不存在的 Endpoint**
- `app/router/notification` **存放所有 osu! v2 API 聊天、通知和 BanchoBot 的实现,不允许添加额外的,原 v2 API 不存在的 Endpoint**
- `app/router/v1` 存放所有 osu! v1 API 实现,**不允许添加额外的,原 v1 API 不存在的 Endpoint**
- `app/router/auth.py` 存放账户鉴权/登录的 API
- `app/router/private` 存放服务器自定义 API (g0v0 API),供其他服务使用
任何 Router 需要满足:
- 使用 Annotated-style 的依赖注入
- 对于已经存在的依赖注入如 Database 和 Redis使用 `app.dependencies` 中的实现
- 需要拥有文档
- 如果返回需要资源代理,使用 `app.helpers.asset_proxy_helper``asset_proxy_response` 装饰器。
- 如果需要记录日志,请使用 `app.log` 提供的 `log` 函数获取一个 logger 实例
#### 鉴权
如果这个 Router 可以为公开使用客户端、前端、OAuth 程序),考虑使用 `Security(get_current_user, scopes=["some_scope"])`,例如:
```python
from typing import Annotated
from fastapi import Security
from app.dependencies.user import get_current_user
@router.get("/some-api")
async def _(current_user: Annotated[User, Security(get_current_user, scopes=["public"])]):
...
```
其中 scopes 选择请参考 [`app.dependencies.user`](./app/dependencies/user.py) 的 `oauth2_code` 中的 `scopes`
如果这个 Router 仅限客户端和前端使用,请使用 `ClientUser` 依赖注入。
```python
from app.dependencies.user import ClientUser
@router.get("/some-api")
async def _(current_user: ClientUser):
...
```
此外还存在 `get_current_user_and_token``get_client_user_and_token` 变种,用来同时获得当前用户的 token。
### Service
所有的核心业务逻辑放在 `app.service` 里:
- 业务逻辑需要要以类实现
- 日志只需要使用 `app.log` 中的 `logger` 即可。服务器会对 Service 的日志进行包装。
### 定时任务/启动任务/关闭任务
均定义在 `app.tasks` 里。
- 均在 `__init__.py` 进行导出
- 对于启动任务/关闭任务,在 `main.py``lifespan` 调用。
- 定时任务使用 APScheduler
### 耗时任务
- 如果这个任务来自 API Router请使用 FastAPI 提供的 [`BackgroundTasks`](https://fastapi.tiangolo.com/tutorial/background-tasks)
- 其他情况,使用 `app.utils``bg_tasks`,它提供了与 FastAPI 的 `BackgroundTasks` 类似的功能。
### 代码质量和代码检查
使用 `pre-commit` 在提交之前执行代码质量标准。这确保所有代码都通过 `ruff`(用于代码检查和格式化)和 `pyright`(用于类型检查)的检查。
#### 设置
要设置 `pre-commit`,请运行以下命令:
@@ -70,19 +214,9 @@ pre-commit install
pre-commit 不提供 pyright 的 hook您需要手动运行 `pyright` 检查类型错误。
## 提交信息指南
### 提交信息指南
我们遵循 [AngularJS 提交规范](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#commit-message-format) 来编写提交信息。这使得在查看项目历史记录时,信息更加可读且易于理解。
每条提交信息由 **标题**、**主体**和 **页脚** 三部分组成。
```
<type>(<scope>): <subject>
<BLANK LINE>
<body>
<BLANK LINE>
<footer>
```
遵循 [AngularJS 提交规范](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#commit-message-format) 来编写提交信息。
**类型** 必须是以下之一:
@@ -97,40 +231,16 @@ pre-commit 不提供 pyright 的 hook您需要手动运行 `pyright` 检查
* **ci**:持续集成相关的更改
* **deploy**: 部署相关的更改
**范围** 可以是任何指定提交更改位置的内容。例如 `api``db``auth` 等等。
**范围** 可以是任何指定提交更改位置的内容。例如 `api``db``auth` 等等。对整个项目的更改使用 `project`
**主题** 包含对更改的简洁描述。
## 项目结构
### 持续集成检查
以下是项目主要目录和文件的结构说明
所有提交应该通过以下 CI 检查
- `main.py`: FastAPI 应用的主入口点,负责初始化和启动服务器。
- `pyproject.toml`: 项目配置文件,用于管理依赖项 (uv)、代码格式化 (Ruff) 和类型检查 (Pyright)。
- `alembic.ini`: Alembic 数据库迁移工具的配置文件。
- `app/`: 存放所有核心应用代码。
- `router/`: 包含所有 API 端点的定义,根据 API 版本和功能进行组织。
- `service/`: 存放核心业务逻辑,例如用户排名计算、每日挑战处理等。
- `database/`: 定义数据库模型 (SQLModel) 和会话管理。
- `models/`: 定义非数据库模型和其他模型。
- `scheduler/`: 包含由 APScheduler 调度的后台任务。
- `dependencies/`: 管理 FastAPI 的依赖项注入。
- `achievement.py`: 存放与成就相关的逻辑。
- `storage/`: 存储服务代码。
- `fetcher/`: 用于从外部服务(如 osu! 官网)获取数据的模块。
- `config.py`: 应用配置,使用 pydantic-settings 管理。
- `calculator.py`: 存放所有的计算逻辑,例如 pp 和等级。
- `migrations/`: 存放 Alembic 生成的数据库迁移脚本。
- `packages/`: 包含项目相关的独立包。
- `msgpack_lazer_api/`: 基于 Rust 的高性能支持 lazer APIMod 的 MessagePack 解析模块,用于与 osu!lazer 客户端通信。
- `static/`: 存放静态文件,如 `mods.json`
## API Router 规范
- `app/router/v2` 存放所有 osu! v2 API 实现,不允许添加额外的,原 v2 API 不存在的 Endpoint
- `app/router/notification` 存放所有 osu! v2 API 聊天和通知的实现,不允许添加额外的,原 v2 API 不存在的 Endpoint
- `app/router/v1` 存放所有 osu! v1 API 实现,不允许添加额外的,原 v1 API 不存在的 Endpoint
- `app/router/auth.py` 存放账户鉴权/登录的 API
- `app/router/private` 存放服务器自定义 API供其他服务使用
- Ruff Lint
- Pyright Lint
- pre-commit
感谢您的贡献!

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from app.database.daily_challenge import DailyChallengeStats
@@ -32,11 +30,9 @@ async def process_streak(
).first()
if not stats:
return False
if streak <= stats.daily_streak_best < next_streak:
return True
elif next_streak == 0 and stats.daily_streak_best >= streak:
return True
return False
return bool(
streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
)
MEDALS = {

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from app.database.beatmap import calculate_beatmap_attributes
@@ -68,9 +66,7 @@ async def to_the_core(
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
return False
mods_ = mod_to_save(score.mods)
if "DT" not in mods_ or "NC" not in mods_:
return False
return True
return not ("DT" not in mods_ or "NC" not in mods_)
async def wysi(
@@ -83,9 +79,7 @@ async def wysi(
return False
if str(round(score.accuracy, ndigits=4))[3:] != "727":
return False
if "xi" not in beatmap.beatmapset.artist:
return False
return True
return "xi" in beatmap.beatmapset.artist
async def prepared(
@@ -97,9 +91,7 @@ async def prepared(
if score.rank != Rank.X and score.rank != Rank.XH:
return False
mods_ = mod_to_save(score.mods)
if "NF" not in mods_:
return False
return True
return "NF" in mods_
async def reckless_adandon(
@@ -117,9 +109,7 @@ async def reckless_adandon(
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < 3:
return False
return True
return not attribute.star_rating < 3
async def lights_out(
@@ -413,11 +403,10 @@ async def by_the_skin_of_the_teeth(
return False
for mod in score.mods:
if mod.get("acronym") == "AC":
if "settings" in mod and "minimum_accuracy" in mod["settings"]:
target_accuracy = mod["settings"]["minimum_accuracy"]
if isinstance(target_accuracy, int | float):
return abs(score.accuracy - float(target_accuracy)) < 0.0001
if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
target_accuracy = mod["settings"]["minimum_accuracy"]
if isinstance(target_accuracy, int | float):
return abs(score.accuracy - float(target_accuracy)) < 0.0001
return False

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from app.database.score import Beatmap, Score
@@ -19,9 +17,7 @@ async def process_mod(
return False
if not beatmap.beatmap_status.has_leaderboard():
return False
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
return False
return True
return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
async def process_category_mod(

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from app.database.score import Beatmap, Score
@@ -22,11 +20,7 @@ async def process_combo(
return False
if next_combo != 0 and combo >= next_combo:
return False
if combo <= score.max_combo < next_combo:
return True
elif next_combo == 0 and score.max_combo >= combo:
return True
return False
return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
MEDALS: Medals = {

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from app.database import UserStatistics
@@ -35,11 +33,7 @@ async def process_playcount(
).first()
if not stats:
return False
if pc <= stats.play_count < next_pc:
return True
elif next_pc == 0 and stats.play_count >= pc:
return True
return False
return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
MEDALS: Medals = {

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from typing import Literal, cast
@@ -47,9 +45,7 @@ async def process_skill(
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False
if type == "fc" and not score.is_perfect_combo:
return False
return True
return not (type == "fc" and not score.is_perfect_combo)
MEDALS: Medals = {

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from functools import partial
from app.database.score import Beatmap, Score
@@ -35,11 +33,7 @@ async def process_tth(
).first()
if not stats:
return False
if tth <= stats.total_hits < next_tth:
return True
elif next_tth == 0 and stats.play_count >= tth:
return True
return False
return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
MEDALS: Medals = {

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import timedelta
import hashlib
import re
@@ -13,7 +11,7 @@ from app.database import (
User,
)
from app.database.auth import TotpKeys
from app.log import logger
from app.log import log
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from app.utils import utcnow
@@ -31,6 +29,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# bcrypt 缓存(模拟应用状态缓存)
bcrypt_cache = {}
logger = log("Auth")
def validate_username(username: str) -> list[str]:
"""验证用户名"""
@@ -67,7 +67,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
2. MD5哈希 -> bcrypt验证
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
# 2. 检查缓存
if bcrypt_hash in bcrypt_cache:
@@ -101,7 +101,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def get_password_hash(password: str) -> str:
"""生成密码哈希 - 使用 osu! 的方式"""
# 1. 明文密码 -> MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
# 2. MD5 -> bcrypt
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
return pw_bcrypt.decode()
@@ -112,7 +112,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
验证用户身份 - 使用类似 from_login 的逻辑
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest()
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
# 2. 根据用户名查找用户
user = None
@@ -253,7 +253,7 @@ async def store_token(
tokens_to_delete = active_tokens[max_tokens_per_client - 1 :]
for token in tokens_to_delete:
await db.delete(token)
logger.info(f"[Auth] Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
logger.info(f"Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
# 检查是否有重复的 access_token
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
@@ -274,9 +274,7 @@ async def store_token(
await db.commit()
await db.refresh(token_record)
logger.info(
f"[Auth] Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})"
)
logger.info(f"Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})")
return token_record
@@ -325,12 +323,7 @@ def _generate_totp_account_label(user: User) -> str:
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
"""
if settings.totp_use_username_in_label:
# 使用用户名作为主要标识
primary_identifier = user.username
else:
# 使用邮箱作为标识
primary_identifier = user.email
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name:

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
from copy import deepcopy
from enum import Enum
@@ -7,7 +5,7 @@ import math
from typing import TYPE_CHECKING
from app.config import settings
from app.log import logger
from app.log import log
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, parse_enum_to_str
from app.models.score import GameMode
@@ -18,6 +16,8 @@ from redis.asyncio import Redis
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
logger = log("Calculator")
try:
import rosu_pp_py as rosu
except ImportError:
@@ -417,9 +417,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
if len(hit_objects) > i + per_1s:
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
return True
elif len(hit_objects) > i + per_10s:
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
return True
elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
return True
return False
@@ -446,10 +445,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
def is_2b(hit_objects: list[HitObject]) -> bool:
for i in range(0, len(hit_objects) - 1):
if hit_objects[i] == hit_objects[i + 1].start_time:
return True
return False
return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
def is_suspicious_beatmap(content: str) -> bool:

View File

@@ -1,4 +1,3 @@
# ruff: noqa: I002
from enum import Enum
from typing import Annotated, Any
@@ -142,7 +141,7 @@ STORAGE_SETTINGS='{
]
redis_url: Annotated[
str,
Field(default="redis://127.0.0.1:6379/0", description="Redis 连接 URL"),
Field(default="redis://127.0.0.1:6379", description="Redis 连接 URL"),
"数据库设置",
]
@@ -217,7 +216,7 @@ STORAGE_SETTINGS='{
# 服务器设置
host: Annotated[
str,
Field(default="0.0.0.0", description="服务器监听地址"),
Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
"服务器设置",
]
port: Annotated[
@@ -266,18 +265,6 @@ STORAGE_SETTINGS='{
else:
return "/"
# SignalR 设置
signalr_negotiate_timeout: Annotated[
int,
Field(default=30, description="SignalR 协商超时时间(秒)"),
"SignalR 服务器设置",
]
signalr_ping_interval: Annotated[
int,
Field(default=15, description="SignalR ping 间隔(秒)"),
"SignalR 服务器设置",
]
# Fetcher 设置
fetcher_client_id: Annotated[
str,
@@ -329,11 +316,6 @@ STORAGE_SETTINGS='{
Field(default=False, description="是否启用邮件验证功能"),
"验证服务设置",
]
enable_smart_verification: Annotated[
bool,
Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"),
"验证服务设置",
]
enable_session_verification: Annotated[
bool,
Field(default=True, description="是否启用会话验证中间件"),
@@ -487,6 +469,12 @@ STORAGE_SETTINGS='{
"缓存设置",
"谱面缓存",
]
beatmapset_cache_expire_seconds: Annotated[
int,
Field(default=3600, description="Beatmapset 缓存过期时间(秒)"),
"缓存设置",
"谱面缓存",
]
# 排行榜缓存设置
enable_ranking_cache: Annotated[
@@ -551,12 +539,6 @@ STORAGE_SETTINGS='{
"缓存设置",
"用户缓存",
]
user_cache_concurrent_limit: Annotated[
int,
Field(default=10, description="并发缓存用户的限制"),
"缓存设置",
"用户缓存",
]
# 资源代理设置
enable_asset_proxy: Annotated[
@@ -621,26 +603,26 @@ STORAGE_SETTINGS='{
]
@field_validator("fetcher_scopes", mode="before")
@classmethod
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
@field_validator("storage_settings", mode="after")
@classmethod
def validate_storage_settings(
cls,
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
info: ValidationInfo,
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
if not isinstance(v, CloudflareR2Settings):
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
if not isinstance(v, LocalStorageSettings):
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
if not isinstance(v, AWSS3StorageSettings):
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
service = info.data.get("storage_service")
if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
return v

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
BANCHOBOT_ID = 2
BACKUP_CODE_LENGTH = 10

View File

@@ -12,7 +12,7 @@ from .beatmapset import (
BeatmapsetResp,
)
from .beatmapset_ratings import BeatmapRating
from .best_score import BestScore
from .best_scores import BestScore
from .chat import (
ChannelType,
ChatChannel,
@@ -28,22 +28,16 @@ from .counts import (
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
MeResp,
User,
UserResp,
)
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .notification import Notification, UserNotification
from .password_reset import PasswordReset
from .playlist_attempts import (
from .item_attempts_count import (
ItemAttemptsCount,
ItemAttemptsResp,
PlaylistAggregateScore,
)
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .notification import Notification, UserNotification
from .password_reset import PasswordReset
from .playlist_best_score import PlaylistBestScore
from .playlists import Playlist, PlaylistResp
from .pp_best_score import PPBestScore
from .rank_history import RankHistory, RankHistoryResp, RankTop
from .relationship import Relationship, RelationshipResp, RelationshipType
from .room import APIUploadedRoom, Room, RoomResp
@@ -62,6 +56,12 @@ from .statistics import (
UserStatisticsResp,
)
from .team import Team, TeamMember, TeamRequest
from .total_score_best_scores import TotalScoreBestScore
from .user import (
MeResp,
User,
UserResp,
)
from .user_account_history import (
UserAccountHistory,
UserAccountHistoryResp,
@@ -105,7 +105,6 @@ __all__ = [
"Notification",
"OAuthClient",
"OAuthToken",
"PPBestScore",
"PasswordReset",
"Playlist",
"PlaylistAggregateScore",
@@ -131,6 +130,7 @@ __all__ = [
"Team",
"TeamMember",
"TeamRequest",
"TotalScoreBestScore",
"TotpKeys",
"TrustedDevice",
"TrustedDeviceResp",

View File

@@ -24,7 +24,7 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class UserAchievementBase(SQLModel, UTCBaseModel):

View File

@@ -19,7 +19,7 @@ from sqlmodel import (
)
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class OAuthToken(UTCBaseModel, SQLModel, table=True):

View File

@@ -23,7 +23,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .lazer_user import User
from .user import User
class BeatmapOwner(SQLModel):
@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
beatmap = Beatmap.model_validate(
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return beatmap
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)

View File

@@ -20,7 +20,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from .lazer_user import User
from .user import User
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from sqlmodel import Field, SQLModel

View File

@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
return beatmapset
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset)
await session.commit()
@@ -334,5 +339,5 @@ class BeatmapsetResp(BeatmapsetBase):
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
cursor: dict[str, int | float] | None = None
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
from app.database.beatmapset import Beatmapset
from app.database.lazer_user import User
from app.database.user import User
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from app.database.statistics import UserStatistics
from app.models.score import GameMode
from .lazer_user import User
from .user import User
from sqlmodel import (
BigInteger,
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from .score import Score
class PPBestScore(SQLModel, table=True):
class BestScore(SQLModel, table=True):
__tablename__: str = "best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum
from typing import Self
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.models.model import UTCBaseModel
from app.utils import utcnow
@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
)
).first()
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
if last_msg and last_msg.isdigit():
last_msg = int(last_msg)
else:
last_msg = None
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
if last_read_id and last_read_id.isdigit():
last_read_id = int(last_read_id)
else:
last_read_id = last_msg
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
if silence is not None:
attribute = ChatUserAttributes(

View File

@@ -11,7 +11,7 @@ from sqlmodel import (
)
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class CountBase(SQLModel):

View File

@@ -17,7 +17,7 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class DailyChallengeStatsBase(SQLModel, UTCBaseModel):

View File

@@ -18,7 +18,7 @@ from sqlmodel import (
)
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class EventType(str, Enum):

View File

@@ -1,7 +1,7 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.lazer_user import User
from app.database.user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (

View File

@@ -1,39 +0,0 @@
"""
数据库字段类型工具
提供处理数据库和 Pydantic 之间类型转换的工具
"""
from typing import Any
from pydantic import field_validator
from sqlalchemy import Boolean
def bool_field_validator(field_name: str):
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
@field_validator(field_name, mode="before")
@classmethod
def validate_bool_field(cls, v: Any) -> bool:
"""将整数 0/1 转换为布尔值"""
if isinstance(v, int):
return bool(v)
return v
return validate_bool_field
def create_bool_field(**kwargs):
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
from sqlmodel import Column, Field
# 如果没有指定 sa_column则使用 Boolean 类型
if "sa_column" not in kwargs:
# 处理 index 参数
index = kwargs.pop("index", False)
if index:
kwargs["sa_column"] = Column(Boolean, index=True)
else:
kwargs["sa_column"] = Column(Boolean)
return Field(**kwargs)

View File

@@ -1,5 +1,5 @@
from .lazer_user import User, UserResp
from .playlist_best_score import PlaylistBestScore
from .user import User, UserResp
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs

View File

@@ -2,8 +2,6 @@
密码重置相关数据库模型
"""
from __future__ import annotations
from datetime import datetime
from app.utils import utcnow

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from .lazer_user import User
from .user import User
from redis.asyncio import Redis
from sqlmodel import (

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem
from app.models.playlist import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
@@ -72,7 +72,7 @@ class Playlist(PlaylistBase, table=True):
return result.one()
@classmethod
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
async def from_model(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
@@ -107,7 +107,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session)
db_playlist = await cls.from_model(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
await session.refresh(db_playlist)

View File

@@ -21,7 +21,7 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class RankHistory(SQLModel, table=True):

View File

@@ -1,6 +1,6 @@
from enum import Enum
from .lazer_user import User, UserResp
from .user import User, UserResp
from pydantic import BaseModel
from sqlmodel import (

View File

@@ -1,9 +1,8 @@
from datetime import datetime
from app.database.playlist_attempts import PlaylistAggregateScore
from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel
from app.models.multiplayer_hub import ServerMultiplayerRoom
from app.models.room import (
MatchType,
QueueMode,
@@ -14,8 +13,8 @@ from app.models.room import (
)
from app.utils import utcnow
from .lazer_user import User, UserResp
from .playlists import Playlist, PlaylistResp
from .user import User, UserResp
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
@@ -160,25 +159,6 @@ class RoomResp(RoomBase):
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
@classmethod
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
room = server_room.room
resp = cls(
id=room.room_id,
name=room.settings.name,
type=room.settings.match_type,
queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip,
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
status=server_room.status,
category=server_room.category,
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
channel_id=server_room.room.channel_id or 0,
)
return resp
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:

View File

@@ -15,8 +15,8 @@ from sqlmodel import (
)
if TYPE_CHECKING:
from .lazer_user import User
from .room import Room
from .user import User
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):

View File

@@ -16,7 +16,7 @@ from app.calculator import (
from app.config import settings
from app.database.team import TeamMember
from app.dependencies.database import get_redis
from app.log import logger
from app.log import log
from app.models.beatmap import BeatmapRankStatus
from app.models.model import (
CurrentUserAttributes,
@@ -38,17 +38,17 @@ from app.utils import utcnow
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .best_scores import BestScore
from .counts import MonthlyPlaycounts
from .events import Event, EventType
from .lazer_user import User, UserResp
from .playlist_best_score import PlaylistBestScore
from .pp_best_score import PPBestScore
from .relationship import (
Relationship as DBRelationship,
RelationshipType,
)
from .score_token import ScoreToken
from .total_score_best_scores import TotalScoreBestScore
from .user import User, UserResp
from pydantic import BaseModel, field_serializer, field_validator
from redis.asyncio import Redis
@@ -74,6 +74,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
logger = log("Score")
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
@@ -193,13 +195,13 @@ class Score(ScoreBase, table=True):
# optional
beatmap: Mapped[Beatmap] = Relationship()
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
best_score: Mapped[BestScore | None] = Relationship(
best_score: Mapped[TotalScoreBestScore | None] = Relationship(
back_populates="score",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
},
)
ranked_score: Mapped[PPBestScore | None] = Relationship(
ranked_score: Mapped[BestScore | None] = Relationship(
back_populates="score",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
@@ -479,10 +481,10 @@ class ScoreAround(SQLModel):
async def get_best_id(session: AsyncSession, score_id: int) -> int | None:
rownum = (
func.row_number()
.over(partition_by=(col(PPBestScore.user_id), col(PPBestScore.gamemode)), order_by=col(PPBestScore.pp).desc())
.over(partition_by=(col(BestScore.user_id), col(BestScore.gamemode)), order_by=col(BestScore.pp).desc())
.label("rn")
)
subq = select(PPBestScore, rownum).subquery()
subq = select(BestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
return result.one_or_none()
@@ -496,8 +498,8 @@ async def _score_where(
user: User | None = None,
) -> list[ColumnElement[bool] | TextClause] | None:
wheres: list[ColumnElement[bool] | TextClause] = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
col(TotalScoreBestScore.beatmap_id) == beatmap,
col(TotalScoreBestScore.gamemode) == mode,
]
if type == LeaderboardType.FRIENDS:
@@ -510,20 +512,21 @@ async def _score_where(
)
.subquery()
)
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
wheres.append(col(TotalScoreBestScore.user_id).in_(select(subq.c.target_id)))
else:
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
wheres.append(col(TotalScoreBestScore.user).has(col(User.country_code) == user.country_code))
else:
return None
elif type == LeaderboardType.TEAM:
if user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
elif type == LeaderboardType.TEAM and user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(TotalScoreBestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id))
)
if mods:
if user and user.is_supporter:
wheres.append(
@@ -557,10 +560,10 @@ async def get_leaderboard(
max_score = sys.maxsize
while limit > 0:
query = (
select(BestScore)
.where(*wheres, BestScore.total_score < max_score)
select(TotalScoreBestScore)
.where(*wheres, TotalScoreBestScore.total_score < max_score)
.limit(limit)
.order_by(col(BestScore.total_score).desc())
.order_by(col(TotalScoreBestScore.total_score).desc())
)
extra_need = 0
for s in await session.exec(query):
@@ -579,13 +582,13 @@ async def get_leaderboard(
user_score = None
if user:
self_query = (
select(BestScore)
.where(BestScore.user_id == user.id)
select(TotalScoreBestScore)
.where(TotalScoreBestScore.user_id == user.id)
.where(
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
col(TotalScoreBestScore.beatmap_id) == beatmap,
col(TotalScoreBestScore.gamemode) == mode,
)
.order_by(col(BestScore.total_score).desc())
.order_by(col(TotalScoreBestScore.total_score).desc())
.limit(1)
)
if mods:
@@ -618,14 +621,14 @@ async def get_score_position_by_user(
func.row_number()
.over(
partition_by=(
col(BestScore.beatmap_id),
col(BestScore.gamemode),
col(TotalScoreBestScore.beatmap_id),
col(TotalScoreBestScore.gamemode),
),
order_by=col(BestScore.total_score).desc(),
order_by=col(TotalScoreBestScore.total_score).desc(),
)
.label("row_number")
)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
subq = select(TotalScoreBestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
result = await session.exec(stmt)
s = result.first()
@@ -648,14 +651,14 @@ async def get_score_position_by_id(
func.row_number()
.over(
partition_by=(
col(BestScore.beatmap_id),
col(BestScore.gamemode),
col(TotalScoreBestScore.beatmap_id),
col(TotalScoreBestScore.gamemode),
),
order_by=col(BestScore.total_score).desc(),
order_by=col(TotalScoreBestScore.total_score).desc(),
)
.label("row_number")
)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
subq = select(TotalScoreBestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
@@ -667,16 +670,16 @@ async def get_user_best_score_in_beatmap(
beatmap: int,
user: int,
mode: GameMode | None = None,
) -> BestScore | None:
) -> TotalScoreBestScore | None:
return (
await session.exec(
select(BestScore)
select(TotalScoreBestScore)
.where(
BestScore.gamemode == mode if mode is not None else true(),
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
TotalScoreBestScore.gamemode == mode if mode is not None else true(),
TotalScoreBestScore.beatmap_id == beatmap,
TotalScoreBestScore.user_id == user,
)
.order_by(col(BestScore.total_score).desc())
.order_by(col(TotalScoreBestScore.total_score).desc())
)
).first()
@@ -687,32 +690,32 @@ async def get_user_best_score_with_mod_in_beatmap(
user: int,
mod: list[str],
mode: GameMode | None = None,
) -> BestScore | None:
) -> TotalScoreBestScore | None:
return (
await session.exec(
select(BestScore)
select(TotalScoreBestScore)
.where(
BestScore.gamemode == mode if mode is not None else True,
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
TotalScoreBestScore.gamemode == mode if mode is not None else True,
TotalScoreBestScore.beatmap_id == beatmap,
TotalScoreBestScore.user_id == user,
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
).params(w=json.dumps(mod)),
)
.order_by(col(BestScore.total_score).desc())
.order_by(col(TotalScoreBestScore.total_score).desc())
)
).first()
async def get_user_first_scores(
session: AsyncSession, user_id: int, mode: GameMode, limit: int = 5, offset: int = 0
) -> list[BestScore]:
) -> list[TotalScoreBestScore]:
rownum = (
func.row_number()
.over(
partition_by=(col(BestScore.beatmap_id), col(BestScore.gamemode)),
order_by=col(BestScore.total_score).desc(),
partition_by=(col(TotalScoreBestScore.beatmap_id), col(TotalScoreBestScore.gamemode)),
order_by=col(TotalScoreBestScore.total_score).desc(),
)
.label("rn")
)
@@ -720,11 +723,11 @@ async def get_user_first_scores(
# Step 1: Fetch top score_ids in Python
subq = (
select(
col(BestScore.score_id).label("score_id"),
col(BestScore.user_id).label("user_id"),
col(TotalScoreBestScore.score_id).label("score_id"),
col(TotalScoreBestScore.user_id).label("user_id"),
rownum,
)
.where(col(BestScore.gamemode) == mode)
.where(col(TotalScoreBestScore.gamemode) == mode)
.subquery()
)
@@ -733,7 +736,11 @@ async def get_user_first_scores(
top_ids = await session.exec(top_ids_stmt)
top_ids = list(top_ids)
stmt = select(BestScore).where(col(BestScore.score_id).in_(top_ids)).order_by(col(BestScore.total_score).desc())
stmt = (
select(TotalScoreBestScore)
.where(col(TotalScoreBestScore.score_id).in_(top_ids))
.order_by(col(TotalScoreBestScore.total_score).desc())
)
result = await session.exec(stmt)
return list(result.all())
@@ -743,18 +750,18 @@ async def get_user_first_score_count(session: AsyncSession, user_id: int, mode:
rownum = (
func.row_number()
.over(
partition_by=(col(BestScore.beatmap_id), col(BestScore.gamemode)),
order_by=col(BestScore.total_score).desc(),
partition_by=(col(TotalScoreBestScore.beatmap_id), col(TotalScoreBestScore.gamemode)),
order_by=col(TotalScoreBestScore.total_score).desc(),
)
.label("rn")
)
subq = (
select(
col(BestScore.score_id).label("score_id"),
col(BestScore.user_id).label("user_id"),
col(TotalScoreBestScore.score_id).label("score_id"),
col(TotalScoreBestScore.user_id).label("user_id"),
rownum,
)
.where(col(BestScore.gamemode) == mode)
.where(col(TotalScoreBestScore.gamemode) == mode)
.subquery()
)
count_stmt = select(func.count()).where(subq.c.rn == 1, subq.c.user_id == user_id)
@@ -768,13 +775,13 @@ async def get_user_best_pp_in_beatmap(
beatmap: int,
user: int,
mode: GameMode,
) -> PPBestScore | None:
) -> BestScore | None:
return (
await session.exec(
select(PPBestScore).where(
PPBestScore.beatmap_id == beatmap,
PPBestScore.user_id == user,
PPBestScore.gamemode == mode,
select(BestScore).where(
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
BestScore.gamemode == mode,
)
)
).first()
@@ -799,12 +806,12 @@ async def get_user_best_pp(
user: int,
mode: GameMode,
limit: int = 1000,
) -> Sequence[PPBestScore]:
) -> Sequence[BestScore]:
return (
await session.exec(
select(PPBestScore)
.where(PPBestScore.user_id == user, PPBestScore.gamemode == mode)
.order_by(col(PPBestScore.pp).desc())
select(BestScore)
.where(BestScore.user_id == user, BestScore.gamemode == mode)
.order_by(col(BestScore.pp).desc())
.limit(limit)
)
).all()
@@ -854,8 +861,7 @@ async def process_score(
) -> Score:
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
logger.info(
"[Score] Creating score for user {user_id} | beatmap={beatmap_id} "
"ruleset={ruleset} passed={passed} total={total}",
"Creating score for user {user_id} | beatmap={beatmap_id} ruleset={ruleset} passed={passed} total={total}",
user_id=user.id,
beatmap_id=beatmap_id,
ruleset=gamemode,
@@ -897,7 +903,7 @@ async def process_score(
)
session.add(score)
logger.debug(
"[Score] Score staged for commit | token={token} mods={mods} total_hits={hits}",
"Score staged for commit | token={token} mods={mods} total_hits={hits}",
token=score_token.id,
mods=info.mods,
hits=sum(info.statistics.values()) if info.statistics else 0,
@@ -910,7 +916,7 @@ async def process_score(
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
if score.pp != 0:
logger.debug(
"[Score] Skipping PP calculation for score {score_id} | already set {pp:.2f}",
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
score_id=score.id,
pp=score.pp,
)
@@ -918,7 +924,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
can_get_pp = score.passed and score.ranked and mods_can_get_pp(int(score.gamemode), score.mods)
if not can_get_pp:
logger.debug(
"[Score] Skipping PP calculation for score {score_id} | passed={passed} ranked={ranked} mods={mods}",
"Skipping PP calculation for score {score_id} | passed={passed} ranked={ranked} mods={mods}",
score_id=score.id,
passed=score.passed,
ranked=score.ranked,
@@ -928,15 +934,15 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
pp, successed = await pre_fetch_and_calculate_pp(score, session, redis, fetcher)
if not successed:
await redis.rpush("score:need_recalculate", score.id) # pyright: ignore[reportGeneralTypeIssues]
logger.warning("[Score] Queued score {score_id} for PP recalculation", score_id=score.id)
logger.warning("Queued score {score_id} for PP recalculation", score_id=score.id)
return
score.pp = pp
logger.info("[Score] Calculated PP for score {score_id} | pp={pp:.2f}", score_id=score.id, pp=pp)
logger.info("Calculated PP for score {score_id} | pp={pp:.2f}", score_id=score.id, pp=pp)
user_id = score.user_id
beatmap_id = score.beatmap_id
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
best_score = PPBestScore(
best_score = BestScore(
user_id=user_id,
score_id=score.id,
beatmap_id=beatmap_id,
@@ -947,7 +953,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
session.add(best_score)
await session.delete(previous_pp_best) if previous_pp_best else None
logger.info(
"[Score] Updated PP best for user {user_id} | score_id={score_id} pp={pp:.2f}",
"Updated PP best for user {user_id} | score_id={score_id} pp={pp:.2f}",
user_id=user_id,
score_id=score.id,
pp=score.pp,
@@ -966,15 +972,14 @@ async def _process_score_events(score: Score, session: AsyncSession):
if rank_global == 0 or total_users == 0:
logger.debug(
"[Score] Skipping event creation for score {score_id} | "
"rank_global={rank_global} total_users={total_users}",
"Skipping event creation for score {score_id} | rank_global={rank_global} total_users={total_users}",
score_id=score.id,
rank_global=rank_global,
total_users=total_users,
)
return
logger.debug(
"[Score] Processing events for score {score_id} | rank_global={rank_global} total_users={total_users}",
"Processing events for score {score_id} | rank_global={rank_global} total_users={total_users}",
score_id=score.id,
rank_global=rank_global,
total_users=total_users,
@@ -1003,7 +1008,7 @@ async def _process_score_events(score: Score, session: AsyncSession):
}
session.add(rank_event)
logger.info(
"[Score] Registered rank event for user {user_id} | score_id={score_id} rank={rank}",
"Registered rank event for user {user_id} | score_id={score_id} rank={rank}",
user_id=score.user_id,
score_id=score.id,
rank=rank_global,
@@ -1011,12 +1016,12 @@ async def _process_score_events(score: Score, session: AsyncSession):
if rank_global == 1:
displaced_score = (
await session.exec(
select(BestScore)
select(TotalScoreBestScore)
.where(
BestScore.beatmap_id == score.beatmap_id,
BestScore.gamemode == score.gamemode,
TotalScoreBestScore.beatmap_id == score.beatmap_id,
TotalScoreBestScore.gamemode == score.gamemode,
)
.order_by(col(BestScore.total_score).desc())
.order_by(col(TotalScoreBestScore.total_score).desc())
.limit(1)
.offset(1)
)
@@ -1045,12 +1050,12 @@ async def _process_score_events(score: Score, session: AsyncSession):
}
session.add(rank_lost_event)
logger.info(
"[Score] Registered rank lost event | displaced_user={user_id} new_score_id={score_id}",
"Registered rank lost event | displaced_user={user_id} new_score_id={score_id}",
user_id=displaced_score.user_id,
score_id=score.id,
)
logger.debug(
"[Score] Event processing committed for score {score_id}",
"Event processing committed for score {score_id}",
score_id=score.id,
)
@@ -1074,7 +1079,7 @@ async def _process_statistics(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
logger.debug(
"[Score] Existing best scores for user {user_id} | global={global_id} mod={mod_id}",
"Existing best scores for user {user_id} | global={global_id} mod={mod_id}",
user_id=user.id,
global_id=previous_score_best.score_id if previous_score_best else None,
mod_id=previous_score_best_mod.score_id if previous_score_best_mod else None,
@@ -1104,7 +1109,7 @@ async def _process_statistics(
statistics.total_score += score.total_score
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
logger.debug(
"[Score] Score delta computed for {score_id}: {difference}",
"Score delta computed for {score_id}: {difference}",
score_id=score.id,
difference=difference,
)
@@ -1140,7 +1145,7 @@ async def _process_statistics(
# 情况2: 有最佳分数记录但没有该mod组合的记录添加新记录
if previous_score_best is None or previous_score_best_mod is None:
session.add(
BestScore(
TotalScoreBestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
@@ -1151,7 +1156,7 @@ async def _process_statistics(
)
)
logger.info(
"[Score] Created new best score entry for user {user_id} | score_id={score_id} mods={mods}",
"Created new best score entry for user {user_id} | score_id={score_id} mods={mods}",
user_id=user.id,
score_id=score.id,
mods=mod_for_save,
@@ -1163,7 +1168,7 @@ async def _process_statistics(
previous_score_best.rank = score.rank
previous_score_best.score_id = score.id
logger.info(
"[Score] Updated existing best score for user {user_id} | score_id={score_id} total={total}",
"Updated existing best score for user {user_id} | score_id={score_id} total={total}",
user_id=user.id,
score_id=score.id,
total=score.total_score,
@@ -1175,7 +1180,7 @@ async def _process_statistics(
if difference > 0:
# 下方的 if 一定会触发。将高分设置为此分数,删除自己防止重复的 score_id
logger.info(
"[Score] Replacing global best score for user {user_id} | old_score_id={old_score_id}",
"Replacing global best score for user {user_id} | old_score_id={old_score_id}",
user_id=user.id,
old_score_id=previous_score_best.score_id,
)
@@ -1188,7 +1193,7 @@ async def _process_statistics(
previous_score_best_mod.rank = score.rank
previous_score_best_mod.score_id = score.id
logger.info(
"[Score] Replaced mod-specific best for user {user_id} | mods={mods} score_id={score_id}",
"Replaced mod-specific best for user {user_id} | mods={mods} score_id={score_id}",
user_id=user.id,
mods=mod_for_save,
score_id=score.id,
@@ -1202,14 +1207,14 @@ async def _process_statistics(
mouthly_playcount.count += 1
statistics.play_time += playtime
logger.debug(
"[Score] Recorded playtime {playtime}s for score {score_id} (user {user_id})",
"Recorded playtime {playtime}s for score {score_id} (user {user_id})",
playtime=playtime,
score_id=score.id,
user_id=user.id,
)
else:
logger.debug(
"[Score] Playtime {playtime}s for score {score_id} did not meet validity checks",
"Playtime {playtime}s for score {score_id} did not meet validity checks",
playtime=playtime,
score_id=score.id,
)
@@ -1242,7 +1247,7 @@ async def _process_statistics(
if add_to_db:
session.add(mouthly_playcount)
logger.debug(
"[Score] Created monthly playcount record for user {user_id} ({year}-{month})",
"Created monthly playcount record for user {user_id} ({year}-{month})",
user_id=user.id,
year=mouthly_playcount.year,
month=mouthly_playcount.month,
@@ -1262,7 +1267,7 @@ async def process_user(
score_id = score.id
user_id = user.id
logger.info(
"[Score] Processing score {score_id} for user {user_id} on beatmap {beatmap_id}",
"Processing score {score_id} for user {user_id} on beatmap {beatmap_id}",
score_id=score_id,
user_id=user_id,
beatmap_id=score.beatmap_id,
@@ -1287,14 +1292,14 @@ async def process_user(
score_ = (await session.exec(select(Score).where(Score.id == score_id).options(joinedload(Score.beatmap)))).first()
if score_ is None:
logger.warning(
"[Score] Score {score_id} disappeared after commit, skipping event processing",
"Score {score_id} disappeared after commit, skipping event processing",
score_id=score_id,
)
return
await _process_score_events(score_, session)
await session.commit()
logger.info(
"[Score] Finished processing score {score_id} for user {user_id}",
"Finished processing score {score_id} for user {user_id}",
score_id=score_id,
user_id=user_id,
)

View File

@@ -5,7 +5,7 @@ from app.models.score import GameMode
from app.utils import utcnow
from .beatmap import Beatmap
from .lazer_user import User
from .user import User
from sqlalchemy import Column, DateTime, Index
from sqlalchemy.orm import Mapped

View File

@@ -23,7 +23,7 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .lazer_user import User, UserResp
from .user import User, UserResp
class UserStatisticsBase(SQLModel):
@@ -122,7 +122,7 @@ class UserStatisticsResp(UserStatisticsBase):
"progress": int(math.fmod(obj.level_current, 1) * 100),
}
if "user" in include:
from .lazer_user import RANKING_INCLUDES, UserResp
from .user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
s.user = user
@@ -149,7 +149,7 @@ class UserStatisticsResp(UserStatisticsBase):
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .lazer_user import User
from .user import User
query = select(
UserStatistics.user_id,

View File

@@ -8,7 +8,7 @@ from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .lazer_user import User
from .user import User
class Team(SQLModel, UTCBaseModel, table=True):

View File

@@ -4,7 +4,7 @@ from app.calculator import calculate_score_to_level
from app.database.statistics import UserStatistics
from app.models.score import GameMode, Rank
from .lazer_user import User
from .user import User
from sqlmodel import (
JSON,
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
from .score import Score
class BestScore(SQLModel, table=True):
class TotalScoreBestScore(SQLModel, table=True):
__tablename__: str = "total_score_best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
@@ -41,7 +41,7 @@ class BestScore(SQLModel, table=True):
user: User = Relationship()
score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[BestScore.score_id]",
"foreign_keys": "[TotalScoreBestScore.score_id]",
"lazy": "joined",
},
back_populates="best_score",
@@ -75,7 +75,7 @@ class BestScore(SQLModel, table=True):
await session.exec(
select(func.max(Score.max_combo)).where(
Score.user_id == self.user_id,
col(Score.id).in_(select(BestScore.score_id)),
col(Score.id).in_(select(TotalScoreBestScore.score_id)),
Score.gamemode == self.gamemode,
)
)

View File

@@ -72,7 +72,7 @@ COUNTRIES = json.loads((STATIC_DIR / "iso3166.json").read_text())
class UserBase(UTCBaseModel, SQLModel):
avatar_url: str = ""
avatar_url: str = "https://lazer-data.g0v0.top/default.jpg"
country_code: str = Field(default="CN", max_length=2, index=True)
# ? default_group: str|None
is_active: bool = True
@@ -256,16 +256,14 @@ class UserResp(UserBase):
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
from .best_score import BestScore
from .best_scores import BestScore
from .favourite_beatmapset import FavouriteBeatmapset
from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
from .score import Score, get_user_first_score_count
from .total_score_best_scores import TotalScoreBestScore
ruleset = ruleset or obj.playmode
@@ -286,9 +284,9 @@ class UserResp(UserBase):
u.scores_best_count = (
await session.exec(
select(func.count())
.select_from(BestScore)
.select_from(TotalScoreBestScore)
.where(
BestScore.user_id == obj.id,
TotalScoreBestScore.user_id == obj.id,
)
.limit(200)
)
@@ -310,16 +308,16 @@ class UserResp(UserBase):
).all()
]
if "team" in include:
if team_membership := await obj.awaitable_attrs.team_membership:
u.team = team_membership.team
if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
u.team = team_membership.team
if "account_history" in include:
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
if "daily_challenge_user_stats":
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "daily_challenge_user_stats" in include and (
daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
):
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include:
current_stattistics = None
@@ -393,10 +391,10 @@ class UserResp(UserBase):
u.scores_best_count = (
await session.exec(
select(func.count())
.select_from(PPBestScore)
.select_from(BestScore)
.where(
PPBestScore.user_id == obj.id,
PPBestScore.gamemode == ruleset,
BestScore.user_id == obj.id,
BestScore.gamemode == ruleset,
)
.limit(200)
)
@@ -443,7 +441,7 @@ class MeResp(UserResp):
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id

View File

@@ -1,4 +1 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, Header
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
if version is None:
return 0
if version < 1:

View File

@@ -1,8 +1,13 @@
from __future__ import annotations
from typing import Annotated
from app.service.beatmap_download_service import download_service
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
from fastapi import Depends
def get_beatmap_download_service():
"""获取谱面下载服务实例"""
return download_service
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]

View File

@@ -1,16 +1,17 @@
"""
Beatmapset缓存服务依赖注入
"""
from typing import Annotated
from __future__ import annotations
from app.dependencies.database import get_redis
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
from app.dependencies.database import Redis
from app.service.beatmapset_cache_service import (
BeatmapsetCacheService as OriginBeatmapsetCacheService,
get_beatmapset_cache_service,
)
from fastapi import Depends
from redis.asyncio import Redis
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
"""获取beatmapset缓存服务依赖"""
return get_beatmapset_cache_service(redis)
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from contextvars import ContextVar
@@ -11,7 +9,6 @@ from app.config import settings
from fastapi import Depends
from pydantic import BaseModel
import redis as sync_redis
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
@@ -38,13 +35,16 @@ engine = create_async_engine(
)
# Redis 连接
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
redis_client = redis.from_url(settings.redis_url, decode_responses=True, db=0)
# Redis 二进制数据连接 (不自动解码响应,用于存储音频等二进制数据)
redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False)
# Redis 消息缓存连接 (db1)
redis_message_client = redis.from_url(settings.redis_url, decode_responses=True, db=1)
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
# Redis 二进制数据连接 (不自动解码响应用于存储音频等二进制数据db2)
redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False, db=2)
# Redis 限流连接 (db3)
redis_rate_limit_client = redis.from_url(settings.redis_url, decode_responses=True, db=3)
# 数据库依赖
@@ -91,12 +91,15 @@ def get_redis():
return redis_client
Redis = Annotated[redis.Redis, Depends(get_redis)]
def get_redis_binary():
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
return redis_binary_client
def get_redis_message():
def get_redis_message() -> redis.Redis:
"""获取消息专用的 Redis 客户端 (db1)"""
return redis_message_client

View File

@@ -1,17 +1,19 @@
from __future__ import annotations
from typing import Annotated
from app.config import settings
from app.dependencies.database import get_redis
from app.fetcher import Fetcher
from app.log import logger
from app.fetcher import Fetcher as OriginFetcher
from app.log import fetcher_logger
fetcher: Fetcher | None = None
from fastapi import Depends
fetcher: OriginFetcher | None = None
async def get_fetcher() -> Fetcher:
async def get_fetcher() -> OriginFetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
fetcher = OriginFetcher(
settings.fetcher_client_id,
settings.fetcher_client_secret,
settings.fetcher_scopes,
@@ -25,5 +27,10 @@ async def get_fetcher() -> Fetcher:
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
fetcher_logger("Fetcher").opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]

View File

@@ -2,14 +2,15 @@
GeoIP dependency for FastAPI
"""
from __future__ import annotations
from functools import lru_cache
import ipaddress
from typing import Annotated
from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper
from fastapi import Depends, Request
@lru_cache
def get_geoip_helper() -> GeoIPHelper:
@@ -26,7 +27,7 @@ def get_geoip_helper() -> GeoIPHelper:
)
def get_client_ip(request) -> str:
def get_client_ip(request: Request) -> str:
"""
获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
@@ -66,6 +67,10 @@ def get_client_ip(request) -> str:
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
IPAddress = Annotated[str, Depends(get_client_ip)]
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
def is_valid_ip(ip_str: str) -> bool:
"""
验证 IP 地址是否有效(支持 IPv4 和 IPv6

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Any
from fastapi import Request
@@ -7,7 +5,7 @@ from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ValidationError
def BodyOrForm[T: BaseModel](model: type[T]):
def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
async def dependency(
request: Request,
) -> T:

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.config import settings
from fastapi import Depends

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import UTC
from typing import cast

View File

@@ -1,6 +1,4 @@
from __future__ import annotations
from typing import cast
from typing import Annotated, cast
from app.config import (
AWSS3StorageSettings,
@@ -9,11 +7,13 @@ from app.config import (
StorageServiceType,
settings,
)
from app.storage import StorageService
from app.storage import StorageService as OriginStorageService
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
from app.storage.local import LocalStorageService
storage: StorageService | None = None
from fastapi import Depends
storage: OriginStorageService | None = None
def init_storage_service():
@@ -50,3 +50,6 @@ def get_storage_service():
if storage is None:
return init_storage_service()
return storage
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database import User
from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer
@@ -11,7 +10,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Security
from fastapi.security import (
APIKeyQuery,
HTTPBearer,
@@ -112,16 +111,13 @@ async def get_client_user(
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
# 获取当前验证方式
verify_method = None
if api_version >= 20250913:
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101:
verify_method = "totp"
else:
verify_method = "mail"
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
# 设置选择的验证方法到Redis中避免重复选择
if api_version >= 20250913:
@@ -169,3 +165,6 @@ async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User:
return user_and_token[0]
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.models.model import UserAgentInfo as UserAgentInfoModel

View File

@@ -1,10 +0,0 @@
from __future__ import annotations
class SignalRException(Exception):
pass
class InvokeException(SignalRException):
def __init__(self, message: str) -> None:
self.message = message

View File

@@ -1,49 +0,0 @@
"""
用户页面相关的异常类
"""
from __future__ import annotations
class UserpageError(Exception):
"""用户页面处理错误基类"""
def __init__(self, message: str, code: str = "userpage_error"):
self.message = message
self.code = code
super().__init__(message)
class ContentTooLongError(UserpageError):
"""内容过长错误"""
def __init__(self, current_length: int, max_length: int):
message = f"Content too long. Maximum {max_length} characters allowed, got {current_length}."
super().__init__(message, "content_too_long")
self.current_length = current_length
self.max_length = max_length
class ContentEmptyError(UserpageError):
"""内容为空错误"""
def __init__(self):
super().__init__("Content cannot be empty.", "content_empty")
class BBCodeValidationError(UserpageError):
"""BBCode验证错误"""
def __init__(self, errors: list[str]):
message = f"BBCode validation failed: {'; '.join(errors)}"
super().__init__(message, "bbcode_validation_error")
self.errors = errors
class ForbiddenTagError(UserpageError):
"""禁止标签错误"""
def __init__(self, tag: str):
message = f"Forbidden tag '{tag}' is not allowed."
super().__init__(message, "forbidden_tag")
self.tag = tag

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from .beatmap import BeatmapFetcher
from .beatmap_raw import BeatmapRawFetcher
from .beatmapset import BeatmapsetFetcher

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
import time
from urllib.parse import quote
from app.dependencies.database import get_redis
from app.log import logger
from app.log import fetcher_logger
from httpx import AsyncClient
@@ -16,6 +14,9 @@ class TokenAuthError(Exception):
pass
logger = fetcher_logger("Fetcher")
class BaseFetcher:
def __init__(
self,

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
from app.database.beatmap import BeatmapResp
from app.log import logger
from app.log import fetcher_logger
from ._base import BaseFetcher
logger = fetcher_logger("BeatmapFetcher")
class BeatmapFetcher(BaseFetcher):
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
@@ -14,7 +14,7 @@ class BeatmapFetcher(BaseFetcher):
params = {"checksum": beatmap_checksum}
else:
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
logger.opt(colors=True).debug(f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>")
logger.opt(colors=True).debug(f"get_beatmap: <y>{params}</y>")
return BeatmapResp.model_validate(
await self.request_api(

View File

@@ -1,10 +1,9 @@
from __future__ import annotations
from app.log import fetcher_logger
from ._base import BaseFetcher
from httpx import AsyncClient, HTTPError
from httpx._models import Response
from loguru import logger
import redis.asyncio as redis
urls = [
@@ -13,12 +12,14 @@ urls = [
"https://catboy.best/osu/{beatmap_id}",
]
logger = fetcher_logger("BeatmapRawFetcher")
class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str:
for url in urls:
req_url = url.format(beatmap_id=beatmap_id)
logger.opt(colors=True).debug(f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>")
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>")
resp = await self._request(req_url)
if resp.status_code >= 400:
continue

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
@@ -7,7 +5,7 @@ import json
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger
from app.log import fetcher_logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from app.utils import bg_tasks
@@ -24,6 +22,9 @@ class RateLimitError(Exception):
pass
logger = fetcher_logger("BeatmapsetFetcher")
class BeatmapsetFetcher(BaseFetcher):
@staticmethod
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
@@ -113,7 +114,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
@@ -135,7 +136,7 @@ class BeatmapsetFetcher(BaseFetcher):
return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
logger.opt(colors=True).debug(f"get_beatmapset: <y>{beatmap_set_id}</y>")
return BeatmapsetResp.model_validate(
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
@@ -144,7 +145,7 @@ class BeatmapsetFetcher(BaseFetcher):
async def search_beatmapset(
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
logger.opt(colors=True).debug(f"search_beatmapset: <y>{query}</y>")
# 生成缓存键
cache_key = self._generate_cache_key(query, cursor)
@@ -152,17 +153,15 @@ class BeatmapsetFetcher(BaseFetcher):
# 尝试从缓存获取结果
cached_result = await redis_client.get(cache_key)
if cached_result:
logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
logger.opt(colors=True).debug(f"Cache hit for key: <y>{cache_key}</y>")
try:
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Cache data invalid, fetching from API: {e}"
)
logger.warning(f"Cache data invalid, fetching from API: {e}")
# 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
logger.debug("Cache miss, fetching from API")
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
@@ -186,9 +185,7 @@ class BeatmapsetFetcher(BaseFetcher):
cache_ttl = 15 * 60 # 15 分钟
await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
)
logger.opt(colors=True).debug(f"Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)")
resp = SearchBeatmapsetsResp.model_validate(api_response)
@@ -204,9 +201,7 @@ class BeatmapsetFetcher(BaseFetcher):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
logger.opt(colors=True).info(
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
)
logger.info("Prefetch skipped due to rate limit")
bg_tasks.add_task(delayed_prefetch)
@@ -230,14 +225,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
logger.debug(f"Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
@@ -247,7 +242,7 @@ class BeatmapsetFetcher(BaseFetcher):
cursor = data["cursor"]
continue
except Exception:
pass
logger.warning("Failed to parse cached data for cursor")
break
# 在预取页面之间添加延迟,避免突发请求
@@ -282,22 +277,18 @@ class BeatmapsetFetcher(BaseFetcher):
ex=prefetch_ttl,
)
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
)
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
except RateLimitError:
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch stopped due to rate limit")
logger.info("Prefetch stopped due to rate limit")
except Exception as e:
logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
logger.warning(f"Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
)
logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
for i, (query, cursor) in enumerate(homepage_queries):
try:
@@ -309,9 +300,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存
if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(
f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
)
logger.debug(f"Query {query.sort} already cached")
continue
# 请求并缓存
@@ -334,24 +323,15 @@ class BeatmapsetFetcher(BaseFetcher):
ex=cache_ttl,
)
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
)
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
if api_response.get("cursor"):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
logger.opt(colors=True).info(
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup prefetch "
f"skipped for {query.sort} due to rate limit"
)
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
except RateLimitError:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup skipped for {query.sort} due to rate limit"
)
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
except Exception as e:
logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
)
logger.error(f"Failed to warmup cache for {query.sort}: {e}")

0
app/helpers/__init__.py Normal file
View File

View File

@@ -0,0 +1,106 @@
"""资源代理辅助方法与路由装饰器。"""
from collections.abc import Awaitable, Callable
from functools import wraps
import re
from typing import Any
from app.config import settings
from fastapi import Response
from pydantic import BaseModel
Handler = Callable[..., Awaitable[Any]]
def _replace_asset_urls_in_string(value: str) -> str:
result = value
custom_domain = settings.custom_asset_domain
asset_prefix = settings.asset_proxy_prefix
avatar_prefix = settings.avatar_proxy_prefix
beatmap_prefix = settings.beatmap_proxy_prefix
audio_proxy_base_url = f"{settings.server_url}api/private/audio/beatmapset"
result = re.sub(
r"^https://assets\.ppy\.sh/",
f"https://{asset_prefix}.{custom_domain}/",
result,
)
result = re.sub(
r"^https://b\.ppy\.sh/preview/(\d+)\\.mp3",
rf"{audio_proxy_base_url}/\1",
result,
)
result = re.sub(
r"^//b\.ppy\.sh/preview/(\d+)\\.mp3",
rf"{audio_proxy_base_url}/\1",
result,
)
result = re.sub(
r"^https://a\.ppy\.sh/",
f"https://{avatar_prefix}.{custom_domain}/",
result,
)
result = re.sub(
r"https://b\.ppy\.sh/",
f"https://{beatmap_prefix}.{custom_domain}/",
result,
)
return result
def _replace_asset_urls_in_data(data: Any) -> Any:
if isinstance(data, str):
return _replace_asset_urls_in_string(data)
if isinstance(data, list):
return [_replace_asset_urls_in_data(item) for item in data]
if isinstance(data, tuple):
return tuple(_replace_asset_urls_in_data(item) for item in data)
if isinstance(data, dict):
return {key: _replace_asset_urls_in_data(value) for key, value in data.items()}
return data
async def replace_asset_urls(data: Any) -> Any:
"""替换数据中的 osu! 资源 URL。"""
if not settings.enable_asset_proxy:
return data
if hasattr(data, "model_dump"):
raw = data.model_dump()
processed = _replace_asset_urls_in_data(raw)
try:
return data.__class__(**processed)
except Exception:
return processed
if isinstance(data, (dict, list, tuple, str)):
return _replace_asset_urls_in_data(data)
return data
def asset_proxy_response(func: Handler) -> Handler:
"""装饰器:在返回响应前替换资源 URL。"""
@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
if not settings.enable_asset_proxy:
return result
if isinstance(result, Response):
return result
if isinstance(result, BaseModel):
result = result.model_dump()
return _replace_asset_urls_in_data(result)
return wrapper # type: ignore[return-value]

View File

@@ -1,19 +1,37 @@
"""
GeoLite2 Helper Class
GeoLite2 Helper Class (asynchronous)
"""
from __future__ import annotations
import asyncio
from contextlib import suppress
import os
from pathlib import Path
import shutil
import tarfile
import tempfile
import time
from typing import Any, Required, TypedDict
from app.log import logger
import aiofiles
import httpx
import maxminddb
class GeoIPLookupResult(TypedDict, total=False):
ip: Required[str]
country_iso: str
country_name: str
city_name: str
latitude: str
longitude: str
time_zone: str
postal_code: str
asn: int | None
organization: str
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
@@ -25,163 +43,184 @@ EDITIONS = {
class GeoIPHelper:
def __init__(
self,
dest_dir="./geoip",
license_key=None,
editions=None,
max_age_days=8,
timeout=60.0,
dest_dir: str | Path = Path("./geoip"),
license_key: str | None = None,
editions: list[str] | None = None,
max_age_days: int = 8,
timeout: float = 60.0,
):
self.dest_dir = dest_dir
self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"]
self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days
self.timeout = timeout
self._readers = {}
self._readers: dict[str, maxminddb.Reader] = {}
self._update_lock = asyncio.Lock()
@staticmethod
def _safe_extract(tar: tarfile.TarFile, path: str):
base = Path(path).resolve()
for m in tar.getmembers():
target = (base / m.name).resolve()
if not str(target).startswith(str(base)):
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = path.resolve()
for member in tar.getmembers():
target = (base / member.name).resolve()
if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter="data")
tar.extractall(path=base, filter="data")
def _download_and_extract(self, edition_id: str) -> str:
"""
下载并解压 mmdb 文件到 dest_dir仅保留 .mmdb
- 跟随 302 重定向
- 流式下载到临时文件
- 临时目录退出后自动清理
"""
@staticmethod
def _as_mapping(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
@staticmethod
def _as_str(value: Any, default: str = "") -> str:
if isinstance(value, str):
return value
if value is None:
return default
return str(value)
@staticmethod
def _as_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
@staticmethod
def _extract_tarball(src: Path, dest: Path) -> None:
with tarfile.open(src, "r:gz") as tar:
GeoIPHelper._safe_extract(tar, dest)
@staticmethod
def _find_mmdb(root: Path) -> Path | None:
for candidate in root.rglob("*.mmdb"):
return candidate
return None
def _latest_file_sync(self, edition_id: str) -> Path | None:
directory = self.dest_dir
if not directory.is_dir():
return None
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
async def _latest_file(self, edition_id: str) -> Path | None:
return await asyncio.to_thread(self._latest_file_sync, edition_id)
async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key:
raise ValueError("缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY")
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp:
try:
tgz_path = tmp_dir / "db.tgz"
async with (
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
client.stream("GET", url) as resp,
):
resp.raise_for_status()
with tempfile.TemporaryDirectory() as tmpd:
tgz_path = os.path.join(tmpd, "db.tgz")
# 流式写入
with open(tgz_path, "wb") as f:
for chunk in resp.iter_bytes():
if chunk:
f.write(chunk)
async with aiofiles.open(tgz_path, "wb") as download_file:
async for chunk in resp.aiter_bytes():
if chunk:
await download_file.write(chunk)
# 解压并只移动 .mmdb
with tarfile.open(tgz_path, "r:gz") as tar:
# 先安全检查与解压
self._safe_extract(tar, tmpd)
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
if mmdb_path is None:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
# 递归找 .mmdb
mmdb_path = None
for root, _, files in os.walk(tmpd):
for fn in files:
if fn.endswith(".mmdb"):
mmdb_path = os.path.join(root, fn)
break
if mmdb_path:
break
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
dst = self.dest_dir / mmdb_path.name
await asyncio.to_thread(shutil.move, mmdb_path, dst)
return dst
finally:
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
if not mmdb_path:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
async def update(self, force: bool = False) -> None:
async with self._update_lock:
for edition in self.editions:
edition_id = EDITIONS[edition]
path = await self._latest_file(edition_id)
need_download = force or path is None
os.makedirs(self.dest_dir, exist_ok=True)
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
shutil.move(mmdb_path, dst)
return dst
def _latest_file(self, edition_id: str):
if not os.path.isdir(self.dest_dir):
return None
files = [
os.path.join(self.dest_dir, f)
for f in os.listdir(self.dest_dir)
if f.startswith(edition_id) and f.endswith(".mmdb")
]
return max(files, key=os.path.getmtime) if files else None
def update(self, force=False):
from app.log import logger
for ed in self.editions:
eid = EDITIONS[ed]
path = self._latest_file(eid)
need = force or not path
if path:
age_days = (time.time() - os.path.getmtime(path)) / 86400
if age_days >= self.max_age_days:
need = True
logger.info(
f"[GeoIP] {eid} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
if path:
mtime = await asyncio.to_thread(path.stat)
age_days = (time.time() - mtime.st_mtime) / 86400
if age_days >= self.max_age_days:
need_download = True
logger.info(
f"{edition_id} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
else:
logger.info(
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
)
else:
logger.info(
f"[GeoIP] {eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
)
else:
logger.info(f"[GeoIP] {eid} database not found, will download")
logger.info(f"{edition_id} database not found, will download")
if need:
logger.info(f"[GeoIP] Downloading {eid} database...")
path = self._download_and_extract(eid)
logger.info(f"[GeoIP] {eid} database downloaded successfully")
else:
logger.info(f"[GeoIP] Using existing {eid} database")
if need_download:
logger.info(f"Downloading {edition_id} database...")
path = await self._download_and_extract(edition_id)
logger.info(f"{edition_id} database downloaded successfully")
else:
logger.info(f"Using existing {edition_id} database")
old = self._readers.get(ed)
if old:
try:
old.close()
except Exception:
pass
if path is not None:
self._readers[ed] = maxminddb.open_database(path)
old_reader = self._readers.get(edition)
if old_reader:
with suppress(Exception):
old_reader.close()
if path is not None:
self._readers[edition] = maxminddb.open_database(str(path))
def lookup(self, ip: str):
res = {"ip": ip}
# City
city_r = self._readers.get("City")
if city_r:
data = city_r.get(ip)
if data:
country = data.get("country") or {}
res["country_iso"] = country.get("iso_code") or ""
res["country_name"] = (country.get("names") or {}).get("en", "")
city = data.get("city") or {}
res["city_name"] = (city.get("names") or {}).get("en", "")
loc = data.get("location") or {}
res["latitude"] = str(loc.get("latitude") or "")
res["longitude"] = str(loc.get("longitude") or "")
res["time_zone"] = str(loc.get("time_zone") or "")
postal = data.get("postal") or {}
if "code" in postal:
res["postal_code"] = postal["code"]
# ASN
asn_r = self._readers.get("ASN")
if asn_r:
data = asn_r.get(ip)
if data:
res["asn"] = data.get("autonomous_system_number")
res["organization"] = data.get("autonomous_system_organization")
def lookup(self, ip: str) -> GeoIPLookupResult:
res: GeoIPLookupResult = {"ip": ip}
city_reader = self._readers.get("City")
if city_reader:
data = city_reader.get(ip)
if isinstance(data, dict):
country = self._as_mapping(data.get("country"))
res["country_iso"] = self._as_str(country.get("iso_code"))
country_names = self._as_mapping(country.get("names"))
res["country_name"] = self._as_str(country_names.get("en"))
city = self._as_mapping(data.get("city"))
city_names = self._as_mapping(city.get("names"))
res["city_name"] = self._as_str(city_names.get("en"))
location = self._as_mapping(data.get("location"))
latitude = location.get("latitude")
longitude = location.get("longitude")
res["latitude"] = str(latitude) if latitude is not None else ""
res["longitude"] = str(longitude) if longitude is not None else ""
res["time_zone"] = self._as_str(location.get("time_zone"))
postal = self._as_mapping(data.get("postal"))
postal_code = postal.get("code")
if postal_code is not None:
res["postal_code"] = self._as_str(postal_code)
asn_reader = self._readers.get("ASN")
if asn_reader:
data = asn_reader.get(ip)
if isinstance(data, dict):
res["asn"] = self._as_int(data.get("autonomous_system_number"))
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res
def close(self):
for r in self._readers.values():
try:
r.close()
except Exception:
pass
def close(self) -> None:
for reader in self._readers.values():
with suppress(Exception):
reader.close()
self._readers = {}
if __name__ == "__main__":
# 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
async def _demo() -> None:
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
await geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
asyncio.run(_demo())

View File

@@ -6,8 +6,6 @@ Rate limiter for osu! API requests to avoid abuse detection.
- 建议:每分钟不超过 60 次请求以避免滥用检测
"""
from __future__ import annotations
import asyncio
from collections import deque
import time

View File

@@ -1,73 +0,0 @@
"""
会话验证接口
基于osu-web的SessionVerificationInterface实现
用于标准化会话验证行为
"""
from __future__ import annotations
from abc import ABC, abstractmethod
class SessionVerificationInterface(ABC):
"""会话验证接口
定义了会话验证所需的基本操作参考osu-web的实现
"""
@classmethod
@abstractmethod
async def find_for_verification(cls, session_id: str) -> SessionVerificationInterface | None:
"""根据会话ID查找会话用于验证
Args:
session_id: 会话ID
Returns:
会话实例或None
"""
pass
@abstractmethod
def get_key(self) -> str:
"""获取会话密钥/ID"""
pass
@abstractmethod
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
pass
@abstractmethod
def get_verification_method(self) -> str | None:
"""获取当前验证方法
Returns:
验证方法 ('totp', 'mail') 或 None
"""
pass
@abstractmethod
def is_verified(self) -> bool:
"""检查会话是否已验证"""
pass
@abstractmethod
async def mark_verified(self) -> None:
"""标记会话为已验证"""
pass
@abstractmethod
async def set_verification_method(self, method: str) -> None:
"""设置验证方法
Args:
method: 验证方法 ('totp', 'mail')
"""
pass
@abstractmethod
def user_id(self) -> int | None:
"""获取关联的用户ID"""
pass

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import http
import inspect
import logging
import re
from sys import stdout
from types import FunctionType
from typing import TYPE_CHECKING
from app.config import settings
from app.utils import snake_to_pascal
import loguru
@@ -37,16 +37,18 @@ class InterceptHandler(logging.Handler):
depth += 1
message = record.getMessage()
_logger = logger
if record.name == "uvicorn.access":
message = self._format_uvicorn_access_log(message)
color = True
_logger = uvicorn_logger()
elif record.name == "uvicorn.error":
message = self._format_uvicorn_error_log(message)
_logger = uvicorn_logger()
color = True
else:
color = False
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
_logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
def _format_uvicorn_error_log(self, message: str) -> str:
websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
@@ -93,9 +95,7 @@ class InterceptHandler(logging.Handler):
status_color = "green"
elif 300 <= status < 400:
status_color = "yellow"
elif 400 <= status < 500:
status_color = "red"
elif 500 <= status < 600:
elif 400 <= status < 500 or 500 <= status < 600:
status_color = "red"
return (
@@ -107,11 +107,106 @@ class InterceptHandler(logging.Handler):
return message
def get_caller_class_name(module_prefix: str = "", just_last_part: bool = True) -> str | None:
stack = inspect.stack()
for frame_info in stack[2:]:
module = frame_info.frame.f_globals.get("__name__", "")
if module_prefix and not module.startswith(module_prefix):
continue
local_vars = frame_info.frame.f_locals
# 实例方法
if "self" in local_vars:
return local_vars["self"].__class__.__name__
# 类方法
if "cls" in local_vars:
return local_vars["cls"].__name__
# 静态方法 / 普通函数 -> 尝试通过函数名匹配类
func_name = frame_info.function
for obj_name, obj in frame_info.frame.f_globals.items():
if isinstance(obj, type): # 遍历模块内类
cls = obj
attr = getattr(cls, func_name, None)
if isinstance(attr, (staticmethod, classmethod, FunctionType)):
return cls.__name__
# 如果没找到类,返回模块名
if just_last_part:
return module.rsplit(".", 1)[-1]
return module
return None
def service_logger(name: str) -> "Logger":
return logger.bind(service=name)
def fetcher_logger(name: str) -> "Logger":
return logger.bind(fetcher=name)
def task_logger(name: str) -> "Logger":
return logger.bind(task=name)
def system_logger(name: str) -> "Logger":
return logger.bind(system=name)
def uvicorn_logger() -> "Logger":
return logger.bind(uvicorn="Uvicorn")
def log(name: str) -> "Logger":
return logger.bind(real_name=name)
def dynamic_format(record):
name = ""
uvicorn = record["extra"].get("uvicorn")
if uvicorn:
name = f"<fg #228B22>{uvicorn}</fg #228B22>"
service = record["extra"].get("service")
if not service:
service = get_caller_class_name("app.service")
if service:
name = f"<blue>{service}</blue>"
fetcher = record["extra"].get("fetcher")
if not fetcher:
fetcher = get_caller_class_name("app.fetcher")
if fetcher:
name = f"<magenta>{fetcher}</magenta>"
task = record["extra"].get("task")
if not task:
task = get_caller_class_name("app.tasks")
if task:
task = snake_to_pascal(task)
name = f"<fg #FFD700>{task}</fg #FFD700>"
system = record["extra"].get("system")
if system:
name = f"<red>{system}</red>"
if name == "":
real_name = record["extra"].get("real_name", "") or record["name"]
name = f"<fg #FFC1C1>{real_name}</fg #FFC1C1>"
format = f"<green>{{time:YYYY-MM-DD HH:mm:ss}}</green> [<level>{{level}}</level>] | {name} | {{message}}\n"
if record["exception"]:
format += "{exception}\n"
return format
logger.remove()
logger.add(
stdout,
colorize=True,
format=("<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"),
format=dynamic_format,
level=settings.log_level,
diagnose=settings.debug,
)
@@ -120,7 +215,7 @@ logger.add(
rotation="00:00",
retention="30 days",
colorize=False,
format="{time:YYYY-MM-DD HH:mm:ss} {level} | {message}",
format=dynamic_format,
level=settings.log_level,
diagnose=settings.debug,
encoding="utf8",
@@ -135,8 +230,9 @@ uvicorn_loggers = [
]
for logger_name in uvicorn_loggers:
uvicorn_logger = logging.getLogger(logger_name)
uvicorn_logger.handlers = [InterceptHandler()]
uvicorn_logger.propagate = False
_uvicorn_logger = logging.getLogger(logger_name)
_uvicorn_logger.handlers = [InterceptHandler()]
_uvicorn_logger.propagate = False
logging.getLogger("httpx").setLevel("WARNING")
logging.getLogger("apscheduler").setLevel("WARNING")

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from .verify_session import SessionState, VerifySessionMiddleware
__all__ = ["SessionState", "VerifySessionMiddleware"]

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
from app.config import settings
from app.middleware.verify_session import VerifySessionMiddleware
from fastapi import FastAPI
def setup_session_verification_middleware(app: FastAPI) -> None:
"""设置会话验证中间件
Args:
app: FastAPI应用实例
"""
# 只在启用会话验证时添加中间件
if settings.enable_session_verification:
app.add_middleware(VerifySessionMiddleware)
# 可以在这里添加中间件配置日志
from app.log import logger
logger.info("[Middleware] Session verification middleware enabled")
else:
from app.log import logger
logger.info("[Middleware] Session verification middleware disabled")
def setup_all_middlewares(app: FastAPI) -> None:
"""设置所有中间件
Args:
app: FastAPI应用实例
"""
# 设置会话验证中间件
setup_session_verification_middleware(app)
# 可以在这里添加其他中间件
# app.add_middleware(OtherMiddleware)
from app.log import logger
logger.info("[Middleware] All middlewares configured")

View File

@@ -4,17 +4,15 @@ FastAPI会话验证中间件
基于osu-web的会话验证系统适配FastAPI框架
"""
from __future__ import annotations
from collections.abc import Callable
from typing import ClassVar
from app.auth import get_token_by_access_token
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database.lazer_user import User
from app.database.user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.log import log
from app.service.verification_service import LoginSessionService
from app.utils import extract_user_agent
@@ -25,180 +23,7 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware
class VerifySessionMiddleware(BaseHTTPMiddleware):
"""会话验证中间件
参考osu-web的VerifyUser中间件适配FastAPI
"""
# 需要跳过验证的路由
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/session/verify/mail-fallback",
"/api/v2/me",
"/api/v2/me/",
"/api/v2/logout",
"/oauth/token",
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/redoc",
}
# 总是需要验证的路由前缀
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
"/api/private/admin/",
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""中间件主处理逻辑"""
try:
# 检查是否跳过验证
if self._should_skip_verification(request):
return await call_next(request)
# 获取当前用户
user = await self._get_current_user(request)
if not user:
# 未登录用户跳过验证
return await call_next(request)
# 获取会话状态
session_state = await self._get_session_state(request, user)
if not session_state:
# 无会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if session_state.is_verified():
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request, user):
return await call_next(request)
# 启动验证流程
return await self._initiate_verification(request, session_state)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error: {e}")
# 出错时允许请求继续,避免阻塞
return await call_next(request)
def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
# 完全匹配的跳过路由
if path in self.SKIP_VERIFICATION_ROUTES:
return True
# 非API请求跳过
if not path.startswith("/api/"):
return True
return False
def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证"""
path = request.url.path
method = request.method
# 检查是否为强制验证的路由
for pattern in self.ALWAYS_VERIFY_PATTERNS:
if path.startswith(pattern):
return True
if not user.is_active:
return True
# 安全方法GET/HEAD/OPTIONS一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# 修改操作POST/PUT/DELETE/PATCH需要验证
return method in {"POST", "PUT", "DELETE", "PATCH"}
async def _get_current_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 从Authorization header提取token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:] # 移除"Bearer "前缀
# 创建专用数据库会话
async with with_db() as db:
# 获取token记录
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
# 获取用户
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user
except Exception as e:
logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
return None
async def _get_session_state(self, request: Request, user: User) -> SessionState | None:
"""获取会话状态"""
try:
# 提取会话token这里简化为使用相同的auth token
auth_header = request.headers.get("Authorization", "")
api_version = 0
raw_api_version = request.headers.get("x-api-version")
if raw_api_version is not None:
try:
api_version = int(raw_api_version)
except ValueError:
api_version = 0
if not auth_header.startswith("Bearer "):
return None
session_token = auth_header[7:]
# 获取数据库和Redis连接
async with with_db() as db:
redis = get_redis()
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return SessionState(session, user, redis, db, api_version)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
return None
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
"""启动验证流程"""
try:
method = await state.get_method()
if method == "mail":
await state.issue_mail_if_needed()
# 返回验证要求响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error initiating verification: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
)
logger = log("Middleware")
class SessionState:
@@ -261,7 +86,7 @@ class SessionState:
self.session.web_uuid,
)
except Exception as e:
logger.error(f"[Session State] Error marking verified: {e}")
logger.error(f"Error marking verified: {e}")
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
@@ -274,7 +99,7 @@ class SessionState:
self.db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
except Exception as e:
logger.error(f"[Session State] Error issuing mail: {e}")
logger.error(f"Error issuing mail: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
@@ -289,3 +114,169 @@ class SessionState:
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id
class VerifySessionMiddleware(BaseHTTPMiddleware):
"""会话验证中间件
参考osu-web的VerifyUser中间件适配FastAPI
"""
# 需要跳过验证的路由
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/session/verify/mail-fallback",
"/api/v2/me",
"/api/v2/me/",
"/api/v2/logout",
"/oauth/token",
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/redoc",
}
# 总是需要验证的路由前缀
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
"/api/private/admin/",
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""中间件主处理逻辑"""
# 检查是否跳过验证
if self._should_skip_verification(request):
return await call_next(request)
# 获取当前用户
user = await self._get_current_user(request)
if not user:
# 未登录用户跳过验证
return await call_next(request)
# 获取会话状态
session_state = await self._get_session_state(request, user)
if not session_state:
# 无会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if session_state.is_verified():
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request, user):
return await call_next(request)
# 启动验证流程
return await self._initiate_verification(session_state)
def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
# 完全匹配的跳过路由
if path in self.SKIP_VERIFICATION_ROUTES:
return True
# 非API请求跳过
return bool(not path.startswith("/api/"))
def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证"""
path = request.url.path
method = request.method
# 检查是否为强制验证的路由
for pattern in self.ALWAYS_VERIFY_PATTERNS:
if path.startswith(pattern):
return True
if not user.is_active:
return True
# 安全方法GET/HEAD/OPTIONS一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# 修改操作POST/PUT/DELETE/PATCH需要验证
return method in {"POST", "PUT", "DELETE", "PATCH"}
async def _get_current_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 从Authorization header提取token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:] # 移除"Bearer "前缀
# 创建专用数据库会话
async with with_db() as db:
# 获取token记录
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
# 获取用户
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user
except Exception as e:
logger.debug(f"Error getting user: {e}")
return None
async def _get_session_state(self, request: Request, user: User) -> SessionState | None:
"""获取会话状态"""
try:
# 提取会话token这里简化为使用相同的auth token
auth_header = request.headers.get("Authorization", "")
api_version = 0
raw_api_version = request.headers.get("x-api-version")
if raw_api_version is not None:
try:
api_version = int(raw_api_version)
except ValueError:
api_version = 0
if not auth_header.startswith("Bearer "):
return None
session_token = auth_header[7:]
# 获取数据库和Redis连接
async with with_db() as db:
redis = get_redis()
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return SessionState(session, user, redis, db, api_version)
except Exception as e:
logger.error(f"Error getting session state: {e}")
return None
async def _initiate_verification(self, state: SessionState) -> Response:
"""启动验证流程"""
try:
method = await state.get_method()
if method == "mail":
await state.issue_mail_if_needed()
# 返回验证要求响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"Error initiating verification: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
)

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, NamedTuple

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from enum import IntEnum
from typing import Annotated, Any, Literal
@@ -204,3 +202,6 @@ class SearchQueryModel(BaseModel):
default=None,
description="游标字符串,用于分页",
)
SearchQueryModel.model_rebuild()

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel

View File

@@ -2,8 +2,6 @@
扩展的 OAuth 响应模型,支持二次验证
"""
from __future__ import annotations
from pydantic import BaseModel
@@ -11,7 +9,7 @@ class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None
token_type: str = "Bearer"
token_type: str = "Bearer" # noqa: S105
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
@@ -20,14 +18,3 @@ class ExtendedTokenResponse(BaseModel):
requires_second_factor: bool = False
verification_message: str | None = None
user_id: int | None = None # 用于二次验证的用户ID
class SessionState(BaseModel):
"""会话状态"""
user_id: int
username: str
email: str
requires_verification: bool
session_token: str | None = None
verification_sent: bool = False

View File

@@ -1,157 +0,0 @@
from __future__ import annotations
from enum import IntEnum
from typing import ClassVar, Literal
from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, Field
TOTAL_SCORE_DISTRIBUTION_BINS = 13
class _UserActivity(SignalRUnionMessage): ...
class ChoosingBeatmap(_UserActivity):
union_type: ClassVar[Literal[11]] = 11
class _InGame(_UserActivity):
beatmap_id: int
beatmap_display_title: str
ruleset_id: int
ruleset_playing_verb: str
class InSoloGame(_InGame):
union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame):
union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame):
union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame):
union_type: ClassVar[Literal[31]] = 31
class PlayingDailyChallenge(_InGame):
union_type: ClassVar[Literal[52]] = 52
class EditingBeatmap(_UserActivity):
union_type: ClassVar[Literal[41]] = 41
beatmap_id: int
beatmap_display_title: str
class TestingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[43]] = 43
class ModdingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[42]] = 42
class WatchingReplay(_UserActivity):
union_type: ClassVar[Literal[13]] = 13
score_id: int
player_name: str
beatmap_id: int
beatmap_display_title: str
class SpectatingUser(WatchingReplay):
union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity):
union_type: ClassVar[Literal[21]] = 21
class InLobby(_UserActivity):
union_type: ClassVar[Literal[22]] = 22
room_id: int
room_name: str
class InDailyChallengeLobby(_UserActivity):
union_type: ClassVar[Literal[51]] = 51
UserActivity = (
ChoosingBeatmap
| InSoloGame
| WatchingReplay
| SpectatingUser
| SearchingForLobby
| InLobby
| InMultiplayerGame
| SpectatingMultiplayerGame
| InPlaylistGame
| EditingBeatmap
| ModdingBeatmap
| TestingBeatmap
| InDailyChallengeLobby
| PlayingDailyChallenge
)
class UserPresence(BaseModel):
activity: UserActivity | None = None
status: OnlineStatus | None = None
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
@property
def for_push(self) -> "UserPresence | None":
return UserPresence(
activity=self.activity,
status=self.status,
)
class MetadataClientState(UserPresence, UserState): ...
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身
DO_NOT_DISTURB = 1
ONLINE = 2
class DailyChallengeInfo(BaseModel):
room_id: int
class MultiplayerPlaylistItemStats(BaseModel):
playlist_item_id: int = 0
total_score_distribution: list[int] = Field(
default_factory=list,
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
)
cumulative_score: int = 0
last_processed_score_id: int = 0
class MultiplayerRoomStats(BaseModel):
room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
class MultiplayerRoomScoreSetEvent(BaseModel):
room_id: int
playlist_item_id: int
score_id: int
user_id: int
total_score: int
new_rank: int | None = None

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import hashlib
import json
from typing import Any, Literal, NotRequired, TypedDict
from app.config import settings as app_settings
from app.log import logger
from app.log import log
from app.path import CONFIG_DIR, STATIC_DIR
from pydantic import ConfigDict, Field, create_model
@@ -268,7 +266,7 @@ def generate_ranked_mod_settings(enable_all: bool = False):
for mod_acronym in ruleset_mods:
result[ruleset_id][mod_acronym] = {}
if not enable_all:
logger.info("ENABLE_ALL_MODS_PP is deprecated, transformed to config/ranked_mods.json")
log("Mod").info("ENABLE_ALL_MODS_PP is deprecated, transformed to config/ranked_mods.json")
result["$mods_checksum"] = checksum # pyright: ignore[reportArgumentType]
ranked_mods_file.write_text(json.dumps(result, indent=4))

View File

@@ -1,840 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import IntEnum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypedDict,
cast,
override,
)
from app.database.beatmap import Beatmap
from app.dependencies.database import with_db
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.utils import utcnow
from .mods import API_MODS, APIMod
from .room import (
DownloadState,
MatchType,
MultiplayerRoomState,
MultiplayerUserState,
QueueMode,
RoomCategory,
RoomStatus,
)
from .signalr import (
SignalRMeta,
SignalRUnionMessage,
UserState,
)
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
if TYPE_CHECKING:
from app.database.room import Room
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
PER_USER_LIMIT = 3
class MultiplayerClientState(UserState):
room_id: int = 0
class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room"
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = ""
match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState):
team_id: int
union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState
class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(BaseModel):
id: int
name: str
class TeamVersusRoomState(_MatchRoomState):
teams: list[MultiplayerTeam] = Field(
default_factory=lambda: [
MultiplayerTeam(id=0, name="Team Red"),
MultiplayerTeam(id=1, name="Team Blue"),
]
)
union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState
class PlaylistItem(BaseModel):
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int
beatmap_id: int
beatmap_checksum: str
ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool
playlist_order: int
played_at: datetime | None = None
star_rating: float
freestyle: bool
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
for i, mod1 in enumerate(mods):
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
if mod1_settings:
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
for req_mod in self.required_mods:
req_acronym = req_mod["acronym"]
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
if req_settings:
incompatible = set(req_settings.get("IncompatibleMods", []))
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
# Validate required mods
for mod in self.required_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
# Validate allowed mods
for mod in self.allowed_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
# Check internal compatibility of required mods
self._check_mod_compatibility(self.required_mods, ruleset_key)
# Check compatibility between required and allowed mods
self._check_required_allowed_compatibility(ruleset_key)
def validate_user_mods(
self,
user: "MultiplayerRoomUser",
proposed_mods: list[APIMod],
) -> tuple[bool, list[APIMod]]:
"""
Validates user mods against playlist item rules and returns valid mods.
Returns (is_valid, valid_mods).
"""
from typing import Literal, cast
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
valid_mods = []
all_proposed_valid = True
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
all_proposed_valid = False
continue
valid_mods.append(mod)
# Check mod compatibility within user mods
incompatible_mods = set()
final_valid_mods = []
for mod in valid_mods:
if mod["acronym"] in incompatible_mods:
all_proposed_valid = False
continue
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod)
# If not freestyle, check against allowed mods
if not self.freestyle:
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = []
for mod in final_valid_mods:
if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False
else:
filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
for mod in final_valid_mods:
mod_acronym = mod["acronym"]
is_compatible = True
for other_acronym in all_mod_acronyms:
if other_acronym == mod_acronym:
continue
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
is_compatible = False
all_proposed_valid = False
break
if is_compatible:
filtered_valid_mods.append(mod)
return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
copy.expired = False
copy.played_at = None
return copy
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
class MatchStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle
class MultiplayerRoom(BaseModel):
room_id: int
state: MultiplayerRoomState
settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list)
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int
@classmethod
def from_db(cls, room: "Room") -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
# 用户列表
users = [MultiplayerRoomUser(user_id=room.host_id)]
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if room.playlist:
for item in room.playlist:
playlist.append(
PlaylistItem(
id=item.id,
owner_id=item.owner_id,
beatmap_id=item.beatmap_id,
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
ruleset_id=item.ruleset_id,
required_mods=item.required_mods,
allowed_mods=item.allowed_mods,
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
freestyle=item.freestyle,
)
)
return cls(
room_id=room.id,
state=getattr(room, "state", MultiplayerRoomState.OPEN),
settings=MultiplayerRoomSettings(
name=room.name,
playlist_item_id=playlist[0].id if playlist else 0,
password=getattr(room, "password", ""),
match_type=room.type,
queue_mode=room.queue_mode,
auto_start_duration=timedelta(seconds=room.auto_start_duration),
auto_skip=room.auto_skip,
),
users=users,
host=host_user,
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=room.channel_id or 0,
)
class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room
self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property
def upcoming_items(self):
return sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda i: i.playlist_order,
)
@property
def room(self):
return self.server_room.room
async def update_order(self):
from app.database import Playlist
match self.room.settings.queue_mode:
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
ordered_active_items = []
is_first_set = True
first_set_order_by_user_id = {}
active_items = [item for item in self.room.playlist if not item.expired]
active_items.sort(key=lambda x: x.id)
user_item_groups = {}
for item in active_items:
if item.owner_id not in user_item_groups:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max((len(items) for items in user_item_groups.values()), default=0)
for i in range(max_items):
current_set = []
for user_id, items in user_item_groups.items():
if i < len(items):
current_set.append(items[i])
if is_first_set:
current_set.sort(key=lambda item: (item.playlist_order, item.id))
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
ordered_active_items.extend(current_set)
is_first_set = False
case _:
ordered_active_items = sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with with_db() as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
async def update_current_item(self):
upcoming_items = self.upcoming_items
if upcoming_items:
# 优先选择未过期的项目
next_item = upcoming_items[0]
else:
# 如果所有项目都过期了选择最近添加的项目played_at 为 None 或最新的)
# 优先选择 expired=False 的项目,然后是 played_at 最晚的
next_item = max(
self.room.playlist,
key=lambda i: (not i.expired, i.played_at or datetime.min),
)
self.current_index = self.room.playlist.index(next_item)
last_id = self.room.settings.playlist_item_id
self.room.settings.playlist_item_id = next_item.id
if last_id != next_item.id:
await self.hub.setting_changed(self.server_room, True)
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
is_host = self.room.host and self.room.host.user_id == user.user_id
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = beatmap.difficulty_rating
await Playlist.add_to_db(item, self.room.room_id, session)
self.room.playlist.append(item)
await self.hub.playlist_added(self.server_room, item)
await self.update_order()
await self.update_current_item()
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
if existing_item is None:
raise InvokeException("Attempted to change an item that doesn't exist")
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException("Attempted to change an item which is not owned by the user")
if existing_item.expired:
raise InvokeException("Attempted to change an item which has already been played")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = float(beatmap.difficulty_rating)
item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session)
# Update item in playlist
for idx, playlist_item in enumerate(self.room.playlist):
if playlist_item.id == item.id:
self.room.playlist[idx] = item
break
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
from app.database import Playlist
item = next(
(i for i in self.room.playlist if i.id == playlist_item_id),
None,
)
if item is None:
raise InvokeException("Item does not exist in the room")
# Check if it's the only item and current item
if item == self.current_item:
upcoming_items = [i for i in self.room.playlist if not i.expired]
if len(upcoming_items) == 1:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException("Attempted to remove an item which is not owned by the user")
if item.expired:
raise InvokeException("Attempted to remove an item which has already been played")
async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
if found_item:
self.room.playlist.remove(found_item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()
await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with with_db() as session:
played_at = utcnow()
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_current_item()
async def update_queue_mode(self):
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_order()
await self.update_current_item()
@property
def current_item(self):
return self.room.playlist[self.current_index]
@dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
)
class _MatchRequest(SignalRUnionMessage): ...
class ChangeTeamRequest(_MatchRequest):
union_type: ClassVar[Literal[0]] = 0
team_id: int
class StartMatchCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[1]] = 1
duration: timedelta
class StopCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[2]] = 2
id: int
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
class MatchTypeHandler(ABC):
def __init__(self, room: "ServerMultiplayerRoom"):
self.room = room
self.hub = room.hub
@abstractmethod
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@abstractmethod
def get_details(self) -> MatchStartedEventDetail: ...
class HeadToHeadHandler(MatchTypeHandler):
@override
async def handle_join(self, user: MultiplayerRoomUser):
if user.match_state is not None:
user.match_state = None
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
return detail
class TeamVersusHandler(MatchTypeHandler):
@override
def __init__(self, room: "ServerMultiplayerRoom"):
super().__init__(room)
self.state = TeamVersusRoomState()
room.room.match_state = self.state
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
self.hub.tasks.add(task)
task.add_done_callback(self.hub.tasks.discard)
def _get_best_available_team(self) -> int:
for team in self.state.teams:
if all(
(
user.match_state is None
or not isinstance(user.match_state, TeamVersusUserState)
or user.match_state.team_id != team.id
)
for user in self.room.room.users
):
return team.id
from collections import defaultdict
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
team_counts[user.match_state.team_id] += 1
if team_counts:
min_count = min(team_counts.values())
for team_id, count in team_counts.items():
if count == min_count:
return team_id
return self.state.teams[0].id if self.state.teams else 0
@override
async def handle_join(self, user: MultiplayerRoomUser):
best_team_id = self._get_best_available_team()
user.match_state = TeamVersusUserState(team_id=best_team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
if not isinstance(request, ChangeTeamRequest):
return
if request.team_id not in [team.id for team in self.state.teams]:
raise InvokeException("Invalid team ID")
user.match_state = TeamVersusUserState(team_id=request.team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
MATCH_TYPE_HANDLERS = {
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
MatchType.TEAM_VERSUS: TeamVersusHandler,
}
@dataclass
class ServerMultiplayerRoom:
room: MultiplayerRoom
category: RoomCategory
status: RoomStatus
start_at: datetime
hub: "MultiplayerHub"
match_type_handler: MatchTypeHandler
queue: MultiplayerQueue
_next_countdown_id: int
_countdown_id_lock: asyncio.Lock
_tracked_countdown: dict[int, CountdownInfo]
def __init__(
self,
room: MultiplayerRoom,
category: RoomCategory,
start_at: datetime,
hub: "MultiplayerHub",
):
self.room = room
self.category = category
self.status = RoomStatus.IDLE
self.start_at = start_at
self.hub = hub
self.queue = MultiplayerQueue(self)
self._next_countdown_id = 0
self._countdown_id_lock = asyncio.Lock()
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
if on_complete is not None:
await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive:
await self.stop_all_countdowns(countdown.__class__)
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = self._tracked_countdown.get(countdown.id)
if info is None:
return
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
for countdown in list(self._tracked_countdown.values()):
if isinstance(countdown.countdown, typ):
await self.stop_countdown(countdown.countdown)
class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent):
id: int
union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
class GameplayAbortReason(IntEnum):
LOAD_TOOK_TOO_LONG = 0
HOST_ABORTED = 1
class MatchStartedEventDetail(TypedDict):
room_type: Literal["playlists", "head_to_head", "team_versus"]
team: dict[int, Literal["blue", "red"]] | None

View File

@@ -1,4 +1,4 @@
from __future__ import annotations
# ruff: noqa: ARG002
from abc import abstractmethod
from enum import Enum
@@ -118,10 +118,7 @@ class ChannelMessageBase(NotificationDetail):
channel_type: "ChannelType",
) -> Self:
try:
avatar_url = (
getattr(user, "avatar_url", "https://lazer-data.g0v0.top/default.jpg")
or "https://lazer-data.g0v0.top/default.jpg"
)
avatar_url = user.avatar_url or "https://lazer-data.g0v0.top/default.jpg"
except Exception:
avatar_url = "https://lazer-data.g0v0.top/default.jpg"
instance = cls(
@@ -160,7 +157,7 @@ class ChannelMessageTeam(ChannelMessageBase):
cls,
message: "ChatMessage",
user: "User",
) -> ChannelMessageTeam:
) -> Self:
from app.database import ChannelType
return super().init(message, user, [], ChannelType.TEAM)

View File

@@ -1,4 +1,4 @@
# OAuth 相关模型 # noqa: I002
# OAuth 相关模型
from typing import Annotated, Any, cast
from typing_extensions import Doc
@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
class TokenResponse(BaseModel):
access_token: str
token_type: str = "Bearer"
token_type: str = "Bearer" # noqa: S105
expires_in: int
refresh_token: str
scope: str = "*"
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
class OAuth2ClientCredentialsBearer(OAuth2):
def __init__(
self,
tokenUrl: Annotated[
tokenUrl: Annotated[ # noqa: N803
str,
Doc(
"""
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
"""
),
],
refreshUrl: Annotated[
refreshUrl: Annotated[ # noqa: N803
str | None,
Doc(
"""

20
app/models/playlist.py Normal file
View File

@@ -0,0 +1,20 @@
from datetime import datetime
from app.models.mods import APIMod
from pydantic import BaseModel, Field
class PlaylistItem(BaseModel):
id: int = Field(default=0, ge=-1)
owner_id: int
beatmap_id: int
beatmap_checksum: str = ""
ruleset_id: int = 0
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool = False
playlist_order: int = 0
played_at: datetime | None = None
star_rating: float = 0.0
freestyle: bool = False

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypedDict, cast

View File

@@ -1,37 +1 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
from pydantic import (
BaseModel,
Field,
)
@dataclass
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_abbr: bool = True
class SignalRUnionMessage(BaseModel):
union_type: ClassVar[int]
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
class NegotiateResponse(BaseModel):
connectionId: str
connectionToken: str
negotiateVersion: int = 1
availableTransports: list[Transport]
class UserState(BaseModel):
connection_id: str
connection_token: str

View File

@@ -1,131 +0,0 @@
from __future__ import annotations
import datetime
from enum import IntEnum
from typing import Annotated, Any
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from .score import (
ScoreStatistics,
)
from .signalr import SignalRMeta, UserState
from pydantic import BaseModel, Field, field_validator
class SpectatedUserState(IntEnum):
Idle = 0
Playing = 1
Paused = 2
Passed = 3
Failed = 4
Quit = 5
class SpectatorState(BaseModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
return False
return (
self.beatmap_id == other.beatmap_id
and self.ruleset_id == other.ruleset_id
and self.mods == other.mods
and self.state == other.state
)
class ScoreProcessorStatistics(BaseModel):
base_score: float
maximum_base_score: float
accuracy_judgement_count: int
combo_portion: float
bonus_portion: float
class FrameHeader(BaseModel):
total_score: int
accuracy: float
combo: int
max_combo: int
statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@field_validator("received_time", mode="before")
@classmethod
def validate_timestamp(cls, v: Any) -> datetime.datetime:
if isinstance(v, list):
return v[0]
if isinstance(v, datetime.datetime):
return v
if isinstance(v, int | float):
return datetime.datetime.fromtimestamp(v, tz=datetime.UTC)
if isinstance(v, str):
return datetime.datetime.fromisoformat(v)
raise ValueError(f"Cannot convert {type(v)} to datetime")
# class ReplayButtonState(IntEnum):
# NONE = 0
# LEFT1 = 1
# RIGHT1 = 2
# LEFT2 = 4
# RIGHT2 = 8
# SMOKE = 16
class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame
mouse_x: float | None = None
mouse_y: float | None = None
button_state: int
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
class FrameDataBundle(BaseModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
# Use for server
class APIUser(BaseModel):
id: int
name: str
class ScoreInfo(BaseModel):
mods: list[APIMod]
user: APIUser
ruleset: int
maximum_statistics: ScoreStatistics
id: int | None = None
total_score: int | None = None
accuracy: float | None = None
max_combo: int | None = None
combo: int | None = None
statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel):
score_info: ScoreInfo
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
class StoreClientState(UserState):
state: SpectatorState | None = None
beatmap_status: BeatmapRankStatus | None = None
checksum: str | None = None
ruleset_id: int | None = None
score_token: int | None = None
watched_user: set[int] = Field(default_factory=set)
score: StoreScore | None = None

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import json
from app.log import logger
from app.log import log
from app.path import STATIC_DIR
from pydantic import BaseModel
@@ -16,6 +14,7 @@ class BeatmapTags(BaseModel):
ALL_TAGS: dict[int, BeatmapTags] = {}
logger = log("BeatmapTag")
def load_tags() -> None:

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum
from typing import TypedDict

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import NotRequired, TypedDict

View File

@@ -2,8 +2,6 @@
用户页面编辑相关的API模型
"""
from __future__ import annotations
from pydantic import BaseModel, Field, field_validator
@@ -56,3 +54,47 @@ class ValidateBBCodeResponse(BaseModel):
valid: bool = Field(description="BBCode是否有效")
errors: list[str] = Field(default_factory=list, description="错误列表")
preview: dict[str, str] = Field(description="预览内容")
class UserpageError(Exception):
"""用户页面处理错误基类"""
def __init__(self, message: str, code: str = "userpage_error"):
self.message = message
self.code = code
super().__init__(message)
class ContentTooLongError(UserpageError):
"""内容过长错误"""
def __init__(self, current_length: int, max_length: int):
message = f"Content too long. Maximum {max_length} characters allowed, got {current_length}."
super().__init__(message, "content_too_long")
self.current_length = current_length
self.max_length = max_length
class ContentEmptyError(UserpageError):
"""内容为空错误"""
def __init__(self):
super().__init__("Content cannot be empty.", "content_empty")
class BBCodeValidationError(UserpageError):
"""BBCode验证错误"""
def __init__(self, errors: list[str]):
message = f"BBCode validation failed: {'; '.join(errors)}"
super().__init__(message, "bbcode_validation_error")
self.errors = errors
class ForbiddenTagError(UserpageError):
"""禁止标签错误"""
def __init__(self, tag: str):
message = f"Forbidden tag '{tag}' is not allowed."
super().__init__(message, "forbidden_tag")
self.tag = tag

View File

@@ -1,7 +1,5 @@
"""V1 API 用户相关模型"""
from __future__ import annotations
from pydantic import BaseModel, Field
@@ -46,10 +44,10 @@ class PlayerStatsResponse(BaseModel):
class PlayerEventItem(BaseModel):
"""玩家事件项目"""
userId: int
userId: int # noqa: N815
name: str
mapId: int | None = None
setId: int | None = None
mapId: int | None = None # noqa: N815
setId: int | None = None # noqa: N815
artist: str | None = None
title: str | None = None
version: str | None = None
@@ -88,7 +86,7 @@ class PlayerInfo(BaseModel):
custom_badge_icon: str
custom_badge_color: str
userpage_content: str
recentFailed: int
recentFailed: int # noqa: N815
social_discord: str | None = None
social_youtube: str | None = None
social_twitter: str | None = None

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from pathlib import Path
STATIC_DIR = Path(__file__).parent.parent / "static"

View File

@@ -1,6 +1,3 @@
from __future__ import annotations
# from app.signalr import signalr_router as signalr_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router
@@ -25,5 +22,4 @@ __all__ = [
"private_router",
"redirect_api_router",
"redirect_router",
# "signalr_router",
]

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
from datetime import timedelta
import re
from typing import Literal
from typing import Annotated, Literal
from app.auth import (
authenticate_user,
@@ -19,11 +17,10 @@ from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.database import Database, Redis
from app.dependencies.geoip import GeoIPService, IPAddress
from app.dependencies.user_agent import UserAgentInfo
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.log import log
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import (
OAuthErrorResponse,
@@ -40,12 +37,13 @@ from app.service.verification_service import (
)
from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Header, Request
from fastapi import APIRouter, Form, Header, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import exists, select
logger = log("Auth")
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
"""创建标准的 OAuth 错误响应"""
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
)
async def register_user(
db: Database,
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
geoip: GeoIPService,
client_ip: IPAddress,
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
@@ -126,22 +124,22 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
country_code = None # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
if geo_info and (country_code := geo_info.get("country_iso")):
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
if country_code is None:
country_code = "CN"
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
# 确保 AUTO_INCREMENT 值从3开始ID=2是BanchoBot
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
@@ -158,7 +156,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country_code=country_code, # 根据 IP 地理位置设置国家
country_code=country_code,
join_date=utcnow(),
last_visit=utcnow(),
is_supporter=settings.enable_supporter_for_all_users,
@@ -201,19 +199,21 @@ async def oauth_token(
db: Database,
request: Request,
user_agent: UserAgentInfo,
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
ip_address: IPAddress,
grant_type: Annotated[
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
],
client_id: Annotated[int, Form(..., description="客户端 ID")],
client_secret: Annotated[str, Form(..., description="客户端密钥")],
redis: Redis,
geoip: GeoIPService,
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*'")] = "*",
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
):
scopes = scope.split(" ")
@@ -311,8 +311,6 @@ async def oauth_token(
)
token_id = token.id
ip_address = get_client_ip(request)
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
@@ -363,9 +361,7 @@ async def oauth_token(
await LoginSessionService.mark_session_verified(
db, redis, user_id, token_id, ip_address, user_agent, web_uuid
)
logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
)
logger.debug(f"New location login detected but email verification disabled, auto-verifying user {user_id}")
else:
# 不是新设备登录,正常登录
await LoginLogService.record_login(
@@ -389,7 +385,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=scope,
@@ -442,7 +438,7 @@ async def oauth_token(
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token,
scope=scope,
@@ -508,11 +504,11 @@ async def oauth_token(
)
# 打印jwt
logger.info(f"[Auth] Generated JWT for user {user_id}: {access_token}")
logger.info(f"Generated JWT for user {user_id}: {access_token}")
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
@@ -557,7 +553,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
@@ -571,16 +567,14 @@ async def oauth_token(
)
async def request_password_reset(
request: Request,
email: str = Form(..., description="邮箱地址"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
redis: Redis,
ip_address: IPAddress,
):
"""
请求密码重置
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 请求密码重置
@@ -599,20 +593,16 @@ async def request_password_reset(
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password(
request: Request,
email: str = Form(..., description="邮箱地址"),
reset_code: str = Form(..., description="重置验证"),
new_password: str = Form(..., description="新密码"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
reset_code: Annotated[str, Form(..., description="重置验证码")],
new_password: Annotated[str, Form(..., description="新密")],
redis: Redis,
ip_address: IPAddress,
):
"""
重置密码
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码
success, message = await password_reset_service.reset_password(
email=email.lower().strip(),

Some files were not shown because too many files have changed in this diff Show More