refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
import asyncio
import hashlib
import json
from typing import Annotated
from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, int_to_mods
from app.models.score import (
@@ -18,10 +18,9 @@ from app.models.score import (
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlmodel import col, select
@@ -44,11 +43,11 @@ class BatchGetResp(BaseModel):
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
):
if id is None and md5 is None and filename is None:
raise HTTPException(
@@ -75,9 +74,9 @@ async def lookup_beatmap(
)
async def get_beatmap(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
@@ -95,9 +94,12 @@ async def get_beatmap(
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_ids: Annotated[
list[int],
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
if not beatmap_ids:
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
@@ -127,16 +129,19 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mods: Annotated[
list[str],
Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
],
redis: Redis,
fetcher: Fetcher,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset为空则使用谱面自身模式")] = None,
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
):
mods_ = []
if mods and mods[0].isdigit():

View File

@@ -6,23 +6,20 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.dependencies.beatmap_download import DownloadService
from app.dependencies.beatmapset_cache import BeatmapsetCacheService
from app.dependencies.database import Database, Redis, with_db
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.models.beatmap import SearchQueryModel
from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmap_download_service import BeatmapDownloadService
from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash
from app.service.beatmapset_cache_service import generate_hash
from .router import router
from fastapi import (
BackgroundTasks,
Depends,
Form,
HTTPException,
Path,
@@ -53,10 +50,10 @@ async def search_beatmapset(
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
redis=Depends(get_redis),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
redis: Redis,
cache_service: BeatmapsetCacheService,
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
@@ -134,10 +131,10 @@ async def search_beatmapset(
async def lookup_beatmapset(
db: Database,
request: Request,
beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
beatmap_id: Annotated[int, Query(description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
cache_service: BeatmapsetCacheService,
):
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
@@ -170,10 +167,10 @@ async def lookup_beatmapset(
async def get_beatmapset(
db: Database,
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
cache_service: BeatmapsetCacheService,
):
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
@@ -203,14 +200,12 @@ async def get_beatmapset(
description="\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。",
)
async def download_beatmapset(
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
current_user: User = Security(get_client_user),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
client_ip: IPAddress,
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
current_user: ClientUser,
download_service: DownloadService,
no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True,
):
client_ip = get_client_ip(request)
geoip_helper = get_geoip_helper()
geo_info = geoip_helper.lookup(client_ip)
country_code = geo_info.get("country_iso", "")
@@ -242,9 +237,12 @@ async def download_beatmapset(
)
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
current_user: User = Security(get_client_user),
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
action: Annotated[
Literal["favourite", "unfavourite"],
Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
],
current_user: ClientUser,
):
existing_favourite = (
await db.exec(

View File

@@ -5,14 +5,13 @@
from __future__ import annotations
from app.dependencies.database import get_redis
from app.dependencies.database import Redis
from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import Depends, HTTPException
from fastapi import HTTPException
from pydantic import BaseModel
from redis.asyncio import Redis
class CacheStatsResponse(BaseModel):
@@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel):
tags=["缓存管理"],
)
async def get_cache_stats(
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
):
try:
@@ -68,7 +67,7 @@ async def get_cache_stats(
)
async def invalidate_user_cache(
user_id: int,
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
@@ -87,7 +86,7 @@ async def invalidate_user_cache(
tags=["缓存管理"],
)
async def clear_all_user_cache(
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
@@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel):
)
async def warmup_cache(
request: CacheWarmupRequest,
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Annotated
from app.database import MeResp, User
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user_and_token
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.exceptions.userpage import UserpageError
from app.models.score import GameMode
from app.models.user import Page
@@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security
)
async def get_user_info_with_ruleset(
session: Database,
ruleset: GameMode = Path(description="指定 ruleset"),
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
return user_resp
@@ -45,7 +46,7 @@ async def get_user_info_with_ruleset(
)
async def get_user_info_default(
session: Database,
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
return user_resp
@@ -85,8 +86,8 @@ async def get_user_info_default(
async def update_userpage(
request: UpdateUserpageRequest,
session: Database,
user_id: int = Path(description="用户ID"),
current_user: User = Security(get_current_user, scopes=["edit"]),
user_id: Annotated[int, Path(description="用户ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["edit"])],
):
"""更新用户页面内容匹配官方osu-web实现"""
# 检查权限:只能编辑自己的页面(除非是管理员)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Literal
from typing import Annotated, Literal
from app.config import settings
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
@@ -45,11 +45,11 @@ SortType = Literal["performance", "score"]
async def get_team_ranking_pp(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user)
return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page)
@router.get(
@@ -62,14 +62,17 @@ async def get_team_ranking_pp(
async def get_team_ranking(
session: Database,
background_tasks: BackgroundTasks,
sort: SortType = Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
sort: Annotated[
SortType,
Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
],
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
@@ -193,11 +196,11 @@ class CountryResponse(BaseModel):
async def get_country_ranking_pp(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user)
return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page)
@router.get(
@@ -210,14 +213,17 @@ async def get_country_ranking_pp(
async def get_country_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
sort: SortType = Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
sort: Annotated[
SortType,
Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
@@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel):
async def get_user_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
sort: SortType = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description=""),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
sort: Annotated[SortType, Path(..., description="排名类型performance 表现分 / score 计分成绩总分")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
country: Annotated[str | None, Query(description="国家代")] = None,
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Annotated
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.database.user import UserResp
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database
from app.dependencies.user import get_client_user, get_current_user
from app.dependencies.user import ClientUser, get_current_user
from .router import router
@@ -56,7 +58,7 @@ async def get_relationship(
db: Database,
request: Request,
api_version: APIVersion,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])],
):
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec(
@@ -107,8 +109,8 @@ class AddFriendResp(BaseModel):
async def add_relationship(
db: Database,
request: Request,
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
target: Annotated[int, Query(description="目标用户 ID")],
current_user: ClientUser,
):
if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found")
@@ -176,8 +178,8 @@ async def add_relationship(
async def delete_relationship(
db: Database,
request: Request,
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
target: Annotated[int, Path(..., description="目标用户 ID")],
current_user: ClientUser,
):
if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import UTC
from typing import Literal
from typing import Annotated, Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
@@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.database.user import User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user, get_current_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.signalr.hub import MultiplayerHubs
@@ -21,9 +21,8 @@ from app.utils import utcnow
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession
)
async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned"] | None = Query(
default="open",
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
category: RoomCategory = Query(
RoomCategory.NORMAL,
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mode: Annotated[
Literal["open", "ended", "participated", "owned"] | None,
Query(
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
] = "open",
category: Annotated[
RoomCategory,
Query(
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
):
resp_list: list[RoomResp] = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
@@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
async def create_room(
db: Database,
room: APIUploadedRoom,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redis: Redis,
):
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
@@ -162,13 +165,15 @@ async def create_room(
)
async def get_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
redis: Redis = Depends(get_redis),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
redis: Redis,
category: Annotated[
str,
Query(
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
] = "",
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -185,8 +190,8 @@ async def get_room(
)
async def delete_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_client_user),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: ClientUser,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -205,10 +210,10 @@ async def delete_room(
)
async def add_user_to_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
redis: Redis,
current_user: ClientUser,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -229,10 +234,10 @@ async def add_user_to_room(
)
async def remove_user_from_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -273,8 +278,8 @@ class APILeaderboard(BaseModel):
)
async def get_room_leaderboard(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -329,11 +334,11 @@ class RoomEvents(BaseModel):
)
async def get_room_events(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None,
before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None,
):
events = (
await db.exec(

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, date
import time
from typing import Annotated
from app.calculator import clamp
from app.config import settings
@@ -34,11 +35,10 @@ from app.database.score import (
process_user,
)
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.dependencies.database import Database, Redis, get_redis, with_db
from app.dependencies.fetcher import Fetcher, get_fetcher
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser, get_current_user
from app.log import logger
from app.models.beatmap import BeatmapRankStatus
from app.models.room import RoomCategory
@@ -50,7 +50,6 @@ from app.models.score import (
)
from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from app.utils import utcnow
from .router import router
@@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse
from fastapi_limiter.depends import RateLimiter
from httpx import HTTPError
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
async def get_beatmap_scores(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
type: LeaderboardType = Query(
LeaderboardType.GLOBAL,
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
mode: Annotated[GameMode, Query(description="指定 auleset")],
mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
type: Annotated[
LeaderboardType,
Query(
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
] = LeaderboardType.GLOBAL,
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
):
if legacy_only:
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
@@ -294,12 +294,12 @@ async def get_beatmap_scores(
async def get_user_beatmap_score(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
current_user: User = Security(get_current_user, scopes=["public"]),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: Annotated[int, Path(description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None,
):
user_score = (
await db.exec(
@@ -342,11 +342,11 @@ async def get_user_beatmap_score(
async def get_user_all_beatmap_scores(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: Annotated[int, Path(description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
):
all_user_scores = (
await db.exec(
@@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores(
async def create_solo_score(
background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
version_hash: str = Form("", description="游戏版本哈希"),
beatmap_hash: str = Form(description="谱面文件哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
current_user: ClientUser,
version_hash: Annotated[str, Form(description="游戏版本哈希")] = "",
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -406,12 +406,12 @@ async def create_solo_score(
async def submit_solo_score(
background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
token: int = Path(description="成绩令牌 ID"),
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
token: Annotated[int, Path(description="成绩令牌 ID")],
info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")],
current_user: ClientUser,
redis: Redis,
fetcher: Fetcher,
):
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@@ -428,11 +428,11 @@ async def create_playlist_score(
background_task: BackgroundTasks,
room_id: int,
playlist_id: int,
beatmap_id: int = Form(description="谱面 ID"),
beatmap_hash: str = Form(description="游戏版本哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
beatmap_id: Annotated[int, Form(description="谱面 ID")],
beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
current_user: ClientUser,
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -496,9 +496,9 @@ async def submit_playlist_score(
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
current_user: ClientUser,
redis: Redis,
fetcher: Fetcher,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -555,9 +555,9 @@ async def index_playlist_scores(
session: Database,
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
current_user: User = Security(get_current_user, scopes=["public"]),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50,
cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -623,8 +623,8 @@ async def show_playlist_score(
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redis: Redis,
):
room = await session.get(Room, room_id)
if not room:
@@ -692,7 +692,7 @@ async def get_user_playlist_score(
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
score_record = None
start_time = time.time()
@@ -725,8 +725,8 @@ async def get_user_playlist_score(
)
async def pin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -770,8 +770,8 @@ async def pin_score(
)
async def unpin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -805,10 +805,10 @@ async def unpin_score(
)
async def reorder_score_pin(
db: Database,
score_id: int = Path(description="成绩 ID"),
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
before_score_id: int | None = Body(default=None, description="放在该成绩之"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
after_score_id: Annotated[int | None, Body(description="放在该成绩之")] = None,
before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -893,8 +893,8 @@ async def reorder_score_pin(
async def download_score_replay(
score_id: int,
db: Database,
current_user: User = Security(get_current_user, scopes=["public"]),
storage_service: StorageService = Depends(get_storage_service),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
storage_service: StorageService,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id

View File

@@ -11,8 +11,8 @@ from app.config import settings
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
from app.database.auth import TotpKeys
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip
from app.dependencies.database import Database, Redis, get_redis
from app.dependencies.geoip import IPAddress
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.dependencies.user_agent import UserAgentInfo
from app.log import logger
@@ -27,7 +27,6 @@ from .router import router
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
class VerifyMethod(BaseModel):
@@ -64,10 +63,14 @@ async def verify_session(
db: Database,
api_version: APIVersion,
user_agent: UserAgentInfo,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
user_and_token: UserAndToken = Security(get_client_user_and_token),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
verification_key: Annotated[
str,
Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
],
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
) -> Response:
current_user = user_and_token[0]
token_id = user_and_token[1].id
@@ -82,7 +85,6 @@ async def verify_session(
else await LoginSessionService.get_login_method(user_id, token_id, redis)
)
ip_address = get_client_ip(request)
login_method = "password"
try:
@@ -182,12 +184,12 @@ async def verify_session(
tags=["验证"],
)
async def reissue_verification_code(
request: Request,
db: Database,
user_agent: UserAgentInfo,
api_version: APIVersion,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> SessionReissueResponse:
current_user = user_and_token[0]
token_id = user_and_token[1].id
@@ -203,7 +205,6 @@ async def reissue_verification_code(
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
try:
ip_address = get_client_ip(request)
user_id = current_user.id
success, message = await EmailVerificationService.resend_verification_code(
db,
@@ -233,17 +234,15 @@ async def reissue_verification_code(
async def fallback_email(
db: Database,
user_agent: UserAgentInfo,
request: Request,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> VerifyMethod:
current_user = user_and_token[0]
token_id = user_and_token[1].id
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
ip_address = get_client_ip(request)
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
success, message = await EmailVerificationService.resend_verification_code(
db,

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap
from app.database.beatmap_tags import BeatmapTagVote
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.models.score import Rank
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
@@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
description="为指定谱面添加标签投票。",
)
async def vote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"),
tag_id: int = Path(..., description="标签 ID"),
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_client_user),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: Annotated[int, Path(..., description="标签 ID")],
session: Database,
current_user: Annotated[User, Depends(get_client_user)],
):
try:
get_tag_by_id(tag_id)
@@ -90,10 +92,10 @@ async def vote_beatmap_tags(
description="取消对指定谱面标签的投票。",
)
async def devote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"),
tag_id: int = Path(..., description="标签 ID"),
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_client_user),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: Annotated[int, Path(..., description="标签 ID")],
session: Database,
current_user: Annotated[User, Depends(get_client_user)],
):
"""
取消对谱面指定标签的投票。

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import timedelta
from typing import Literal
from typing import Annotated, Literal
from app.config import settings
from app.const import BANCHOBOT_ID
@@ -51,9 +51,12 @@ async def get_users(
session: Database,
request: Request,
background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")],
# current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
include_variant_statistics: Annotated[
bool,
Query(description="是否包含各模式的统计信息"),
] = False, # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -119,9 +122,9 @@ async def get_users(
)
async def get_user_events(
session: Database,
user_id: int = Path(description="用户 ID"),
limit: int | None = Query(None, description="限制返回的活动数量"),
offset: int | None = Query(None, description="活动日志的偏移量"),
user_id: Annotated[int, Path(description="用户 ID")],
limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None,
offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
):
db_user = await session.get(User, user_id)
if db_user is None or db_user.id == BANCHOBOT_ID:
@@ -147,9 +150,9 @@ async def get_user_events(
)
async def get_user_kudosu(
session: Database,
user_id: int = Path(description="用户 ID"),
offset: int = Query(default=0, description="偏移量"),
limit: int = Query(default=6, description="返回记录数量限制"),
user_id: Annotated[int, Path(description="用户 ID")],
offset: Annotated[int, Query(description="偏移量")] = 0,
limit: Annotated[int, Query(description="返回记录数量限制")] = 6,
):
"""
获取用户的 kudosu 记录
@@ -176,8 +179,8 @@ async def get_user_kudosu(
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
@@ -225,7 +228,7 @@ async def get_user_info(
background_task: BackgroundTasks,
session: Database,
request: Request,
user_id: str = Path(description="用户 ID 或用户名"),
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
@@ -274,11 +277,11 @@ async def get_user_info(
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
user_id: Annotated[int, Path(description="用户 ID")],
type: Annotated[BeatmapsetType, Path(description="谱面集类型")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -356,16 +359,17 @@ async def get_user_scores(
session: Database,
api_version: APIVersion,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
user_id: Annotated[int, Path(description="用户 ID")],
type: Annotated[
Literal["best", "recent", "firsts", "pinned"],
Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False,
include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False,
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None,
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
):
is_legacy_api = api_version < 20220705
redis = get_redis()