refactor(api): use Annotated-style dependency injection
This commit is contained in:
@@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.database.beatmap import calculate_beatmap_attributes
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import (
|
||||
@@ -18,10 +18,9 @@ from app.models.score import (
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -44,11 +43,11 @@ class BatchGetResp(BaseModel):
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
|
||||
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
|
||||
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
|
||||
):
|
||||
if id is None and md5 is None and filename is None:
|
||||
raise HTTPException(
|
||||
@@ -75,9 +74,9 @@ async def lookup_beatmap(
|
||||
)
|
||||
async def get_beatmap(
|
||||
db: Database,
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
@@ -95,9 +94,12 @@ async def get_beatmap(
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_ids: Annotated[
|
||||
list[int],
|
||||
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
@@ -127,16 +129,19 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
async def get_beatmap_attributes(
|
||||
db: Database,
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
mods: list[str] = Query(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
mods: Annotated[
|
||||
list[str],
|
||||
Query(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
],
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset;为空则使用谱面自身模式")] = None,
|
||||
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
|
||||
):
|
||||
mods_ = []
|
||||
if mods and mods[0].isdigit():
|
||||
|
||||
@@ -6,23 +6,20 @@ from urllib.parse import parse_qs
|
||||
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
||||
from app.dependencies.beatmap_download import get_beatmap_download_service
|
||||
from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency
|
||||
from app.dependencies.database import Database, get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.beatmap_download import DownloadService
|
||||
from app.dependencies.beatmapset_cache import BeatmapsetCacheService
|
||||
from app.dependencies.database import Database, Redis, with_db
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.geoip import IPAddress, get_geoip_helper
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.service.asset_proxy_helper import process_response_assets
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash
|
||||
from app.service.beatmapset_cache_service import generate_hash
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import (
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Form,
|
||||
HTTPException,
|
||||
Path,
|
||||
@@ -53,10 +50,10 @@ async def search_beatmapset(
|
||||
query: Annotated[SearchQueryModel, Query(...)],
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis=Depends(get_redis),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
redis: Redis,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
params = parse_qs(qs=request.url.query, keep_blank_values=True)
|
||||
cursor = {}
|
||||
@@ -134,10 +131,10 @@ async def search_beatmapset(
|
||||
async def lookup_beatmapset(
|
||||
db: Database,
|
||||
request: Request,
|
||||
beatmap_id: int = Query(description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
beatmap_id: Annotated[int, Query(description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
# 先尝试从缓存获取
|
||||
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
||||
@@ -170,10 +167,10 @@ async def lookup_beatmapset(
|
||||
async def get_beatmapset(
|
||||
db: Database,
|
||||
request: Request,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
# 先尝试从缓存获取
|
||||
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
||||
@@ -203,14 +200,12 @@ async def get_beatmapset(
|
||||
description="\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。",
|
||||
)
|
||||
async def download_beatmapset(
|
||||
request: Request,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
|
||||
current_user: User = Security(get_client_user),
|
||||
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
|
||||
client_ip: IPAddress,
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
current_user: ClientUser,
|
||||
download_service: DownloadService,
|
||||
no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True,
|
||||
):
|
||||
client_ip = get_client_ip(request)
|
||||
|
||||
geoip_helper = get_geoip_helper()
|
||||
geo_info = geoip_helper.lookup(client_ip)
|
||||
country_code = geo_info.get("country_iso", "")
|
||||
@@ -242,9 +237,12 @@ async def download_beatmapset(
|
||||
)
|
||||
async def favourite_beatmapset(
|
||||
db: Database,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
action: Annotated[
|
||||
Literal["favourite", "unfavourite"],
|
||||
Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
|
||||
@@ -5,14 +5,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.database import Redis
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
@@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel):
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def get_cache_stats(
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
|
||||
):
|
||||
try:
|
||||
@@ -68,7 +67,7 @@ async def get_cache_stats(
|
||||
)
|
||||
async def invalidate_user_cache(
|
||||
user_id: int,
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
@@ -87,7 +86,7 @@ async def invalidate_user_cache(
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def clear_all_user_cache(
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
@@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel):
|
||||
)
|
||||
async def warmup_cache(
|
||||
request: CacheWarmupRequest,
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import MeResp, User
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import UserAndToken, get_current_user_and_token
|
||||
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
|
||||
from app.exceptions.userpage import UserpageError
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import Page
|
||||
@@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security
|
||||
)
|
||||
async def get_user_info_with_ruleset(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
@@ -45,7 +46,7 @@ async def get_user_info_with_ruleset(
|
||||
)
|
||||
async def get_user_info_default(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
@@ -85,8 +86,8 @@ async def get_user_info_default(
|
||||
async def update_userpage(
|
||||
request: UpdateUserpageRequest,
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["edit"]),
|
||||
user_id: Annotated[int, Path(description="用户ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["edit"])],
|
||||
):
|
||||
"""更新用户页面内容(匹配官方osu-web实现)"""
|
||||
# 检查权限:只能编辑自己的页面(除非是管理员)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import GameMode
|
||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
|
||||
@@ -45,11 +45,11 @@ SortType = Literal["performance", "score"]
|
||||
async def get_team_ranking_pp(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user)
|
||||
return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -62,14 +62,17 @@ async def get_team_ranking_pp(
|
||||
async def get_team_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
sort: SortType = Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
sort: Annotated[
|
||||
SortType,
|
||||
Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
],
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
@@ -193,11 +196,11 @@ class CountryResponse(BaseModel):
|
||||
async def get_country_ranking_pp(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user)
|
||||
return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -210,14 +213,17 @@ async def get_country_ranking_pp(
|
||||
async def get_country_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
sort: SortType = Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
sort: Annotated[
|
||||
SortType,
|
||||
Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
@@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel):
|
||||
async def get_user_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
sort: SortType = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"),
|
||||
country: str | None = Query(None, description="国家代码"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
sort: Annotated[SortType, Path(..., description="排名类型:performance 表现分 / score 计分成绩总分")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
country: Annotated[str | None, Query(description="国家代码")] = None,
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
||||
from app.database.user import UserResp
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -56,7 +58,7 @@ async def get_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
api_version: APIVersion,
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])],
|
||||
):
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
relationships = await db.exec(
|
||||
@@ -107,8 +109,8 @@ class AddFriendResp(BaseModel):
|
||||
async def add_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
target: int = Query(description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
target: Annotated[int, Query(description="目标用户 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if not (await db.exec(select(exists()).where(User.id == target))).first():
|
||||
raise HTTPException(404, "Target user not found")
|
||||
@@ -176,8 +178,8 @@ async def add_relationship(
|
||||
async def delete_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
target: int = Path(..., description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
target: Annotated[int, Path(..., description="目标用户 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if not (await db.exec(select(exists()).where(User.id == target))).first():
|
||||
raise HTTPException(404, "Target user not found")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
@@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.score import Score
|
||||
from app.database.user import User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.room import RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
@@ -21,9 +21,8 @@ from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
)
|
||||
async def get_all_rooms(
|
||||
db: Database,
|
||||
mode: Literal["open", "ended", "participated", "owned"] | None = Query(
|
||||
default="open",
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
category: RoomCategory = Query(
|
||||
RoomCategory.NORMAL,
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
mode: Annotated[
|
||||
Literal["open", "ended", "participated", "owned"] | None,
|
||||
Query(
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
] = "open",
|
||||
category: Annotated[
|
||||
RoomCategory,
|
||||
Query(
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
] = RoomCategory.NORMAL,
|
||||
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
|
||||
):
|
||||
resp_list: list[RoomResp] = []
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
|
||||
@@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
|
||||
async def create_room(
|
||||
db: Database,
|
||||
room: APIUploadedRoom,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
@@ -162,13 +165,15 @@ async def create_room(
|
||||
)
|
||||
async def get_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
category: str = Query(
|
||||
default="",
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
redis: Redis,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
] = "",
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -185,8 +190,8 @@ async def get_room(
|
||||
)
|
||||
async def delete_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -205,10 +210,10 @@ async def delete_room(
|
||||
)
|
||||
async def add_user_to_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -229,10 +234,10 @@ async def add_user_to_room(
|
||||
)
|
||||
async def remove_user_from_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -273,8 +278,8 @@ class APILeaderboard(BaseModel):
|
||||
)
|
||||
async def get_room_leaderboard(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -329,11 +334,11 @@ class RoomEvents(BaseModel):
|
||||
)
|
||||
async def get_room_events(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
|
||||
before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None,
|
||||
before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None,
|
||||
):
|
||||
events = (
|
||||
await db.exec(
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.config import settings
|
||||
@@ -34,11 +35,10 @@ from app.database.score import (
|
||||
process_user,
|
||||
)
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.database import Database, Redis, get_redis, with_db
|
||||
from app.dependencies.fetcher import Fetcher, get_fetcher
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.log import logger
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.room import RoomCategory
|
||||
@@ -50,7 +50,6 @@ from app.models.score import (
|
||||
)
|
||||
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||
from app.service.user_cache_service import refresh_user_cache_background
|
||||
from app.storage.base import StorageService
|
||||
from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
@@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from httpx import HTTPError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
|
||||
async def get_beatmap_scores(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
mode: GameMode = Query(description="指定 auleset"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
|
||||
type: LeaderboardType = Query(
|
||||
LeaderboardType.GLOBAL,
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
mode: Annotated[GameMode, Query(description="指定 auleset")],
|
||||
mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
type: Annotated[
|
||||
LeaderboardType,
|
||||
Query(
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
] = LeaderboardType.GLOBAL,
|
||||
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
|
||||
@@ -294,12 +294,12 @@ async def get_beatmap_scores(
|
||||
async def get_user_beatmap_score(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
|
||||
mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None,
|
||||
):
|
||||
user_score = (
|
||||
await db.exec(
|
||||
@@ -342,11 +342,11 @@ async def get_user_beatmap_score(
|
||||
async def get_user_all_beatmap_scores(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
|
||||
):
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
@@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores(
|
||||
async def create_solo_score(
|
||||
background_task: BackgroundTasks,
|
||||
db: Database,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
version_hash: str = Form("", description="游戏版本哈希"),
|
||||
beatmap_hash: str = Form(description="谱面文件哈希"),
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
|
||||
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
|
||||
current_user: ClientUser,
|
||||
version_hash: Annotated[str, Form(description="游戏版本哈希")] = "",
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -406,12 +406,12 @@ async def create_solo_score(
|
||||
async def submit_solo_score(
|
||||
background_task: BackgroundTasks,
|
||||
db: Database,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
token: int = Path(description="成绩令牌 ID"),
|
||||
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
token: Annotated[int, Path(description="成绩令牌 ID")],
|
||||
info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
|
||||
|
||||
@@ -428,11 +428,11 @@ async def create_playlist_score(
|
||||
background_task: BackgroundTasks,
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
beatmap_id: int = Form(description="谱面 ID"),
|
||||
beatmap_hash: str = Form(description="游戏版本哈希"),
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
version_hash: str = Form("", description="谱面版本哈希"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmap_id: Annotated[int, Form(description="谱面 ID")],
|
||||
beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
|
||||
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
|
||||
current_user: ClientUser,
|
||||
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -496,9 +496,9 @@ async def submit_playlist_score(
|
||||
playlist_id: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -555,9 +555,9 @@ async def index_playlist_scores(
|
||||
session: Database,
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50,
|
||||
cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -623,8 +623,8 @@ async def show_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
@@ -692,7 +692,7 @@ async def get_user_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
score_record = None
|
||||
start_time = time.time()
|
||||
@@ -725,8 +725,8 @@ async def get_user_playlist_score(
|
||||
)
|
||||
async def pin_score(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -770,8 +770,8 @@ async def pin_score(
|
||||
)
|
||||
async def unpin_score(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -805,10 +805,10 @@ async def unpin_score(
|
||||
)
|
||||
async def reorder_score_pin(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
|
||||
before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
after_score_id: Annotated[int | None, Body(description="放在该成绩之后")] = None,
|
||||
before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -893,8 +893,8 @@ async def reorder_score_pin(
|
||||
async def download_score_replay(
|
||||
score_id: int,
|
||||
db: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
storage_service: StorageService,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -11,8 +11,8 @@ from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database.auth import TotpKeys
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
from app.dependencies.database import Database, Redis, get_redis
|
||||
from app.dependencies.geoip import IPAddress
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.log import logger
|
||||
@@ -27,7 +27,6 @@ from .router import router
|
||||
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class VerifyMethod(BaseModel):
|
||||
@@ -64,10 +63,14 @@ async def verify_session(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
user_agent: UserAgentInfo,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
verification_key: Annotated[
|
||||
str,
|
||||
Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
],
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
|
||||
) -> Response:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
@@ -82,7 +85,6 @@ async def verify_session(
|
||||
else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
)
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
login_method = "password"
|
||||
|
||||
try:
|
||||
@@ -182,12 +184,12 @@ async def verify_session(
|
||||
tags=["验证"],
|
||||
)
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
api_version: APIVersion,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
) -> SessionReissueResponse:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
@@ -203,7 +205,6 @@ async def reissue_verification_code(
|
||||
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
|
||||
|
||||
try:
|
||||
ip_address = get_client_ip(request)
|
||||
user_id = current_user.id
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
@@ -233,17 +234,15 @@ async def reissue_verification_code(
|
||||
async def fallback_email(
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
request: Request,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
) -> VerifyMethod:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.score import Score
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.score import Rank
|
||||
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
|
||||
@@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
|
||||
description="为指定谱面添加标签投票。",
|
||||
)
|
||||
async def vote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
tag_id: Annotated[int, Path(..., description="标签 ID")],
|
||||
session: Database,
|
||||
current_user: Annotated[User, Depends(get_client_user)],
|
||||
):
|
||||
try:
|
||||
get_tag_by_id(tag_id)
|
||||
@@ -90,10 +92,10 @@ async def vote_beatmap_tags(
|
||||
description="取消对指定谱面标签的投票。",
|
||||
)
|
||||
async def devote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
tag_id: Annotated[int, Path(..., description="标签 ID")],
|
||||
session: Database,
|
||||
current_user: Annotated[User, Depends(get_client_user)],
|
||||
):
|
||||
"""
|
||||
取消对谱面指定标签的投票。
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
@@ -51,9 +51,12 @@ async def get_users(
|
||||
session: Database,
|
||||
request: Request,
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
include_variant_statistics: Annotated[
|
||||
bool,
|
||||
Query(description="是否包含各模式的统计信息"),
|
||||
] = False, # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -119,9 +122,9 @@ async def get_users(
|
||||
)
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
limit: int | None = Query(None, description="限制返回的活动数量"),
|
||||
offset: int | None = Query(None, description="活动日志的偏移量"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None,
|
||||
offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
|
||||
):
|
||||
db_user = await session.get(User, user_id)
|
||||
if db_user is None or db_user.id == BANCHOBOT_ID:
|
||||
@@ -147,9 +150,9 @@ async def get_user_events(
|
||||
)
|
||||
async def get_user_kudosu(
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
offset: int = Query(default=0, description="偏移量"),
|
||||
limit: int = Query(default=6, description="返回记录数量限制"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
offset: Annotated[int, Query(description="偏移量")] = 0,
|
||||
limit: Annotated[int, Query(description="返回记录数量限制")] = 6,
|
||||
):
|
||||
"""
|
||||
获取用户的 kudosu 记录
|
||||
@@ -176,8 +179,8 @@ async def get_user_kudosu(
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
|
||||
ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
@@ -225,7 +228,7 @@ async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
request: Request,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
@@ -274,11 +277,11 @@ async def get_user_info(
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
type: Annotated[BeatmapsetType, Path(description="谱面集类型")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -356,16 +359,17 @@ async def get_user_scores(
|
||||
session: Database,
|
||||
api_version: APIVersion,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
type: Annotated[
|
||||
Literal["best", "recent", "firsts", "pinned"],
|
||||
Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False,
|
||||
include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False,
|
||||
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
|
||||
):
|
||||
is_legacy_api = api_version < 20220705
|
||||
redis = get_redis()
|
||||
|
||||
Reference in New Issue
Block a user