refactor(api): use Annotated-style dependency injection
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
@@ -22,6 +21,8 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .room import Room
|
||||
|
||||
|
||||
@@ -72,7 +73,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
@@ -89,7 +90,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
@@ -106,7 +107,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database.item_attempts_count import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
@@ -32,6 +32,9 @@ from sqlmodel import (
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
|
||||
|
||||
class RoomBase(SQLModel, UTCBaseModel):
|
||||
name: str = Field(index=True)
|
||||
@@ -161,7 +164,7 @@ class RoomResp(RoomBase):
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
|
||||
async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
|
||||
@@ -1,4 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .database import get_db as get_db
|
||||
from .user import get_current_user as get_current_user
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
from fastapi import Depends, Header
|
||||
|
||||
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
|
||||
if version is None:
|
||||
return 0
|
||||
if version < 1:
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.service.beatmap_download_service import download_service
|
||||
from typing import Annotated
|
||||
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
def get_beatmap_download_service():
|
||||
"""获取谱面下载服务实例"""
|
||||
return download_service
|
||||
|
||||
|
||||
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
Beatmapset缓存服务依赖注入
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
|
||||
from typing import Annotated
|
||||
|
||||
from app.dependencies.database import Redis
|
||||
from app.service.beatmapset_cache_service import (
|
||||
BeatmapsetCacheService as OriginBeatmapsetCacheService,
|
||||
get_beatmapset_cache_service,
|
||||
)
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
|
||||
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
|
||||
"""获取beatmapset缓存服务依赖"""
|
||||
return get_beatmapset_cache_service(redis)
|
||||
|
||||
|
||||
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]
|
||||
|
||||
@@ -91,6 +91,9 @@ def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
Redis = Annotated[redis.Redis, Depends(get_redis)]
|
||||
|
||||
|
||||
def get_redis_binary():
|
||||
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
|
||||
return redis_binary_client
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import get_redis
|
||||
from app.fetcher import Fetcher
|
||||
from app.fetcher import Fetcher as OriginFetcher
|
||||
from app.log import logger
|
||||
|
||||
fetcher: Fetcher | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
fetcher: OriginFetcher | None = None
|
||||
|
||||
|
||||
async def get_fetcher() -> Fetcher:
|
||||
async def get_fetcher() -> OriginFetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
fetcher = OriginFetcher(
|
||||
settings.fetcher_client_id,
|
||||
settings.fetcher_client_secret,
|
||||
settings.fetcher_scopes,
|
||||
@@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher:
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
|
||||
return fetcher
|
||||
|
||||
|
||||
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]
|
||||
|
||||
@@ -6,10 +6,13 @@ from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import ipaddress
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_geoip_helper() -> GeoIPHelper:
|
||||
@@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper:
|
||||
)
|
||||
|
||||
|
||||
def get_client_ip(request) -> str:
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实 IP 地址
|
||||
支持 IPv4 和 IPv6,考虑代理、负载均衡器等情况
|
||||
@@ -66,6 +69,10 @@ def get_client_ip(request) -> str:
|
||||
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
|
||||
|
||||
|
||||
IPAddress = Annotated[str, Depends(get_client_ip)]
|
||||
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
|
||||
|
||||
|
||||
def is_valid_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
验证 IP 地址是否有效(支持 IPv4 和 IPv6)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import Annotated, cast
|
||||
|
||||
from app.config import (
|
||||
AWSS3StorageSettings,
|
||||
@@ -9,11 +9,13 @@ from app.config import (
|
||||
StorageServiceType,
|
||||
settings,
|
||||
)
|
||||
from app.storage import StorageService
|
||||
from app.storage import StorageService as OriginStorageService
|
||||
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
storage: StorageService | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
storage: OriginStorageService | None = None
|
||||
|
||||
|
||||
def init_storage_service():
|
||||
@@ -50,3 +52,6 @@ def get_storage_service():
|
||||
if storage is None:
|
||||
return init_storage_service()
|
||||
return storage
|
||||
|
||||
|
||||
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Annotated
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database import User
|
||||
from app.database.auth import OAuthToken, V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
@@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
from .api_version import APIVersion
|
||||
from .database import Database, get_redis
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi.security import (
|
||||
APIKeyQuery,
|
||||
HTTPBearer,
|
||||
@@ -112,13 +113,13 @@ async def get_client_user(
|
||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||
# 获取当前验证方式
|
||||
verify_method = None
|
||||
if api_version >= 20250913:
|
||||
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||||
|
||||
if verify_method is None:
|
||||
# 智能选择验证方式(有TOTP优先TOTP)
|
||||
totp_key = await user.awaitable_attrs.totp_key
|
||||
if totp_key is not None and api_version >= 20240101:
|
||||
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = "totp"
|
||||
else:
|
||||
verify_method = "mail"
|
||||
@@ -169,3 +170,6 @@ async def get_current_user(
|
||||
user_and_token: UserAndToken = Depends(get_current_user_and_token),
|
||||
) -> User:
|
||||
return user_and_token[0]
|
||||
|
||||
|
||||
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.geoip import GeoIPService, IPAddress
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
@@ -40,9 +39,8 @@ from app.service.verification_service import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Header, Request
|
||||
from fastapi import APIRouter, Form, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import exists, select
|
||||
|
||||
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
)
|
||||
async def register_user(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
|
||||
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
|
||||
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
|
||||
geoip: GeoIPService,
|
||||
client_ip: IPAddress,
|
||||
):
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
@@ -126,7 +124,6 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
client_ip = get_client_ip(request)
|
||||
country_code = "CN" # 默认国家代码
|
||||
|
||||
try:
|
||||
@@ -201,19 +198,21 @@ async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_agent: UserAgentInfo,
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
ip_address: IPAddress,
|
||||
grant_type: Annotated[
|
||||
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
|
||||
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
|
||||
],
|
||||
client_id: Annotated[int, Form(..., description="客户端 ID")],
|
||||
client_secret: Annotated[str, Form(..., description="客户端密钥")],
|
||||
redis: Redis,
|
||||
geoip: GeoIPService,
|
||||
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
|
||||
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*')")] = "*",
|
||||
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
|
||||
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
|
||||
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
|
||||
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
@@ -311,8 +310,6 @@ async def oauth_token(
|
||||
)
|
||||
token_id = token.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
@@ -571,16 +568,14 @@ async def oauth_token(
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
# 请求密码重置
|
||||
@@ -599,20 +594,16 @@ async def request_password_reset(
|
||||
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
reset_code: str = Form(..., description="重置验证码"),
|
||||
new_password: str = Form(..., description="新密码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
reset_code: Annotated[str, Form(..., description="重置验证码")],
|
||||
new_password: Annotated[str, Form(..., description="新密码")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter
|
||||
|
||||
fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False)
|
||||
|
||||
|
||||
@fetcher_router.get("/callback")
|
||||
async def callback(code: str, fetcher: Fetcher = Depends(get_fetcher)):
|
||||
async def callback(code: str, fetcher: Fetcher):
|
||||
await fetcher.grant_access_token(code)
|
||||
return {"message": "Login successful"}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.storage import LocalStorageService, StorageService
|
||||
from app.dependencies.storage import StorageService as StorageServiceDep
|
||||
from app.storage import LocalStorageService
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
file_router = APIRouter(prefix="/file", include_in_schema=False)
|
||||
|
||||
|
||||
@file_router.get("/{path:path}")
|
||||
async def get_file(path: str, storage: StorageService = Depends(get_storage_service)):
|
||||
async def get_file(path: str, storage: StorageServiceDep):
|
||||
if not isinstance(storage, LocalStorageService):
|
||||
raise HTTPException(404, "Not Found")
|
||||
if not await storage.is_exists(path):
|
||||
|
||||
@@ -11,21 +11,18 @@ from app.database.playlists import Playlist as DBPlaylist
|
||||
from app.database.room import Room
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.log import logger
|
||||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||||
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
|
||||
from app.storage.base import StorageService
|
||||
from app.utils import utcnow
|
||||
|
||||
from .notification.server import server
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi import APIRouter, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -637,8 +634,8 @@ async def add_user_to_room(
|
||||
async def ensure_beatmap_present(
|
||||
beatmap_data: BeatmapEnsureRequest,
|
||||
db: Database,
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
|
||||
@@ -677,7 +674,7 @@ class ReplayDataRequest(BaseModel):
|
||||
@router.post("/scores/replay")
|
||||
async def save_replay(
|
||||
req: ReplayDataRequest,
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
storage_service: StorageService,
|
||||
):
|
||||
replay_data = req.mreplay
|
||||
replay_path = f"replays/{req.score_id}_{req.beatmap_id}_{req.user_id}_lazer_replay.osr"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -11,7 +11,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
@@ -20,7 +20,6 @@ from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
@@ -38,11 +37,14 @@ class UpdateResponse(BaseModel):
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
|
||||
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
|
||||
includes: Annotated[
|
||||
list[str],
|
||||
Query(alias="includes[]", description="要包含的更新类型"),
|
||||
] = ["presence", "silences"],
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
@@ -86,9 +88,9 @@ async def get_update(
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
user: Annotated[str, Path(..., description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -110,9 +112,9 @@ async def join_channel(
|
||||
)
|
||||
async def leave_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
user: Annotated[str, Path(..., description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -135,8 +137,8 @@ async def leave_channel(
|
||||
)
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
@@ -171,9 +173,9 @@ class GetChannelResp(BaseModel):
|
||||
)
|
||||
async def get_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -245,9 +247,9 @@ class CreateChannelReq(BaseModel):
|
||||
)
|
||||
async def create_channel(
|
||||
session: Database,
|
||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
redis: Redis,
|
||||
):
|
||||
if req.type == "PM":
|
||||
target = await session.get(User, req.target_id)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -11,7 +13,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
@@ -24,7 +26,6 @@ from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
@@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel):
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
|
||||
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
@@ -73,9 +74,9 @@ class MessageReq(BaseModel):
|
||||
)
|
||||
async def send_message(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
|
||||
):
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -156,10 +157,10 @@ async def send_message(
|
||||
async def get_message(
|
||||
session: Database,
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||
since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"),
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息(向后翻历史)"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50,
|
||||
since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)")] = 0,
|
||||
until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None,
|
||||
):
|
||||
# 1) 查频道
|
||||
if channel.isdigit():
|
||||
@@ -220,9 +221,9 @@ async def get_message(
|
||||
)
|
||||
async def mark_as_read(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
message: Annotated[int, Path(..., description="消息 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -259,9 +260,9 @@ class NewPMResp(BaseModel):
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
req: Annotated[PMReq, Depends(BodyOrForm(PMReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
|
||||
redis: Redis,
|
||||
):
|
||||
user_id = current_user.id
|
||||
target = await session.get(User, req.target_id)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import overload
|
||||
from typing import Annotated, overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
Redis,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
with_db,
|
||||
@@ -22,7 +23,6 @@ from app.utils import bg_tasks
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -298,10 +298,10 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"),
|
||||
access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"),
|
||||
authorization: str | None = Header(None, description="Bearer认证头"),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
factory: Annotated[DBFactory, Depends(get_db_factory)],
|
||||
token: Annotated[str | None, Query(description="认证令牌,支持通过URL参数传递")] = None,
|
||||
access_token: Annotated[str | None, Query(description="访问令牌,支持通过URL参数传递")] = None,
|
||||
authorization: Annotated[str | None, Header(description="Bearer认证头")] = None,
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.auth import OAuthToken
|
||||
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.geoip import get_geoip_helper
|
||||
from app.dependencies.geoip import GeoIPService
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi import HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -28,8 +29,8 @@ class SessionsResp(BaseModel):
|
||||
)
|
||||
async def get_sessions(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
geoip: GeoIPService,
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
sessions = (
|
||||
@@ -57,7 +58,7 @@ async def get_sessions(
|
||||
async def delete_session(
|
||||
session: Database,
|
||||
session_id: int,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
if session_id == token.id:
|
||||
@@ -91,8 +92,8 @@ class TrustedDevicesResp(BaseModel):
|
||||
)
|
||||
async def get_trusted_devices(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
geoip: GeoIPService,
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
devices = (
|
||||
@@ -131,7 +132,7 @@ async def get_trusted_devices(
|
||||
async def delete_trusted_device(
|
||||
session: Database,
|
||||
device_id: int,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
device = await session.get(TrustedDevice, device_id)
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.storage.base import StorageService
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.utils import check_image
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, File, Security
|
||||
from fastapi import File
|
||||
|
||||
|
||||
@router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"])
|
||||
async def upload_avatar(
|
||||
session: Database,
|
||||
content: bytes = File(...),
|
||||
current_user: User = Security(get_client_user),
|
||||
storage: StorageService = Depends(get_storage_service),
|
||||
content: Annotated[bytes, File(...)],
|
||||
current_user: ClientUser,
|
||||
storage: StorageService,
|
||||
):
|
||||
"""上传用户头像
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.beatmapset_ratings import BeatmapRating
|
||||
from app.database.score import Score
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.service.beatmapset_update_service import get_beatmapset_update_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException, Security
|
||||
from fastapi import Body, Depends, HTTPException
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from sqlmodel import col, exists, select
|
||||
|
||||
@@ -25,7 +26,7 @@ from sqlmodel import col, exists, select
|
||||
async def can_rate_beatmapset(
|
||||
beatmapset_id: int,
|
||||
session: Database,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
"""检查用户是否可以评价谱面集
|
||||
|
||||
@@ -57,8 +58,8 @@ async def can_rate_beatmapset(
|
||||
async def rate_beatmaps(
|
||||
beatmapset_id: int,
|
||||
session: Database,
|
||||
rating: int = Body(..., ge=0, le=10),
|
||||
current_user: User = Security(get_client_user),
|
||||
rating: Annotated[int, Body(..., ge=0, le=10)],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
"""为谱面集评分
|
||||
|
||||
@@ -96,7 +97,7 @@ async def rate_beatmaps(
|
||||
async def sync_beatmapset(
|
||||
beatmapset_id: int,
|
||||
session: Database,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
"""请求同步谱面集
|
||||
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.user import User, UserProfileCover
|
||||
from app.database.user import UserProfileCover
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.storage.base import StorageService
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.utils import check_image
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, File, Security
|
||||
from fastapi import File
|
||||
|
||||
|
||||
@router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"])
|
||||
async def upload_cover(
|
||||
session: Database,
|
||||
content: bytes = File(...),
|
||||
current_user: User = Security(get_client_user),
|
||||
storage: StorageService = Depends(get_storage_service),
|
||||
content: Annotated[bytes, File(...)],
|
||||
current_user: ClientUser,
|
||||
storage: StorageService,
|
||||
):
|
||||
"""上传用户头图
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.auth import OAuthClient, OAuthToken
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException, Security
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Body, HTTPException
|
||||
from sqlmodel import select, text
|
||||
|
||||
|
||||
@@ -22,10 +21,10 @@ from sqlmodel import select, text
|
||||
)
|
||||
async def create_oauth_app(
|
||||
session: Database,
|
||||
name: str = Body(..., max_length=100, description="应用程序名称"),
|
||||
description: str = Body("", description="应用程序描述"),
|
||||
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
|
||||
current_user: User = Security(get_client_user),
|
||||
name: Annotated[str, Body(..., max_length=100, description="应用程序名称")],
|
||||
redirect_uris: Annotated[list[str], Body(..., description="允许的重定向 URI 列表")],
|
||||
current_user: ClientUser,
|
||||
description: Annotated[str, Body(description="应用程序描述")] = "",
|
||||
):
|
||||
result = await session.execute(
|
||||
text(
|
||||
@@ -64,7 +63,7 @@ async def create_oauth_app(
|
||||
async def get_oauth_app(
|
||||
session: Database,
|
||||
client_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
oauth_app = await session.get(OAuthClient, client_id)
|
||||
if not oauth_app:
|
||||
@@ -85,7 +84,7 @@ async def get_oauth_app(
|
||||
)
|
||||
async def get_user_oauth_apps(
|
||||
session: Database,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
|
||||
return [
|
||||
@@ -109,7 +108,7 @@ async def get_user_oauth_apps(
|
||||
async def delete_oauth_app(
|
||||
session: Database,
|
||||
client_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
@@ -134,10 +133,10 @@ async def delete_oauth_app(
|
||||
async def update_oauth_app(
|
||||
session: Database,
|
||||
client_id: int,
|
||||
name: str = Body(..., max_length=100, description="应用程序新名称"),
|
||||
description: str = Body("", description="应用程序新描述"),
|
||||
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
|
||||
current_user: User = Security(get_client_user),
|
||||
name: Annotated[str, Body(..., max_length=100, description="应用程序新名称")],
|
||||
redirect_uris: Annotated[list[str], Body(..., description="新的重定向 URI 列表")],
|
||||
current_user: ClientUser,
|
||||
description: Annotated[str, Body(description="应用程序新描述")] = "",
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
@@ -168,7 +167,7 @@ async def update_oauth_app(
|
||||
async def refresh_secret(
|
||||
session: Database,
|
||||
client_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
@@ -200,10 +199,10 @@ async def refresh_secret(
|
||||
async def generate_oauth_code(
|
||||
session: Database,
|
||||
client_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
|
||||
scopes: list[str] = Body(..., description="请求的权限范围列表"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: ClientUser,
|
||||
redirect_uri: Annotated[str, Body(..., description="授权后重定向的 URI")],
|
||||
scopes: Annotated[list[str], Body(..., description="请求的权限范围列表")],
|
||||
redis: Redis,
|
||||
):
|
||||
client = await session.get(OAuthClient, client_id)
|
||||
if not client:
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import Relationship, User
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import Relationship
|
||||
from app.database.relationship import RelationshipType
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.user import ClientUser
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Security
|
||||
from fastapi import HTTPException, Path
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -27,8 +29,8 @@ class CheckResponse(BaseModel):
|
||||
)
|
||||
async def check_user_relationship(
|
||||
db: Database,
|
||||
user_id: int = Path(..., description="目标用户的 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
user_id: Annotated[int, Path(..., description="目标用户的 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(422, "Cannot check relationship with yourself")
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.score import Score
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.service.user_cache_service import refresh_user_cache_background
|
||||
from app.storage.base import StorageService
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Security
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -24,9 +21,9 @@ async def delete_score(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
score_id: int,
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
storage_service: StorageService,
|
||||
):
|
||||
"""删除成绩
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.team import Team, TeamMember, TeamRequest
|
||||
from app.database.user import BASE_INCLUDES, User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.models.notification import (
|
||||
TeamApplicationAccept,
|
||||
TeamApplicationReject,
|
||||
@@ -14,27 +15,25 @@ from app.models.notification import (
|
||||
)
|
||||
from app.router.notification import server
|
||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
from app.storage.base import StorageService
|
||||
from app.utils import check_image, utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, File, Form, HTTPException, Path, Request, Security
|
||||
from fastapi import File, Form, HTTPException, Path, Request
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import exists, select
|
||||
|
||||
|
||||
@router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"])
|
||||
async def create_team(
|
||||
session: Database,
|
||||
storage: StorageService = Depends(get_storage_service),
|
||||
current_user: User = Security(get_client_user),
|
||||
flag: bytes = File(..., description="战队图标文件"),
|
||||
cover: bytes = File(..., description="战队头图文件"),
|
||||
name: str = Form(max_length=100, description="战队名称"),
|
||||
short_name: str = Form(max_length=10, description="战队缩写"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
storage: StorageService,
|
||||
current_user: ClientUser,
|
||||
flag: Annotated[bytes, File(..., description="战队图标文件")],
|
||||
cover: Annotated[bytes, File(..., description="战队头图文件")],
|
||||
name: Annotated[str, Form(max_length=100, description="战队名称")],
|
||||
short_name: Annotated[str, Form(max_length=10, description="战队缩写")],
|
||||
redis: Redis,
|
||||
):
|
||||
"""创建战队。
|
||||
|
||||
@@ -88,13 +87,13 @@ async def create_team(
|
||||
async def update_team(
|
||||
team_id: int,
|
||||
session: Database,
|
||||
storage: StorageService = Depends(get_storage_service),
|
||||
current_user: User = Security(get_client_user),
|
||||
flag: bytes | None = File(default=None, description="战队图标文件"),
|
||||
cover: bytes | None = File(default=None, description="战队头图文件"),
|
||||
name: str | None = Form(default=None, max_length=100, description="战队名称"),
|
||||
short_name: str | None = Form(default=None, max_length=10, description="战队缩写"),
|
||||
leader_id: int | None = Form(default=None, description="战队队长 ID"),
|
||||
storage: StorageService,
|
||||
current_user: ClientUser,
|
||||
flag: Annotated[bytes | None, File(description="战队图标文件")] = None,
|
||||
cover: Annotated[bytes | None, File(description="战队头图文件")] = None,
|
||||
name: Annotated[str | None, Form(max_length=100, description="战队名称")] = None,
|
||||
short_name: Annotated[str | None, Form(max_length=10, description="战队缩写")] = None,
|
||||
leader_id: Annotated[int | None, Form(description="战队队长 ID")] = None,
|
||||
):
|
||||
"""修改战队。
|
||||
|
||||
@@ -161,9 +160,9 @@ async def update_team(
|
||||
@router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"])
|
||||
async def delete_team(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
team = await session.get(Team, team_id)
|
||||
if not team:
|
||||
@@ -191,7 +190,7 @@ class TeamQueryResp(BaseModel):
|
||||
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
|
||||
async def get_team(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
):
|
||||
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
|
||||
return TeamQueryResp(
|
||||
@@ -203,8 +202,8 @@ async def get_team(
|
||||
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
|
||||
async def request_join_team(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
team = await session.get(Team, team_id)
|
||||
if not team:
|
||||
@@ -231,10 +230,10 @@ async def request_join_team(
|
||||
async def handle_request(
|
||||
req: Request,
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
team = await session.get(Team, team_id)
|
||||
if not team:
|
||||
@@ -272,10 +271,10 @@ async def handle_request(
|
||||
@router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"])
|
||||
async def kick_member(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
team_id: Annotated[int, Path(..., description="战队 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
team = await session.get(Team, team_id)
|
||||
if not team:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import (
|
||||
check_totp_backup_code,
|
||||
finish_create_totp_key,
|
||||
@@ -9,17 +11,15 @@ from app.auth import (
|
||||
)
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException, Security
|
||||
from fastapi import Body, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import pyotp
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class TotpStatusResp(BaseModel):
|
||||
@@ -37,7 +37,7 @@ class TotpStatusResp(BaseModel):
|
||||
response_model=TotpStatusResp,
|
||||
)
|
||||
async def get_totp_status(
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
"""检查用户是否已创建TOTP"""
|
||||
totp_key = await current_user.awaitable_attrs.totp_key
|
||||
@@ -62,8 +62,8 @@ async def get_totp_status(
|
||||
status_code=201,
|
||||
)
|
||||
async def start_create_totp(
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if await current_user.awaitable_attrs.totp_key:
|
||||
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
|
||||
@@ -98,9 +98,9 @@ async def start_create_totp(
|
||||
)
|
||||
async def finish_create_totp(
|
||||
session: Database,
|
||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码")],
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
|
||||
if status == FinishStatus.SUCCESS:
|
||||
@@ -122,9 +122,9 @@ async def finish_create_totp(
|
||||
)
|
||||
async def disable_totp(
|
||||
session: Database,
|
||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码或备份码")],
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
totp = await session.get(TotpKeys, current_user.id)
|
||||
if not totp:
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import validate_username
|
||||
from app.config import settings
|
||||
from app.database.events import Event, EventType
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.dependencies.user import ClientUser
|
||||
from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, HTTPException, Security
|
||||
from fastapi import Body, HTTPException
|
||||
from sqlmodel import exists, select
|
||||
|
||||
|
||||
@router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"])
|
||||
async def user_rename(
|
||||
session: Database,
|
||||
new_name: str = Body(..., description="新的用户名"),
|
||||
current_user: User = Security(get_client_user),
|
||||
new_name: Annotated[str, Body(..., description="新的用户名")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
"""修改用户名
|
||||
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, calculate_beatmap_attributes
|
||||
from app.database.beatmap_playcounts import BeatmapPlaycounts
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.favourite_beatmapset import FavouriteBeatmapset
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.mods import int_to_mods
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .router import AllStrModel, router
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Query
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -148,18 +146,18 @@ class V1Beatmap(AllStrModel):
|
||||
)
|
||||
async def get_beatmaps(
|
||||
session: Database,
|
||||
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
|
||||
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
|
||||
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
|
||||
user: str | None = Query(None, alias="u", description="谱师"),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
|
||||
convert: bool = Query(False, alias="a", description="转谱"), # TODO
|
||||
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
|
||||
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
|
||||
mods: int = Query(0, description="应用到谱面属性的 MOD"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
since: Annotated[datetime | None, Query(description="自指定时间后拥有排行榜的谱面")] = None,
|
||||
beatmapset_id: Annotated[int | None, Query(alias="s", description="谱面集 ID")] = None,
|
||||
beatmap_id: Annotated[int | None, Query(alias="b", description="谱面 ID")] = None,
|
||||
user: Annotated[str | None, Query(alias="u", description="谱师")] = None,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0, le=3)] = None, # TODO
|
||||
convert: Annotated[bool, Query(alias="a", description="转谱")] = False, # TODO
|
||||
checksum: Annotated[str | None, Query(alias="h", description="谱面文件 MD5")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=500, description="返回结果数量限制")] = 500,
|
||||
mods: Annotated[int, Query(description="应用到谱面属性的 MOD")] = 0,
|
||||
):
|
||||
beatmaps: list[Beatmap] = []
|
||||
results = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.database.user import User
|
||||
@@ -181,9 +181,9 @@ async def _count_online_users_optimized(redis):
|
||||
)
|
||||
async def api_get_player_info(
|
||||
session: Database,
|
||||
scope: Literal["stats", "events", "info", "all"] = Query(..., description="信息范围"),
|
||||
id: int | None = Query(None, ge=3, le=2147483647, description="用户 ID"),
|
||||
name: str | None = Query(None, regex=r"^[\w \[\]-]{2,32}$", description="用户名"),
|
||||
scope: Annotated[Literal["stats", "events", "info", "all"], Query(..., description="信息范围")],
|
||||
id: Annotated[int | None, Query(ge=3, le=2147483647, description="用户 ID")] = None,
|
||||
name: Annotated[str | None, Query(regex=r"^[\w \[\]-]{2,32}$", description="用户名")] = None,
|
||||
):
|
||||
"""
|
||||
获取指定玩家的信息
|
||||
|
||||
@@ -2,15 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import date
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.counts import ReplayWatchedCount
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.models.mods import int_to_mods
|
||||
from app.models.score import GameMode
|
||||
from app.storage import StorageService
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -34,18 +33,20 @@ class ReplayModel(BaseModel):
|
||||
)
|
||||
async def download_replay(
|
||||
session: Database,
|
||||
beatmap: int = Query(..., alias="b", description="谱面 ID"),
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int | None = Query(
|
||||
None,
|
||||
alias="m",
|
||||
description="Ruleset ID",
|
||||
ge=0,
|
||||
),
|
||||
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
beatmap: Annotated[int, Query(..., alias="b", description="谱面 ID")],
|
||||
user: Annotated[str, Query(..., alias="u", description="用户")],
|
||||
storage_service: StorageService,
|
||||
ruleset_id: Annotated[
|
||||
int | None,
|
||||
Query(
|
||||
alias="m",
|
||||
description="Ruleset ID",
|
||||
ge=0,
|
||||
),
|
||||
] = None,
|
||||
score_id: Annotated[int | None, Query(alias="s", description="成绩 ID")] = None,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
|
||||
):
|
||||
mods_ = int_to_mods(mods)
|
||||
if score_id is not None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.best_scores import PPBestScore
|
||||
from app.database.score import Score, get_leaderboard
|
||||
@@ -69,10 +69,10 @@ class V1Score(AllStrModel):
|
||||
)
|
||||
async def get_user_best(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
user: Annotated[str, Query(..., alias="u", description="用户")],
|
||||
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
|
||||
):
|
||||
try:
|
||||
scores = (
|
||||
@@ -101,10 +101,10 @@ async def get_user_best(
|
||||
)
|
||||
async def get_user_recent(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
user: Annotated[str, Query(..., alias="u", description="用户")],
|
||||
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
|
||||
):
|
||||
try:
|
||||
scores = (
|
||||
@@ -133,12 +133,12 @@ async def get_user_recent(
|
||||
)
|
||||
async def get_scores(
|
||||
session: Database,
|
||||
user: str | None = Query(None, alias="u", description="用户"),
|
||||
beatmap_id: int = Query(alias="b", description="谱面 ID"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
beatmap_id: Annotated[int, Query(alias="b", description="谱面 ID")],
|
||||
user: Annotated[str | None, Query(alias="u", description="用户")] = None,
|
||||
ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
|
||||
mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
|
||||
):
|
||||
try:
|
||||
if user is not None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.database.user import User
|
||||
@@ -104,10 +104,10 @@ class V1User(AllStrModel):
|
||||
async def get_user(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
|
||||
user: Annotated[str, Query(..., alias="u", description="用户")],
|
||||
ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0)] = None,
|
||||
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None,
|
||||
event_days: Annotated[int, Query(ge=1, le=31, description="从现在起所有事件的最大天数")] = 1,
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
@@ -3,13 +3,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.database.beatmap import calculate_beatmap_attributes
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import (
|
||||
@@ -18,10 +18,9 @@ from app.models.score import (
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -44,11 +43,11 @@ class BatchGetResp(BaseModel):
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
|
||||
md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
|
||||
filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
|
||||
):
|
||||
if id is None and md5 is None and filename is None:
|
||||
raise HTTPException(
|
||||
@@ -75,9 +74,9 @@ async def lookup_beatmap(
|
||||
)
|
||||
async def get_beatmap(
|
||||
db: Database,
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
@@ -95,9 +94,12 @@ async def get_beatmap(
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_ids: Annotated[
|
||||
list[int],
|
||||
Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
@@ -127,16 +129,19 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
async def get_beatmap_attributes(
|
||||
db: Database,
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
mods: list[str] = Query(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
mods: Annotated[
|
||||
list[str],
|
||||
Query(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
],
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset;为空则使用谱面自身模式")] = None,
|
||||
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
|
||||
):
|
||||
mods_ = []
|
||||
if mods and mods[0].isdigit():
|
||||
|
||||
@@ -6,23 +6,20 @@ from urllib.parse import parse_qs
|
||||
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
||||
from app.dependencies.beatmap_download import get_beatmap_download_service
|
||||
from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency
|
||||
from app.dependencies.database import Database, get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.beatmap_download import DownloadService
|
||||
from app.dependencies.beatmapset_cache import BeatmapsetCacheService
|
||||
from app.dependencies.database import Database, Redis, with_db
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.geoip import IPAddress, get_geoip_helper
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.service.asset_proxy_helper import process_response_assets
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash
|
||||
from app.service.beatmapset_cache_service import generate_hash
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import (
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
Form,
|
||||
HTTPException,
|
||||
Path,
|
||||
@@ -53,10 +50,10 @@ async def search_beatmapset(
|
||||
query: Annotated[SearchQueryModel, Query(...)],
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis=Depends(get_redis),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
redis: Redis,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
params = parse_qs(qs=request.url.query, keep_blank_values=True)
|
||||
cursor = {}
|
||||
@@ -134,10 +131,10 @@ async def search_beatmapset(
|
||||
async def lookup_beatmapset(
|
||||
db: Database,
|
||||
request: Request,
|
||||
beatmap_id: int = Query(description="谱面 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
beatmap_id: Annotated[int, Query(description="谱面 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
# 先尝试从缓存获取
|
||||
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
||||
@@ -170,10 +167,10 @@ async def lookup_beatmapset(
|
||||
async def get_beatmapset(
|
||||
db: Database,
|
||||
request: Request,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency),
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
fetcher: Fetcher,
|
||||
cache_service: BeatmapsetCacheService,
|
||||
):
|
||||
# 先尝试从缓存获取
|
||||
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
||||
@@ -203,14 +200,12 @@ async def get_beatmapset(
|
||||
description="\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。",
|
||||
)
|
||||
async def download_beatmapset(
|
||||
request: Request,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
|
||||
current_user: User = Security(get_client_user),
|
||||
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
|
||||
client_ip: IPAddress,
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
current_user: ClientUser,
|
||||
download_service: DownloadService,
|
||||
no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True,
|
||||
):
|
||||
client_ip = get_client_ip(request)
|
||||
|
||||
geoip_helper = get_geoip_helper()
|
||||
geo_info = geoip_helper.lookup(client_ip)
|
||||
country_code = geo_info.get("country_iso", "")
|
||||
@@ -242,9 +237,12 @@ async def download_beatmapset(
|
||||
)
|
||||
async def favourite_beatmapset(
|
||||
db: Database,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
|
||||
action: Annotated[
|
||||
Literal["favourite", "unfavourite"],
|
||||
Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
|
||||
@@ -5,14 +5,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.database import Redis
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
@@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel):
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def get_cache_stats(
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
|
||||
):
|
||||
try:
|
||||
@@ -68,7 +67,7 @@ async def get_cache_stats(
|
||||
)
|
||||
async def invalidate_user_cache(
|
||||
user_id: int,
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
@@ -87,7 +86,7 @@ async def invalidate_user_cache(
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def clear_all_user_cache(
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
@@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel):
|
||||
)
|
||||
async def warmup_cache(
|
||||
request: CacheWarmupRequest,
|
||||
redis: Redis = Depends(get_redis),
|
||||
redis: Redis,
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import MeResp, User
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import UserAndToken, get_current_user_and_token
|
||||
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
|
||||
from app.exceptions.userpage import UserpageError
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import Page
|
||||
@@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security
|
||||
)
|
||||
async def get_user_info_with_ruleset(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
@@ -45,7 +46,7 @@ async def get_user_info_with_ruleset(
|
||||
)
|
||||
async def get_user_info_default(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
@@ -85,8 +86,8 @@ async def get_user_info_default(
|
||||
async def update_userpage(
|
||||
request: UpdateUserpageRequest,
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["edit"]),
|
||||
user_id: Annotated[int, Path(description="用户ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["edit"])],
|
||||
):
|
||||
"""更新用户页面内容(匹配官方osu-web实现)"""
|
||||
# 检查权限:只能编辑自己的页面(除非是管理员)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import GameMode
|
||||
from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
|
||||
@@ -45,11 +45,11 @@ SortType = Literal["performance", "score"]
|
||||
async def get_team_ranking_pp(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user)
|
||||
return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -62,14 +62,17 @@ async def get_team_ranking_pp(
|
||||
async def get_team_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
sort: SortType = Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
sort: Annotated[
|
||||
SortType,
|
||||
Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
],
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
@@ -193,11 +196,11 @@ class CountryResponse(BaseModel):
|
||||
async def get_country_ranking_pp(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user)
|
||||
return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -210,14 +213,17 @@ async def get_country_ranking_pp(
|
||||
async def get_country_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
sort: SortType = Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
sort: Annotated[
|
||||
SortType,
|
||||
Path(
|
||||
...,
|
||||
description="排名类型:performance 表现分 / score 计分成绩总分 "
|
||||
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
|
||||
),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
@@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel):
|
||||
async def get_user_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
sort: SortType = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"),
|
||||
country: str | None = Query(None, description="国家代码"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
|
||||
sort: Annotated[SortType, Path(..., description="排名类型:performance 表现分 / score 计分成绩总分")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
country: Annotated[str | None, Query(description="国家代码")] = None,
|
||||
page: Annotated[int, Query(ge=1, description="页码")] = 1,
|
||||
):
|
||||
# 获取 Redis 连接和缓存服务
|
||||
redis = get_redis()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
||||
from app.database.user import UserResp
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -56,7 +58,7 @@ async def get_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
api_version: APIVersion,
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])],
|
||||
):
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
relationships = await db.exec(
|
||||
@@ -107,8 +109,8 @@ class AddFriendResp(BaseModel):
|
||||
async def add_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
target: int = Query(description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
target: Annotated[int, Query(description="目标用户 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if not (await db.exec(select(exists()).where(User.id == target))).first():
|
||||
raise HTTPException(404, "Target user not found")
|
||||
@@ -176,8 +178,8 @@ async def add_relationship(
|
||||
async def delete_relationship(
|
||||
db: Database,
|
||||
request: Request,
|
||||
target: int = Path(..., description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
target: Annotated[int, Path(..., description="目标用户 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
if not (await db.exec(select(exists()).where(User.id == target))).first():
|
||||
raise HTTPException(404, "Target user not found")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
@@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.score import Score
|
||||
from app.database.user import User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.room import RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
@@ -21,9 +21,8 @@ from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
)
|
||||
async def get_all_rooms(
|
||||
db: Database,
|
||||
mode: Literal["open", "ended", "participated", "owned"] | None = Query(
|
||||
default="open",
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
category: RoomCategory = Query(
|
||||
RoomCategory.NORMAL,
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
mode: Annotated[
|
||||
Literal["open", "ended", "participated", "owned"] | None,
|
||||
Query(
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
] = "open",
|
||||
category: Annotated[
|
||||
RoomCategory,
|
||||
Query(
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
] = RoomCategory.NORMAL,
|
||||
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
|
||||
):
|
||||
resp_list: list[RoomResp] = []
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
|
||||
@@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
|
||||
async def create_room(
|
||||
db: Database,
|
||||
room: APIUploadedRoom,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
@@ -162,13 +165,15 @@ async def create_room(
|
||||
)
|
||||
async def get_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
category: str = Query(
|
||||
default="",
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
redis: Redis,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
] = "",
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -185,8 +190,8 @@ async def get_room(
|
||||
)
|
||||
async def delete_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -205,10 +210,10 @@ async def delete_room(
|
||||
)
|
||||
async def add_user_to_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
redis: Redis,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -229,10 +234,10 @@ async def add_user_to_room(
|
||||
)
|
||||
async def remove_user_from_room(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
user_id: Annotated[int, Path(..., description="用户 ID")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -273,8 +278,8 @@ class APILeaderboard(BaseModel):
|
||||
)
|
||||
async def get_room_leaderboard(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
@@ -329,11 +334,11 @@ class RoomEvents(BaseModel):
|
||||
)
|
||||
async def get_room_events(
|
||||
db: Database,
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
|
||||
before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"),
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None,
|
||||
before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None,
|
||||
):
|
||||
events = (
|
||||
await db.exec(
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.config import settings
|
||||
@@ -34,11 +35,10 @@ from app.database.score import (
|
||||
process_user,
|
||||
)
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.dependencies.database import Database, Redis, get_redis, with_db
|
||||
from app.dependencies.fetcher import Fetcher, get_fetcher
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.log import logger
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.room import RoomCategory
|
||||
@@ -50,7 +50,6 @@ from app.models.score import (
|
||||
)
|
||||
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||
from app.service.user_cache_service import refresh_user_cache_background
|
||||
from app.storage.base import StorageService
|
||||
from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
@@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from httpx import HTTPError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
|
||||
async def get_beatmap_scores(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
mode: GameMode = Query(description="指定 auleset"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
|
||||
type: LeaderboardType = Query(
|
||||
LeaderboardType.GLOBAL,
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
mode: Annotated[GameMode, Query(description="指定 auleset")],
|
||||
mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
type: Annotated[
|
||||
LeaderboardType,
|
||||
Query(
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
] = LeaderboardType.GLOBAL,
|
||||
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
|
||||
@@ -294,12 +294,12 @@ async def get_beatmap_scores(
|
||||
async def get_user_beatmap_score(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
|
||||
mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None,
|
||||
):
|
||||
user_score = (
|
||||
await db.exec(
|
||||
@@ -342,11 +342,11 @@ async def get_user_beatmap_score(
|
||||
async def get_user_all_beatmap_scores(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
|
||||
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
|
||||
):
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
@@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores(
|
||||
async def create_solo_score(
|
||||
background_task: BackgroundTasks,
|
||||
db: Database,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
version_hash: str = Form("", description="游戏版本哈希"),
|
||||
beatmap_hash: str = Form(description="谱面文件哈希"),
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
|
||||
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
|
||||
current_user: ClientUser,
|
||||
version_hash: Annotated[str, Form(description="游戏版本哈希")] = "",
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -406,12 +406,12 @@ async def create_solo_score(
|
||||
async def submit_solo_score(
|
||||
background_task: BackgroundTasks,
|
||||
db: Database,
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
token: int = Path(description="成绩令牌 ID"),
|
||||
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||
token: Annotated[int, Path(description="成绩令牌 ID")],
|
||||
info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")],
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
|
||||
|
||||
@@ -428,11 +428,11 @@ async def create_playlist_score(
|
||||
background_task: BackgroundTasks,
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
beatmap_id: int = Form(description="谱面 ID"),
|
||||
beatmap_hash: str = Form(description="游戏版本哈希"),
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
version_hash: str = Form("", description="谱面版本哈希"),
|
||||
current_user: User = Security(get_client_user),
|
||||
beatmap_id: Annotated[int, Form(description="谱面 ID")],
|
||||
beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
|
||||
ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
|
||||
current_user: ClientUser,
|
||||
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -496,9 +496,9 @@ async def submit_playlist_score(
|
||||
playlist_id: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -555,9 +555,9 @@ async def index_playlist_scores(
|
||||
session: Database,
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50,
|
||||
cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -623,8 +623,8 @@ async def show_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: ClientUser,
|
||||
redis: Redis,
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
@@ -692,7 +692,7 @@ async def get_user_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Security(get_client_user),
|
||||
current_user: ClientUser,
|
||||
):
|
||||
score_record = None
|
||||
start_time = time.time()
|
||||
@@ -725,8 +725,8 @@ async def get_user_playlist_score(
|
||||
)
|
||||
async def pin_score(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -770,8 +770,8 @@ async def pin_score(
|
||||
)
|
||||
async def unpin_score(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -805,10 +805,10 @@ async def unpin_score(
|
||||
)
|
||||
async def reorder_score_pin(
|
||||
db: Database,
|
||||
score_id: int = Path(description="成绩 ID"),
|
||||
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
|
||||
before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
|
||||
current_user: User = Security(get_client_user),
|
||||
score_id: Annotated[int, Path(description="成绩 ID")],
|
||||
current_user: ClientUser,
|
||||
after_score_id: Annotated[int | None, Body(description="放在该成绩之后")] = None,
|
||||
before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
@@ -893,8 +893,8 @@ async def reorder_score_pin(
|
||||
async def download_score_replay(
|
||||
score_id: int,
|
||||
db: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
storage_service: StorageService,
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -11,8 +11,8 @@ from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database.auth import TotpKeys
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
from app.dependencies.database import Database, Redis, get_redis
|
||||
from app.dependencies.geoip import IPAddress
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.log import logger
|
||||
@@ -27,7 +27,6 @@ from .router import router
|
||||
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class VerifyMethod(BaseModel):
|
||||
@@ -64,10 +63,14 @@ async def verify_session(
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
user_agent: UserAgentInfo,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
verification_key: Annotated[
|
||||
str,
|
||||
Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
],
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
|
||||
) -> Response:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
@@ -82,7 +85,6 @@ async def verify_session(
|
||||
else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
)
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
login_method = "password"
|
||||
|
||||
try:
|
||||
@@ -182,12 +184,12 @@ async def verify_session(
|
||||
tags=["验证"],
|
||||
)
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
api_version: APIVersion,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
) -> SessionReissueResponse:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
@@ -203,7 +205,6 @@ async def reissue_verification_code(
|
||||
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
|
||||
|
||||
try:
|
||||
ip_address = get_client_ip(request)
|
||||
user_id = current_user.id
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
@@ -233,17 +234,15 @@ async def reissue_verification_code(
|
||||
async def fallback_email(
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
request: Request,
|
||||
ip_address: IPAddress,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
|
||||
) -> VerifyMethod:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.score import Score
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.score import Rank
|
||||
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
|
||||
@@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
|
||||
description="为指定谱面添加标签投票。",
|
||||
)
|
||||
async def vote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
tag_id: Annotated[int, Path(..., description="标签 ID")],
|
||||
session: Database,
|
||||
current_user: Annotated[User, Depends(get_client_user)],
|
||||
):
|
||||
try:
|
||||
get_tag_by_id(tag_id)
|
||||
@@ -90,10 +92,10 @@ async def vote_beatmap_tags(
|
||||
description="取消对指定谱面标签的投票。",
|
||||
)
|
||||
async def devote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
|
||||
tag_id: Annotated[int, Path(..., description="标签 ID")],
|
||||
session: Database,
|
||||
current_user: Annotated[User, Depends(get_client_user)],
|
||||
):
|
||||
"""
|
||||
取消对谱面指定标签的投票。
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
@@ -51,9 +51,12 @@ async def get_users(
|
||||
session: Database,
|
||||
request: Request,
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
include_variant_statistics: Annotated[
|
||||
bool,
|
||||
Query(description="是否包含各模式的统计信息"),
|
||||
] = False, # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -119,9 +122,9 @@ async def get_users(
|
||||
)
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
limit: int | None = Query(None, description="限制返回的活动数量"),
|
||||
offset: int | None = Query(None, description="活动日志的偏移量"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None,
|
||||
offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
|
||||
):
|
||||
db_user = await session.get(User, user_id)
|
||||
if db_user is None or db_user.id == BANCHOBOT_ID:
|
||||
@@ -147,9 +150,9 @@ async def get_user_events(
|
||||
)
|
||||
async def get_user_kudosu(
|
||||
session: Database,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
offset: int = Query(default=0, description="偏移量"),
|
||||
limit: int = Query(default=6, description="返回记录数量限制"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
offset: Annotated[int, Query(description="偏移量")] = 0,
|
||||
limit: Annotated[int, Query(description="返回记录数量限制")] = 6,
|
||||
):
|
||||
"""
|
||||
获取用户的 kudosu 记录
|
||||
@@ -176,8 +179,8 @@ async def get_user_kudosu(
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
|
||||
ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
@@ -225,7 +228,7 @@ async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
request: Request,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
user_id: Annotated[str, Path(description="用户 ID 或用户名")],
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
@@ -274,11 +277,11 @@ async def get_user_info(
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
type: Annotated[BeatmapsetType, Path(description="谱面集类型")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -356,16 +359,17 @@ async def get_user_scores(
|
||||
session: Database,
|
||||
api_version: APIVersion,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
user_id: Annotated[int, Path(description="用户 ID")],
|
||||
type: Annotated[
|
||||
Literal["best", "recent", "firsts", "pinned"],
|
||||
Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")),
|
||||
],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False,
|
||||
include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False,
|
||||
mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
|
||||
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
|
||||
):
|
||||
is_legacy_api = api_version < 20220705
|
||||
redis = get_redis()
|
||||
|
||||
@@ -7,9 +7,8 @@ from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import DBFactory, get_db_factory
|
||||
from app.dependencies.user import get_current_user_and_token
|
||||
from app.dependencies.user import get_current_user, get_current_user_and_token
|
||||
from app.log import logger
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
|
||||
Reference in New Issue
Block a user