refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
@@ -22,6 +21,8 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.models.multiplayer_hub import PlaylistItem
from .room import Room
@@ -72,7 +73,7 @@ class Playlist(PlaylistBase, table=True):
return result.one()
@classmethod
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
@@ -89,7 +90,7 @@ class Playlist(PlaylistBase, table=True):
)
@classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
@@ -106,7 +107,7 @@ class Playlist(PlaylistBase, table=True):
await session.commit()
@classmethod
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist)
await session.commit()

View File

@@ -1,9 +1,9 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel
from app.models.multiplayer_hub import ServerMultiplayerRoom
from app.models.room import (
MatchType,
QueueMode,
@@ -32,6 +32,9 @@ from sqlmodel import (
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.models.multiplayer_hub import ServerMultiplayerRoom
class RoomBase(SQLModel, UTCBaseModel):
name: str = Field(index=True)
@@ -161,7 +164,7 @@ class RoomResp(RoomBase):
return resp
@classmethod
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp":
room = server_room.room
resp = cls(
id=room.room_id,

View File

@@ -1,4 +1 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import Depends, Header
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
if version is None:
return 0
if version < 1:

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
from app.service.beatmap_download_service import download_service
from typing import Annotated
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
from fastapi import Depends
def get_beatmap_download_service():
"""获取谱面下载服务实例"""
return download_service
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]

View File

@@ -1,16 +1,19 @@
"""
Beatmapset缓存服务依赖注入
"""
from __future__ import annotations
from app.dependencies.database import get_redis
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
from typing import Annotated
from app.dependencies.database import Redis
from app.service.beatmapset_cache_service import (
BeatmapsetCacheService as OriginBeatmapsetCacheService,
get_beatmapset_cache_service,
)
from fastapi import Depends
from redis.asyncio import Redis
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
"""获取beatmapset缓存服务依赖"""
return get_beatmapset_cache_service(redis)
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]

View File

@@ -91,6 +91,9 @@ def get_redis():
return redis_client
Redis = Annotated[redis.Redis, Depends(get_redis)]
def get_redis_binary():
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
return redis_binary_client

View File

@@ -1,17 +1,21 @@
from __future__ import annotations
from typing import Annotated
from app.config import settings
from app.dependencies.database import get_redis
from app.fetcher import Fetcher
from app.fetcher import Fetcher as OriginFetcher
from app.log import logger
fetcher: Fetcher | None = None
from fastapi import Depends
fetcher: OriginFetcher | None = None
async def get_fetcher() -> Fetcher:
async def get_fetcher() -> OriginFetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
fetcher = OriginFetcher(
settings.fetcher_client_id,
settings.fetcher_client_secret,
settings.fetcher_scopes,
@@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher:
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
return fetcher
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]

View File

@@ -6,10 +6,13 @@ from __future__ import annotations
from functools import lru_cache
import ipaddress
from typing import Annotated
from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper
from fastapi import Depends, Request
@lru_cache
def get_geoip_helper() -> GeoIPHelper:
@@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper:
)
def get_client_ip(request) -> str:
def get_client_ip(request: Request) -> str:
"""
获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
@@ -66,6 +69,10 @@ def get_client_ip(request) -> str:
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
IPAddress = Annotated[str, Depends(get_client_ip)]
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
def is_valid_ip(ip_str: str) -> bool:
"""
验证 IP 地址是否有效(支持 IPv4 和 IPv6

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import cast
from typing import Annotated, cast
from app.config import (
AWSS3StorageSettings,
@@ -9,11 +9,13 @@ from app.config import (
StorageServiceType,
settings,
)
from app.storage import StorageService
from app.storage import StorageService as OriginStorageService
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
from app.storage.local import LocalStorageService
storage: StorageService | None = None
from fastapi import Depends
storage: OriginStorageService | None = None
def init_storage_service():
@@ -50,3 +52,6 @@ def get_storage_service():
if storage is None:
return init_storage_service()
return storage
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database import User
from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer
@@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Security
from fastapi.security import (
APIKeyQuery,
HTTPBearer,
@@ -112,13 +113,13 @@ async def get_client_user(
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
# 获取当前验证方式
verify_method = None
if api_version >= 20250913:
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101:
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = "totp"
else:
verify_method = "mail"
@@ -169,3 +170,6 @@ async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User:
return user_and_token[0]
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import timedelta
import re
from typing import Literal
from typing import Annotated, Literal
from app.auth import (
authenticate_user,
@@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.database import Database, Redis
from app.dependencies.geoip import GeoIPService, IPAddress
from app.dependencies.user_agent import UserAgentInfo
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import (
@@ -40,9 +39,8 @@ from app.service.verification_service import (
)
from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Header, Request
from fastapi import APIRouter, Form, Header, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import exists, select
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
)
async def register_user(
db: Database,
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
geoip: GeoIPService,
client_ip: IPAddress,
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
@@ -126,7 +124,6 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
try:
@@ -201,19 +198,21 @@ async def oauth_token(
db: Database,
request: Request,
user_agent: UserAgentInfo,
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
ip_address: IPAddress,
grant_type: Annotated[
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
],
client_id: Annotated[int, Form(..., description="客户端 ID")],
client_secret: Annotated[str, Form(..., description="客户端密钥")],
redis: Redis,
geoip: GeoIPService,
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*'")] = "*",
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
):
scopes = scope.split(" ")
@@ -311,8 +310,6 @@ async def oauth_token(
)
token_id = token.id
ip_address = get_client_ip(request)
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
@@ -571,16 +568,14 @@ async def oauth_token(
)
async def request_password_reset(
request: Request,
email: str = Form(..., description="邮箱地址"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
redis: Redis,
ip_address: IPAddress,
):
"""
请求密码重置
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 请求密码重置
@@ -599,20 +594,16 @@ async def request_password_reset(
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password(
request: Request,
email: str = Form(..., description="邮箱地址"),
reset_code: str = Form(..., description="重置验证"),
new_password: str = Form(..., description="新密码"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
reset_code: Annotated[str, Form(..., description="重置验证码")],
new_password: Annotated[str, Form(..., description="新密")],
redis: Redis,
ip_address: IPAddress,
):
"""
重置密码
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码
success, message = await password_reset_service.reset_password(
email=email.lower().strip(),

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.dependencies.fetcher import Fetcher
from fastapi import APIRouter, Depends
from fastapi import APIRouter
fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False)
@fetcher_router.get("/callback")
async def callback(code: str, fetcher: Fetcher = Depends(get_fetcher)):
async def callback(code: str, fetcher: Fetcher):
await fetcher.grant_access_token(code)
return {"message": "Login successful"}

View File

@@ -1,16 +1,16 @@
from __future__ import annotations
from app.dependencies.storage import get_storage_service
from app.storage import LocalStorageService, StorageService
from app.dependencies.storage import StorageService as StorageServiceDep
from app.storage import LocalStorageService
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
file_router = APIRouter(prefix="/file", include_in_schema=False)
@file_router.get("/{path:path}")
async def get_file(path: str, storage: StorageService = Depends(get_storage_service)):
async def get_file(path: str, storage: StorageServiceDep):
if not isinstance(storage, LocalStorageService):
raise HTTPException(404, "Not Found")
if not await storage.is_exists(path):

View File

@@ -11,21 +11,18 @@ from app.database.playlists import Playlist as DBPlaylist
from app.database.room import Room
from app.database.room_participated_user import RoomParticipatedUser
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.fetcher import Fetcher
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.storage import StorageService
from app.log import logger
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
from app.storage.base import StorageService
from app.utils import utcnow
from .notification.server import server
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy import update
from sqlmodel import col, select
@@ -637,8 +634,8 @@ async def add_user_to_room(
async def ensure_beatmap_present(
beatmap_data: BeatmapEnsureRequest,
db: Database,
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
redis: Redis,
fetcher: Fetcher,
) -> dict[str, Any]:
"""
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
@@ -677,7 +674,7 @@ class ReplayDataRequest(BaseModel):
@router.post("/scores/replay")
async def save_replay(
req: ReplayDataRequest,
storage_service: StorageService = Depends(get_storage_service),
storage_service: StorageService,
):
replay_data = req.mreplay
replay_path = f"replays/{req.score_id}_{req.beatmap_id}_{req.user_id}_lazer_replay.osr"

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Literal, Self
from typing import Annotated, Any, Literal, Self
from app.database.chat import (
ChannelType,
@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.user import User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
@@ -20,7 +20,6 @@ from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -38,11 +37,14 @@ class UpdateResponse(BaseModel):
)
async def get_update(
session: Database,
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
includes: Annotated[
list[str],
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
):
resp = UpdateResponse()
if "presence" in includes:
@@ -86,9 +88,9 @@ async def get_update(
)
async def join_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
user: Annotated[str, Path(..., description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
@@ -110,9 +112,9 @@ async def join_channel(
)
async def leave_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
user: Annotated[str, Path(..., description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
@@ -135,8 +137,8 @@ async def leave_channel(
)
async def get_channel_list(
session: Database,
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
@@ -171,9 +173,9 @@ class GetChannelResp(BaseModel):
)
async def get_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
@@ -245,9 +247,9 @@ class CreateChannelReq(BaseModel):
)
async def create_channel(
session: Database,
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
redis: Redis = Depends(get_redis),
req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
redis: Redis,
):
if req.type == "PM":
target = await session.get(User, req.target_id)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
@@ -11,7 +13,7 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.log import logger
@@ -24,7 +26,6 @@ from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel):
)
async def keep_alive(
session: Database,
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
):
resp = KeepAliveResp()
if history_since:
@@ -73,9 +74,9 @@ class MessageReq(BaseModel):
)
async def send_message(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
):
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
@@ -156,10 +157,10 @@ async def send_message(
async def get_message(
session: Database,
channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"),
until: int | None = Query(None, description="获取自此消息 ID 之的消息(向后翻历史"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50,
since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之的消息(向前加载新消息")] = 0,
until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None,
):
# 1) 查频道
if channel.isdigit():
@@ -220,9 +221,9 @@ async def get_message(
)
async def mark_as_read(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
message: Annotated[int, Path(..., description="消息 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
@@ -259,9 +260,9 @@ class NewPMResp(BaseModel):
)
async def create_new_pm(
session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
redis: Redis = Depends(get_redis),
req: Annotated[PMReq, Depends(BodyOrForm(PMReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
redis: Redis,
):
user_id = current_user.id
target = await session.get(User, req.target_id)

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
import asyncio
from typing import overload
from typing import Annotated, overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.notification import UserNotification, insert_notification
from app.database.user import User
from app.dependencies.database import (
DBFactory,
Redis,
get_db_factory,
get_redis,
with_db,
@@ -22,7 +23,6 @@ from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -298,10 +298,10 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
@chat_router.websocket("/notification-server")
async def chat_websocket(
websocket: WebSocket,
token: str | None = Query(None, description="认证令牌支持通过URL参数传递"),
access_token: str | None = Query(None, description="访问令牌支持通过URL参数传递"),
authorization: str | None = Header(None, description="Bearer认证头"),
factory: DBFactory = Depends(get_db_factory),
factory: Annotated[DBFactory, Depends(get_db_factory)],
token: Annotated[str | None, Query(description="认证令牌支持通过URL参数传递")] = None,
access_token: Annotated[str | None, Query(description="访问令牌支持通过URL参数传递")] = None,
authorization: Annotated[str | None, Header(description="Bearer认证头")] = None,
):
if not server._subscribed:
server._subscribed = True

View File

@@ -1,15 +1,16 @@
from __future__ import annotations
from typing import Annotated
from app.database.auth import OAuthToken
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
from app.dependencies.database import Database
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.geoip import GeoIPService
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.helpers.geoip_helper import GeoIPHelper
from .router import router
from fastapi import Depends, HTTPException, Security
from fastapi import HTTPException, Security
from pydantic import BaseModel
from sqlmodel import col, select
@@ -28,8 +29,8 @@ class SessionsResp(BaseModel):
)
async def get_sessions(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPService,
):
current_user, token = user_and_token
sessions = (
@@ -57,7 +58,7 @@ async def get_sessions(
async def delete_session(
session: Database,
session_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
):
current_user, token = user_and_token
if session_id == token.id:
@@ -91,8 +92,8 @@ class TrustedDevicesResp(BaseModel):
)
async def get_trusted_devices(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPService,
):
current_user, token = user_and_token
devices = (
@@ -131,7 +132,7 @@ async def get_trusted_devices(
async def delete_trusted_device(
session: Database,
device_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
):
current_user, token = user_and_token
device = await session.get(TrustedDevice, device_id)

View File

@@ -1,25 +1,24 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.utils import check_image
from .router import router
from fastapi import Depends, File, Security
from fastapi import File
@router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"])
async def upload_avatar(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
content: Annotated[bytes, File(...)],
current_user: ClientUser,
storage: StorageService,
):
"""上传用户头像

View File

@@ -1,17 +1,18 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap
from app.database.beatmapset import Beatmapset
from app.database.beatmapset_ratings import BeatmapRating
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from app.service.beatmapset_update_service import get_beatmapset_update_service
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from fastapi import Body, Depends, HTTPException
from fastapi_limiter.depends import RateLimiter
from sqlmodel import col, exists, select
@@ -25,7 +26,7 @@ from sqlmodel import col, exists, select
async def can_rate_beatmapset(
beatmapset_id: int,
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""检查用户是否可以评价谱面集
@@ -57,8 +58,8 @@ async def can_rate_beatmapset(
async def rate_beatmaps(
beatmapset_id: int,
session: Database,
rating: int = Body(..., ge=0, le=10),
current_user: User = Security(get_client_user),
rating: Annotated[int, Body(..., ge=0, le=10)],
current_user: ClientUser,
):
"""为谱面集评分
@@ -96,7 +97,7 @@ async def rate_beatmaps(
async def sync_beatmapset(
beatmapset_id: int,
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""请求同步谱面集

View File

@@ -1,25 +1,25 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.user import User, UserProfileCover
from app.database.user import UserProfileCover
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.utils import check_image
from .router import router
from fastapi import Depends, File, Security
from fastapi import File
@router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"])
async def upload_cover(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
content: Annotated[bytes, File(...)],
current_user: ClientUser,
storage: StorageService,
):
"""上传用户头图

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
import secrets
from typing import Annotated
from app.database.auth import OAuthClient, OAuthToken
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from redis.asyncio import Redis
from fastapi import Body, HTTPException
from sqlmodel import select, text
@@ -22,10 +21,10 @@ from sqlmodel import select, text
)
async def create_oauth_app(
session: Database,
name: str = Body(..., max_length=100, description="应用程序名称"),
description: str = Body("", description="应用程序描述"),
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user),
name: Annotated[str, Body(..., max_length=100, description="应用程序名称")],
redirect_uris: Annotated[list[str], Body(..., description="允许的重定向 URI 列表")],
current_user: ClientUser,
description: Annotated[str, Body(description="应用程序描述")] = "",
):
result = await session.execute(
text(
@@ -64,7 +63,7 @@ async def create_oauth_app(
async def get_oauth_app(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_app = await session.get(OAuthClient, client_id)
if not oauth_app:
@@ -85,7 +84,7 @@ async def get_oauth_app(
)
async def get_user_oauth_apps(
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
return [
@@ -109,7 +108,7 @@ async def get_user_oauth_apps(
async def delete_oauth_app(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -134,10 +133,10 @@ async def delete_oauth_app(
async def update_oauth_app(
session: Database,
client_id: int,
name: str = Body(..., max_length=100, description="应用程序新名称"),
description: str = Body("", description="应用程序新描述"),
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
current_user: User = Security(get_client_user),
name: Annotated[str, Body(..., max_length=100, description="应用程序新名称")],
redirect_uris: Annotated[list[str], Body(..., description="新的重定向 URI 列表")],
current_user: ClientUser,
description: Annotated[str, Body(description="应用程序新描述")] = "",
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -168,7 +167,7 @@ async def update_oauth_app(
async def refresh_secret(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -200,10 +199,10 @@ async def refresh_secret(
async def generate_oauth_code(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
scopes: list[str] = Body(..., description="请求的权限范围列表"),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redirect_uri: Annotated[str, Body(..., description="授权后重定向的 URI")],
scopes: Annotated[list[str], Body(..., description="请求的权限范围列表")],
redis: Redis,
):
client = await session.get(OAuthClient, client_id)
if not client:

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
from app.database import Relationship, User
from typing import Annotated
from app.database import Relationship
from app.database.relationship import RelationshipType
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from .router import router
from fastapi import HTTPException, Path, Security
from fastapi import HTTPException, Path
from pydantic import BaseModel, Field
from sqlmodel import select
@@ -27,8 +29,8 @@ class CheckResponse(BaseModel):
)
async def check_user_relationship(
db: Database,
user_id: int = Path(..., description="目标用户的 ID"),
current_user: User = Security(get_client_user),
user_id: Annotated[int, Path(..., description="目标用户的 ID")],
current_user: ClientUser,
):
if user_id == current_user.id:
raise HTTPException(422, "Cannot check relationship with yourself")

View File

@@ -1,17 +1,14 @@
from __future__ import annotations
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from .router import router
from fastapi import BackgroundTasks, Depends, HTTPException, Security
from redis.asyncio import Redis
from fastapi import BackgroundTasks, HTTPException
@router.delete(
@@ -24,9 +21,9 @@ async def delete_score(
session: Database,
background_task: BackgroundTasks,
score_id: int,
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
storage_service: StorageService = Depends(get_storage_service),
redis: Redis,
current_user: ClientUser,
storage_service: StorageService,
):
"""删除成绩

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest
from app.database.user import BASE_INCLUDES, User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.models.notification import (
TeamApplicationAccept,
TeamApplicationReject,
@@ -14,27 +15,25 @@ from app.models.notification import (
)
from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service
from app.storage.base import StorageService
from app.utils import check_image, utcnow
from .router import router
from fastapi import Depends, File, Form, HTTPException, Path, Request, Security
from fastapi import File, Form, HTTPException, Path, Request
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import exists, select
@router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"])
async def create_team(
session: Database,
storage: StorageService = Depends(get_storage_service),
current_user: User = Security(get_client_user),
flag: bytes = File(..., description="战队图标文件"),
cover: bytes = File(..., description="战队头图文件"),
name: str = Form(max_length=100, description="战队名称"),
short_name: str = Form(max_length=10, description="战队缩写"),
redis: Redis = Depends(get_redis),
storage: StorageService,
current_user: ClientUser,
flag: Annotated[bytes, File(..., description="战队图标文件")],
cover: Annotated[bytes, File(..., description="战队头图文件")],
name: Annotated[str, Form(max_length=100, description="战队名称")],
short_name: Annotated[str, Form(max_length=10, description="战队缩写")],
redis: Redis,
):
"""创建战队。
@@ -88,13 +87,13 @@ async def create_team(
async def update_team(
team_id: int,
session: Database,
storage: StorageService = Depends(get_storage_service),
current_user: User = Security(get_client_user),
flag: bytes | None = File(default=None, description="战队图标文件"),
cover: bytes | None = File(default=None, description="战队头图文件"),
name: str | None = Form(default=None, max_length=100, description="战队名称"),
short_name: str | None = Form(default=None, max_length=10, description="战队缩写"),
leader_id: int | None = Form(default=None, description="战队队长 ID"),
storage: StorageService,
current_user: ClientUser,
flag: Annotated[bytes | None, File(description="战队图标文件")] = None,
cover: Annotated[bytes | None, File(description="战队头图文件")] = None,
name: Annotated[str | None, Form(max_length=100, description="战队名称")] = None,
short_name: Annotated[str | None, Form(max_length=10, description="战队缩写")] = None,
leader_id: Annotated[int | None, Form(description="战队队长 ID")] = None,
):
"""修改战队。
@@ -161,9 +160,9 @@ async def update_team(
@router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"])
async def delete_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:
@@ -191,7 +190,7 @@ class TeamQueryResp(BaseModel):
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
async def get_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
team_id: Annotated[int, Path(..., description="战队 ID")],
):
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
return TeamQueryResp(
@@ -203,8 +202,8 @@ async def get_team(
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
async def request_join_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
current_user: User = Security(get_client_user),
team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: ClientUser,
):
team = await session.get(Team, team_id)
if not team:
@@ -231,10 +230,10 @@ async def request_join_team(
async def handle_request(
req: Request,
session: Database,
team_id: int = Path(..., description="战队 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:
@@ -272,10 +271,10 @@ async def handle_request(
@router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"])
async def kick_member(
session: Database,
team_id: int = Path(..., description="战队 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from app.auth import (
check_totp_backup_code,
finish_create_totp_key,
@@ -9,17 +11,15 @@ from app.auth import (
)
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from fastapi import Body, HTTPException
from pydantic import BaseModel
import pyotp
from redis.asyncio import Redis
class TotpStatusResp(BaseModel):
@@ -37,7 +37,7 @@ class TotpStatusResp(BaseModel):
response_model=TotpStatusResp,
)
async def get_totp_status(
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""检查用户是否已创建TOTP"""
totp_key = await current_user.awaitable_attrs.totp_key
@@ -62,8 +62,8 @@ async def get_totp_status(
status_code=201,
)
async def start_create_totp(
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
redis: Redis,
current_user: ClientUser,
):
if await current_user.awaitable_attrs.totp_key:
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
@@ -98,9 +98,9 @@ async def start_create_totp(
)
async def finish_create_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码")],
redis: Redis,
current_user: ClientUser,
):
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
if status == FinishStatus.SUCCESS:
@@ -122,9 +122,9 @@ async def finish_create_totp(
)
async def disable_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码或备份码")],
redis: Redis,
current_user: ClientUser,
):
totp = await session.get(TotpKeys, current_user.id)
if not totp:

View File

@@ -1,24 +1,26 @@
from __future__ import annotations
from typing import Annotated
from app.auth import validate_username
from app.config import settings
from app.database.events import Event, EventType
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from app.utils import utcnow
from .router import router
from fastapi import Body, HTTPException, Security
from fastapi import Body, HTTPException
from sqlmodel import exists, select
@router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"])
async def user_rename(
session: Database,
new_name: str = Body(..., description="新的用户名"),
current_user: User = Security(get_client_user),
new_name: Annotated[str, Body(..., description="新的用户名")],
current_user: ClientUser,
):
"""修改用户名

View File

@@ -1,24 +1,22 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Annotated, Literal
from app.database.beatmap import Beatmap, calculate_beatmap_attributes
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.beatmapset import Beatmapset
from app.database.favourite_beatmapset import FavouriteBeatmapset
from app.database.score import Score
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.mods import int_to_mods
from app.models.score import GameMode
from .router import AllStrModel, router
from fastapi import Depends, Query
from redis.asyncio import Redis
from fastapi import Query
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -148,18 +146,18 @@ class V1Beatmap(AllStrModel):
)
async def get_beatmaps(
session: Database,
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
user: str | None = Query(None, alias="u", description=""),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
convert: bool = Query(False, alias="a", description="转谱"), # TODO
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
mods: int = Query(0, description="应用到谱面属性的 MOD"),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
redis: Redis,
fetcher: Fetcher,
since: Annotated[datetime | None, Query(description="自指定时间后拥有排行榜的谱面")] = None,
beatmapset_id: Annotated[int | None, Query(alias="s", description="面集 ID")] = None,
beatmap_id: Annotated[int | None, Query(alias="b", description="谱面 ID")] = None,
user: Annotated[str | None, Query(alias="u", description="谱师")] = None,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0, le=3)] = None, # TODO
convert: Annotated[bool, Query(alias="a", description="转谱")] = False, # TODO
checksum: Annotated[str | None, Query(alias="h", description="谱面文件 MD5")] = None,
limit: Annotated[int, Query(ge=1, le=500, description="返回结果数量限制")] = 500,
mods: Annotated[int, Query(description="应用到谱面属性的 MOD")] = 0,
):
beatmaps: list[Beatmap] = []
results = []

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Literal
from typing import Annotated, Literal
from app.database.statistics import UserStatistics
from app.database.user import User
@@ -181,9 +181,9 @@ async def _count_online_users_optimized(redis):
)
async def api_get_player_info(
session: Database,
scope: Literal["stats", "events", "info", "all"] = Query(..., description="信息范围"),
id: int | None = Query(None, ge=3, le=2147483647, description="用户 ID"),
name: str | None = Query(None, regex=r"^[\w \[\]-]{2,32}$", description="用户名"),
scope: Annotated[Literal["stats", "events", "info", "all"], Query(..., description="信息范围")],
id: Annotated[int | None, Query(ge=3, le=2147483647, description="用户 ID")] = None,
name: Annotated[str | None, Query(regex=r"^[\w \[\]-]{2,32}$", description="用户名")] = None,
):
"""
获取指定玩家的信息

View File

@@ -2,15 +2,14 @@ from __future__ import annotations
import base64
from datetime import date
from typing import Literal
from typing import Annotated, Literal
from app.database.counts import ReplayWatchedCount
from app.database.score import Score
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.storage import StorageService
from app.models.mods import int_to_mods
from app.models.score import GameMode
from app.storage import StorageService
from .router import router
@@ -34,18 +33,20 @@ class ReplayModel(BaseModel):
)
async def download_replay(
session: Database,
beatmap: int = Query(..., alias="b", description="谱面 ID"),
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(
None,
alias="m",
description="Ruleset ID",
ge=0,
),
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
mods: int = Query(0, description="成绩的 MOD"),
storage_service: StorageService = Depends(get_storage_service),
beatmap: Annotated[int, Query(..., alias="b", description="谱面 ID")],
user: Annotated[str, Query(..., alias="u", description="用户")],
storage_service: StorageService,
ruleset_id: Annotated[
int | None,
Query(
alias="m",
description="Ruleset ID",
ge=0,
),
] = None,
score_id: Annotated[int | None, Query(alias="s", description="成绩 ID")] = None,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
):
mods_ = int_to_mods(mods)
if score_id is not None:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Literal
from typing import Annotated, Literal
from app.database.best_scores import PPBestScore
from app.database.score import Score, get_leaderboard
@@ -69,10 +69,10 @@ class V1Score(AllStrModel):
)
async def get_user_best(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
):
try:
scores = (
@@ -101,10 +101,10 @@ async def get_user_best(
)
async def get_user_recent(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
):
try:
scores = (
@@ -133,12 +133,12 @@ async def get_user_recent(
)
async def get_scores(
session: Database,
user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"),
beatmap_id: Annotated[int, Query(alias="b", description="谱面 ID")],
user: Annotated[str | None, Query(alias="u", description="用户")] = None,
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
):
try:
if user is not None:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Annotated, Literal
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.user import User
@@ -104,10 +104,10 @@ class V1User(AllStrModel):
async def get_user(
session: Database,
background_tasks: BackgroundTasks,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0)] = None,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
event_days: Annotated[int, Query(ge=1, le=31, description="从现在起所有事件的最大天数")] = 1,
):
redis = get_redis()
cache_service = get_user_cache_service(redis)

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
import asyncio
import hashlib
import json
from typing import Annotated
from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, int_to_mods
from app.models.score import (
@@ -18,10 +18,9 @@ from app.models.score import (
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlmodel import col, select
@@ -44,11 +43,11 @@ class BatchGetResp(BaseModel):
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
):
if id is None and md5 is None and filename is None:
raise HTTPException(
@@ -75,9 +74,9 @@ async def lookup_beatmap(
)
async def get_beatmap(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
@@ -95,9 +94,12 @@ async def get_beatmap(
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_ids: Annotated[
list[int],
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
):
if not beatmap_ids:
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
@@ -127,16 +129,19 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mods: Annotated[
list[str],
Query(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
],
redis: Redis,
fetcher: Fetcher,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset为空则使用谱面自身模式")] = None,
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
):
mods_ = []
if mods and mods[0].isdigit():

View File

@@ -6,23 +6,20 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.dependencies.beatmap_download import DownloadService
from app.dependencies.beatmapset_cache import BeatmapsetCacheService
from app.dependencies.database import Database, Redis, with_db
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.models.beatmap import SearchQueryModel
from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmap_download_service import BeatmapDownloadService
from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash
from app.service.beatmapset_cache_service import generate_hash
from .router import router
from fastapi import (
BackgroundTasks,
Depends,
Form,
HTTPException,
Path,
@@ -53,10 +50,10 @@ async def search_beatmapset(
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
redis=Depends(get_redis),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
redis: Redis,
cache_service: BeatmapsetCacheService,
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
@@ -134,10 +131,10 @@ async def search_beatmapset(
async def lookup_beatmapset(
db: Database,
request: Request,
beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
beatmap_id: Annotated[int, Query(description="谱面 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
cache_service: BeatmapsetCacheService,
):
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
@@ -170,10 +167,10 @@ async def lookup_beatmapset(
async def get_beatmapset(
db: Database,
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
cache_service: BeatmapsetCacheService,
):
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
@@ -203,14 +200,12 @@ async def get_beatmapset(
description="\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。",
)
async def download_beatmapset(
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
current_user: User = Security(get_client_user),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
client_ip: IPAddress,
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
current_user: ClientUser,
download_service: DownloadService,
no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True,
):
client_ip = get_client_ip(request)
geoip_helper = get_geoip_helper()
geo_info = geoip_helper.lookup(client_ip)
country_code = geo_info.get("country_iso", "")
@@ -242,9 +237,12 @@ async def download_beatmapset(
)
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
current_user: User = Security(get_client_user),
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
action: Annotated[
Literal["favourite", "unfavourite"],
Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
],
current_user: ClientUser,
):
existing_favourite = (
await db.exec(

View File

@@ -5,14 +5,13 @@
from __future__ import annotations
from app.dependencies.database import get_redis
from app.dependencies.database import Redis
from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import Depends, HTTPException
from fastapi import HTTPException
from pydantic import BaseModel
from redis.asyncio import Redis
class CacheStatsResponse(BaseModel):
@@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel):
tags=["缓存管理"],
)
async def get_cache_stats(
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
):
try:
@@ -68,7 +67,7 @@ async def get_cache_stats(
)
async def invalidate_user_cache(
user_id: int,
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
@@ -87,7 +86,7 @@ async def invalidate_user_cache(
tags=["缓存管理"],
)
async def clear_all_user_cache(
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
@@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel):
)
async def warmup_cache(
request: CacheWarmupRequest,
redis: Redis = Depends(get_redis),
redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Annotated
from app.database import MeResp, User
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user_and_token
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.exceptions.userpage import UserpageError
from app.models.score import GameMode
from app.models.user import Page
@@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security
)
async def get_user_info_with_ruleset(
session: Database,
ruleset: GameMode = Path(description="指定 ruleset"),
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
return user_resp
@@ -45,7 +46,7 @@ async def get_user_info_with_ruleset(
)
async def get_user_info_default(
session: Database,
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
return user_resp
@@ -85,8 +86,8 @@ async def get_user_info_default(
async def update_userpage(
request: UpdateUserpageRequest,
session: Database,
user_id: int = Path(description="用户ID"),
current_user: User = Security(get_current_user, scopes=["edit"]),
user_id: Annotated[int, Path(description="用户ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["edit"])],
):
"""更新用户页面内容匹配官方osu-web实现"""
# 检查权限:只能编辑自己的页面(除非是管理员)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Literal
from typing import Annotated, Literal
from app.config import settings
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
@@ -45,11 +45,11 @@ SortType = Literal["performance", "score"]
async def get_team_ranking_pp(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user)
return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page)
@router.get(
@@ -62,14 +62,17 @@ async def get_team_ranking_pp(
async def get_team_ranking(
session: Database,
background_tasks: BackgroundTasks,
sort: SortType = Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
sort: Annotated[
SortType,
Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
],
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
@@ -193,11 +196,11 @@ class CountryResponse(BaseModel):
async def get_country_ranking_pp(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user)
return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page)
@router.get(
@@ -210,14 +213,17 @@ async def get_country_ranking_pp(
async def get_country_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"),
sort: SortType = Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
sort: Annotated[
SortType,
Path(
...,
description="排名类型performance 表现分 / score 计分成绩总分 "
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
@@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel):
async def get_user_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
sort: SortType = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description=""),
current_user: User = Security(get_current_user, scopes=["public"]),
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
sort: Annotated[SortType, Path(..., description="排名类型performance 表现分 / score 计分成绩总分")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
country: Annotated[str | None, Query(description="国家代")] = None,
page: Annotated[int, Query(ge=1, description="页码")] = 1,
):
# 获取 Redis 连接和缓存服务
redis = get_redis()

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Annotated
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.database.user import UserResp
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database
from app.dependencies.user import get_client_user, get_current_user
from app.dependencies.user import ClientUser, get_current_user
from .router import router
@@ -56,7 +58,7 @@ async def get_relationship(
db: Database,
request: Request,
api_version: APIVersion,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])],
):
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec(
@@ -107,8 +109,8 @@ class AddFriendResp(BaseModel):
async def add_relationship(
db: Database,
request: Request,
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
target: Annotated[int, Query(description="目标用户 ID")],
current_user: ClientUser,
):
if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found")
@@ -176,8 +178,8 @@ async def add_relationship(
async def delete_relationship(
db: Database,
request: Request,
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
target: Annotated[int, Path(..., description="目标用户 ID")],
current_user: ClientUser,
):
if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import UTC
from typing import Literal
from typing import Annotated, Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
@@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.database.user import User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user, get_current_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.signalr.hub import MultiplayerHubs
@@ -21,9 +21,8 @@ from app.utils import utcnow
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession
)
async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned"] | None = Query(
default="open",
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
category: RoomCategory = Query(
RoomCategory.NORMAL,
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mode: Annotated[
Literal["open", "ended", "participated", "owned"] | None,
Query(
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
] = "open",
category: Annotated[
RoomCategory,
Query(
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
):
resp_list: list[RoomResp] = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
@@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
async def create_room(
db: Database,
room: APIUploadedRoom,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redis: Redis,
):
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
@@ -162,13 +165,15 @@ async def create_room(
)
async def get_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
redis: Redis = Depends(get_redis),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
redis: Redis,
category: Annotated[
str,
Query(
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
] = "",
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -185,8 +190,8 @@ async def get_room(
)
async def delete_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_client_user),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: ClientUser,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -205,10 +210,10 @@ async def delete_room(
)
async def add_user_to_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
redis: Redis,
current_user: ClientUser,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -229,10 +234,10 @@ async def add_user_to_room(
)
async def remove_user_from_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -273,8 +278,8 @@ class APILeaderboard(BaseModel):
)
async def get_room_leaderboard(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
@@ -329,11 +334,11 @@ class RoomEvents(BaseModel):
)
async def get_room_events(
db: Database,
room_id: int = Path(..., description="房间 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"),
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None,
before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None,
):
events = (
await db.exec(

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, date
import time
from typing import Annotated
from app.calculator import clamp
from app.config import settings
@@ -34,11 +35,10 @@ from app.database.score import (
process_user,
)
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.dependencies.database import Database, Redis, get_redis, with_db
from app.dependencies.fetcher import Fetcher, get_fetcher
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser, get_current_user
from app.log import logger
from app.models.beatmap import BeatmapRankStatus
from app.models.room import RoomCategory
@@ -50,7 +50,6 @@ from app.models.score import (
)
from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from app.utils import utcnow
from .router import router
@@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse
from fastapi_limiter.depends import RateLimiter
from httpx import HTTPError
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
async def get_beatmap_scores(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
type: LeaderboardType = Query(
LeaderboardType.GLOBAL,
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
mode: Annotated[GameMode, Query(description="指定 auleset")],
mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
type: Annotated[
LeaderboardType,
Query(
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
] = LeaderboardType.GLOBAL,
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
):
if legacy_only:
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
@@ -294,12 +294,12 @@ async def get_beatmap_scores(
async def get_user_beatmap_score(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
current_user: User = Security(get_current_user, scopes=["public"]),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: Annotated[int, Path(description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None,
):
user_score = (
await db.exec(
@@ -342,11 +342,11 @@ async def get_user_beatmap_score(
async def get_user_all_beatmap_scores(
db: Database,
api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: Annotated[int, Path(description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
):
all_user_scores = (
await db.exec(
@@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores(
async def create_solo_score(
background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
version_hash: str = Form("", description="游戏版本哈希"),
beatmap_hash: str = Form(description="谱面文件哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
current_user: ClientUser,
version_hash: Annotated[str, Form(description="游戏版本哈希")] = "",
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -406,12 +406,12 @@ async def create_solo_score(
async def submit_solo_score(
background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
token: int = Path(description="成绩令牌 ID"),
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
beatmap_id: Annotated[int, Path(description="谱面 ID")],
token: Annotated[int, Path(description="成绩令牌 ID")],
info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")],
current_user: ClientUser,
redis: Redis,
fetcher: Fetcher,
):
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@@ -428,11 +428,11 @@ async def create_playlist_score(
background_task: BackgroundTasks,
room_id: int,
playlist_id: int,
beatmap_id: int = Form(description="谱面 ID"),
beatmap_hash: str = Form(description="游戏版本哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
beatmap_id: Annotated[int, Form(description="谱面 ID")],
beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
current_user: ClientUser,
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -496,9 +496,9 @@ async def submit_playlist_score(
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
current_user: ClientUser,
redis: Redis,
fetcher: Fetcher,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -555,9 +555,9 @@ async def index_playlist_scores(
session: Database,
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
current_user: User = Security(get_current_user, scopes=["public"]),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50,
cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -623,8 +623,8 @@ async def show_playlist_score(
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redis: Redis,
):
room = await session.get(Room, room_id)
if not room:
@@ -692,7 +692,7 @@ async def get_user_playlist_score(
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
score_record = None
start_time = time.time()
@@ -725,8 +725,8 @@ async def get_user_playlist_score(
)
async def pin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -770,8 +770,8 @@ async def pin_score(
)
async def unpin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -805,10 +805,10 @@ async def unpin_score(
)
async def reorder_score_pin(
db: Database,
score_id: int = Path(description="成绩 ID"),
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
before_score_id: int | None = Body(default=None, description="放在该成绩之"),
current_user: User = Security(get_client_user),
score_id: Annotated[int, Path(description="成绩 ID")],
current_user: ClientUser,
after_score_id: Annotated[int | None, Body(description="放在该成绩之")] = None,
before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -893,8 +893,8 @@ async def reorder_score_pin(
async def download_score_replay(
score_id: int,
db: Database,
current_user: User = Security(get_current_user, scopes=["public"]),
storage_service: StorageService = Depends(get_storage_service),
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
storage_service: StorageService,
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id

View File

@@ -11,8 +11,8 @@ from app.config import settings
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
from app.database.auth import TotpKeys
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip
from app.dependencies.database import Database, Redis, get_redis
from app.dependencies.geoip import IPAddress
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.dependencies.user_agent import UserAgentInfo
from app.log import logger
@@ -27,7 +27,6 @@ from .router import router
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
class VerifyMethod(BaseModel):
@@ -64,10 +63,14 @@ async def verify_session(
db: Database,
api_version: APIVersion,
user_agent: UserAgentInfo,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
user_and_token: UserAndToken = Security(get_client_user_and_token),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
verification_key: Annotated[
str,
Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
],
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
) -> Response:
current_user = user_and_token[0]
token_id = user_and_token[1].id
@@ -82,7 +85,6 @@ async def verify_session(
else await LoginSessionService.get_login_method(user_id, token_id, redis)
)
ip_address = get_client_ip(request)
login_method = "password"
try:
@@ -182,12 +184,12 @@ async def verify_session(
tags=["验证"],
)
async def reissue_verification_code(
request: Request,
db: Database,
user_agent: UserAgentInfo,
api_version: APIVersion,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> SessionReissueResponse:
current_user = user_and_token[0]
token_id = user_and_token[1].id
@@ -203,7 +205,6 @@ async def reissue_verification_code(
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
try:
ip_address = get_client_ip(request)
user_id = current_user.id
success, message = await EmailVerificationService.resend_verification_code(
db,
@@ -233,17 +234,15 @@ async def reissue_verification_code(
async def fallback_email(
db: Database,
user_agent: UserAgentInfo,
request: Request,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> VerifyMethod:
current_user = user_and_token[0]
token_id = user_and_token[1].id
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
ip_address = get_client_ip(request)
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
success, message = await EmailVerificationService.resend_verification_code(
db,

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap
from app.database.beatmap_tags import BeatmapTagVote
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.models.score import Rank
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
@@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
description="为指定谱面添加标签投票。",
)
async def vote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"),
tag_id: int = Path(..., description="标签 ID"),
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_client_user),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: Annotated[int, Path(..., description="标签 ID")],
session: Database,
current_user: Annotated[User, Depends(get_client_user)],
):
try:
get_tag_by_id(tag_id)
@@ -90,10 +92,10 @@ async def vote_beatmap_tags(
description="取消对指定谱面标签的投票。",
)
async def devote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"),
tag_id: int = Path(..., description="标签 ID"),
session: AsyncSession = Depends(get_db),
current_user: User = Depends(get_client_user),
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: Annotated[int, Path(..., description="标签 ID")],
session: Database,
current_user: Annotated[User, Depends(get_client_user)],
):
"""
取消对谱面指定标签的投票。

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import timedelta
from typing import Literal
from typing import Annotated, Literal
from app.config import settings
from app.const import BANCHOBOT_ID
@@ -51,9 +51,12 @@ async def get_users(
session: Database,
request: Request,
background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")],
# current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
include_variant_statistics: Annotated[
bool,
Query(description="是否包含各模式的统计信息"),
] = False, # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -119,9 +122,9 @@ async def get_users(
)
async def get_user_events(
session: Database,
user_id: int = Path(description="用户 ID"),
limit: int | None = Query(None, description="限制返回的活动数量"),
offset: int | None = Query(None, description="活动日志的偏移量"),
user_id: Annotated[int, Path(description="用户 ID")],
limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None,
offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
):
db_user = await session.get(User, user_id)
if db_user is None or db_user.id == BANCHOBOT_ID:
@@ -147,9 +150,9 @@ async def get_user_events(
)
async def get_user_kudosu(
session: Database,
user_id: int = Path(description="用户 ID"),
offset: int = Query(default=0, description="偏移量"),
limit: int = Query(default=6, description="返回记录数量限制"),
user_id: Annotated[int, Path(description="用户 ID")],
offset: Annotated[int, Query(description="偏移量")] = 0,
limit: Annotated[int, Query(description="返回记录数量限制")] = 6,
):
"""
获取用户的 kudosu 记录
@@ -176,8 +179,8 @@ async def get_user_kudosu(
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
@@ -225,7 +228,7 @@ async def get_user_info(
background_task: BackgroundTasks,
session: Database,
request: Request,
user_id: str = Path(description="用户 ID 或用户名"),
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
@@ -274,11 +277,11 @@ async def get_user_info(
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
user_id: Annotated[int, Path(description="用户 ID")],
type: Annotated[BeatmapsetType, Path(description="谱面集类型")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -356,16 +359,17 @@ async def get_user_scores(
session: Database,
api_version: APIVersion,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
user_id: Annotated[int, Path(description="用户 ID")],
type: Annotated[
Literal["best", "recent", "firsts", "pinned"],
Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False,
include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False,
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None,
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
):
is_legacy_api = api_version < 20220705
redis = get_redis()

View File

@@ -7,9 +7,8 @@ from typing import Literal
import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import DBFactory, get_db_factory
from app.dependencies.user import get_current_user_and_token
from app.dependencies.user import get_current_user, get_current_user_and_token
from app.log import logger
from app.models.signalr import NegotiateResponse, Transport