refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.models.mods import APIMod from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
@@ -22,6 +21,8 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.multiplayer_hub import PlaylistItem
from .room import Room from .room import Room
@@ -72,7 +73,7 @@ class Playlist(PlaylistBase, table=True):
return result.one() return result.one()
@classmethod @classmethod
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist": async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session) next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls( return cls(
id=next_id, id=next_id,
@@ -89,7 +90,7 @@ class Playlist(PlaylistBase, table=True):
) )
@classmethod @classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id)) db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
db_playlist = db_playlist.first() db_playlist = db_playlist.first()
if db_playlist is None: if db_playlist is None:
@@ -106,7 +107,7 @@ class Playlist(PlaylistBase, table=True):
await session.commit() await session.commit()
@classmethod @classmethod
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session) db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist) session.add(db_playlist)
await session.commit() await session.commit()

View File

@@ -1,9 +1,9 @@
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING
from app.database.item_attempts_count import PlaylistAggregateScore from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.models.multiplayer_hub import ServerMultiplayerRoom
from app.models.room import ( from app.models.room import (
MatchType, MatchType,
QueueMode, QueueMode,
@@ -32,6 +32,9 @@ from sqlmodel import (
) )
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.models.multiplayer_hub import ServerMultiplayerRoom
class RoomBase(SQLModel, UTCBaseModel): class RoomBase(SQLModel, UTCBaseModel):
name: str = Field(index=True) name: str = Field(index=True)
@@ -161,7 +164,7 @@ class RoomResp(RoomBase):
return resp return resp
@classmethod @classmethod
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp": async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp":
room = server_room.room room = server_room.room
resp = cls( resp = cls(
id=room.room_id, id=room.room_id,

View File

@@ -1,4 +1 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import Depends, Header from fastapi import Depends, Header
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int: def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
if version is None: if version is None:
return 0 return 0
if version < 1: if version < 1:

View File

@@ -1,8 +1,15 @@
from __future__ import annotations from __future__ import annotations
from app.service.beatmap_download_service import download_service from typing import Annotated
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
from fastapi import Depends
def get_beatmap_download_service(): def get_beatmap_download_service():
"""获取谱面下载服务实例""" """获取谱面下载服务实例"""
return download_service return download_service
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]

View File

@@ -1,16 +1,19 @@
"""
Beatmapset缓存服务依赖注入
"""
from __future__ import annotations from __future__ import annotations
from app.dependencies.database import get_redis from typing import Annotated
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
from app.dependencies.database import Redis
from app.service.beatmapset_cache_service import (
BeatmapsetCacheService as OriginBeatmapsetCacheService,
get_beatmapset_cache_service,
)
from fastapi import Depends from fastapi import Depends
from redis.asyncio import Redis
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService: def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
"""获取beatmapset缓存服务依赖""" """获取beatmapset缓存服务依赖"""
return get_beatmapset_cache_service(redis) return get_beatmapset_cache_service(redis)
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]

View File

@@ -91,6 +91,9 @@ def get_redis():
return redis_client return redis_client
Redis = Annotated[redis.Redis, Depends(get_redis)]
def get_redis_binary(): def get_redis_binary():
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)""" """获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
return redis_binary_client return redis_binary_client

View File

@@ -1,17 +1,21 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.config import settings from app.config import settings
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.fetcher import Fetcher from app.fetcher import Fetcher as OriginFetcher
from app.log import logger from app.log import logger
fetcher: Fetcher | None = None from fastapi import Depends
fetcher: OriginFetcher | None = None
async def get_fetcher() -> Fetcher: async def get_fetcher() -> OriginFetcher:
global fetcher global fetcher
if fetcher is None: if fetcher is None:
fetcher = Fetcher( fetcher = OriginFetcher(
settings.fetcher_client_id, settings.fetcher_client_id,
settings.fetcher_client_secret, settings.fetcher_client_secret,
settings.fetcher_scopes, settings.fetcher_scopes,
@@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher:
if not fetcher.access_token or not fetcher.refresh_token: if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>") logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
return fetcher return fetcher
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]

View File

@@ -6,10 +6,13 @@ from __future__ import annotations
from functools import lru_cache from functools import lru_cache
import ipaddress import ipaddress
from typing import Annotated
from app.config import settings from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper from app.helpers.geoip_helper import GeoIPHelper
from fastapi import Depends, Request
@lru_cache @lru_cache
def get_geoip_helper() -> GeoIPHelper: def get_geoip_helper() -> GeoIPHelper:
@@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper:
) )
def get_client_ip(request) -> str: def get_client_ip(request: Request) -> str:
""" """
获取客户端真实 IP 地址 获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况 支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
@@ -66,6 +69,10 @@ def get_client_ip(request) -> str:
return client_ip if is_valid_ip(client_ip) else "127.0.0.1" return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
IPAddress = Annotated[str, Depends(get_client_ip)]
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
def is_valid_ip(ip_str: str) -> bool: def is_valid_ip(ip_str: str) -> bool:
""" """
验证 IP 地址是否有效(支持 IPv4 和 IPv6 验证 IP 地址是否有效(支持 IPv4 和 IPv6

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import Annotated, cast
from app.config import ( from app.config import (
AWSS3StorageSettings, AWSS3StorageSettings,
@@ -9,11 +9,13 @@ from app.config import (
StorageServiceType, StorageServiceType,
settings, settings,
) )
from app.storage import StorageService from app.storage import StorageService as OriginStorageService
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
from app.storage.local import LocalStorageService from app.storage.local import LocalStorageService
storage: StorageService | None = None from fastapi import Depends
storage: OriginStorageService | None = None
def init_storage_service(): def init_storage_service():
@@ -50,3 +52,6 @@ def get_storage_service():
if storage is None: if storage is None:
return init_storage_service() return init_storage_service()
return storage return storage
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token from app.auth import get_token_by_access_token
from app.config import settings from app.config import settings
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database import User from app.database import User
from app.database.auth import OAuthToken, V1APIKeys from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer from app.models.oauth import OAuth2ClientCredentialsBearer
@@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
from .api_version import APIVersion from .api_version import APIVersion
from .database import Database, get_redis from .database import Database, get_redis
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException, Security
from fastapi.security import ( from fastapi.security import (
APIKeyQuery, APIKeyQuery,
HTTPBearer, HTTPBearer,
@@ -112,13 +113,13 @@ async def get_client_user(
if await LoginSessionService.check_is_need_verification(db, user.id, token.id): if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
# 获取当前验证方式 # 获取当前验证方式
verify_method = None verify_method = None
if api_version >= 20250913: if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis) verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None: if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP # 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101: if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = "totp" verify_method = "totp"
else: else:
verify_method = "mail" verify_method = "mail"
@@ -169,3 +170,6 @@ async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token), user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User: ) -> User:
return user_and_token[0] return user_and_token[0]
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
import re import re
from typing import Literal from typing import Annotated, Literal
from app.auth import ( from app.auth import (
authenticate_user, authenticate_user,
@@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User from app.database import DailyChallengeStats, OAuthClient, User
from app.database.auth import TotpKeys from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import GeoIPService, IPAddress
from app.dependencies.user_agent import UserAgentInfo from app.dependencies.user_agent import UserAgentInfo
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import ( from app.models.oauth import (
@@ -40,9 +39,8 @@ from app.service.verification_service import (
) )
from app.utils import utcnow from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Header, Request from fastapi import APIRouter, Form, Header, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text from sqlalchemy import text
from sqlmodel import exists, select from sqlmodel import exists, select
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
) )
async def register_user( async def register_user(
db: Database, db: Database,
request: Request, user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
user_username: str = Form(..., alias="user[username]", description="用户名"), user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"), user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
user_password: str = Form(..., alias="user[password]", description="密码"), geoip: GeoIPService,
geoip: GeoIPHelper = Depends(get_geoip_helper), client_ip: IPAddress,
): ):
username_errors = validate_username(user_username) username_errors = validate_username(user_username)
email_errors = validate_email(user_email) email_errors = validate_email(user_email)
@@ -126,7 +124,6 @@ async def register_user(
try: try:
# 获取客户端 IP 并查询地理位置 # 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码 country_code = "CN" # 默认国家代码
try: try:
@@ -201,19 +198,21 @@ async def oauth_token(
db: Database, db: Database,
request: Request, request: Request,
user_agent: UserAgentInfo, user_agent: UserAgentInfo,
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form( ip_address: IPAddress,
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证" grant_type: Annotated[
), Literal["authorization_code", "refresh_token", "password", "client_credentials"],
client_id: int = Form(..., description="客户端 ID"), Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
client_secret: str = Form(..., description="客户端密钥"), ],
code: str | None = Form(None, description="授权码(仅授权码模式需要)"), client_id: Annotated[int, Form(..., description="客户端 ID")],
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"), client_secret: Annotated[str, Form(..., description="客户端密钥")],
username: str | None = Form(None, description="用户名(仅密码模式需要)"), redis: Redis,
password: str | None = Form(None, description="密码(仅密码模式需要)"), geoip: GeoIPService,
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"), code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
redis: Redis = Depends(get_redis), scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*'")] = "*",
geoip: GeoIPHelper = Depends(get_geoip_helper), username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
): ):
scopes = scope.split(" ") scopes = scope.split(" ")
@@ -311,8 +310,6 @@ async def oauth_token(
) )
token_id = token.id token_id = token.id
ip_address = get_client_ip(request)
# 获取国家代码 # 获取国家代码
geo_info = geoip.lookup(ip_address) geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX") country_code = geo_info.get("country_iso", "XX")
@@ -571,16 +568,14 @@ async def oauth_token(
) )
async def request_password_reset( async def request_password_reset(
request: Request, request: Request,
email: str = Form(..., description="邮箱地址"), email: Annotated[str, Form(..., description="邮箱地址")],
redis: Redis = Depends(get_redis), redis: Redis,
ip_address: IPAddress,
): ):
""" """
请求密码重置 请求密码重置
""" """
from app.dependencies.geoip import get_client_ip
# 获取客户端信息 # 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
# 请求密码重置 # 请求密码重置
@@ -599,20 +594,16 @@ async def request_password_reset(
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码") @router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password( async def reset_password(
request: Request, email: Annotated[str, Form(..., description="邮箱地址")],
email: str = Form(..., description="邮箱地址"), reset_code: Annotated[str, Form(..., description="重置验证码")],
reset_code: str = Form(..., description="重置验证"), new_password: Annotated[str, Form(..., description="新密")],
new_password: str = Form(..., description="新密码"), redis: Redis,
redis: Redis = Depends(get_redis), ip_address: IPAddress,
): ):
""" """
重置密码 重置密码
""" """
from app.dependencies.geoip import get_client_ip
# 获取客户端信息 # 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码 # 重置密码
success, message = await password_reset_service.reset_password( success, message = await password_reset_service.reset_password(
email=email.lower().strip(), email=email.lower().strip(),

View File

@@ -1,14 +1,13 @@
from __future__ import annotations from __future__ import annotations
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher
from app.fetcher import Fetcher
from fastapi import APIRouter, Depends from fastapi import APIRouter
fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False) fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False)
@fetcher_router.get("/callback") @fetcher_router.get("/callback")
async def callback(code: str, fetcher: Fetcher = Depends(get_fetcher)): async def callback(code: str, fetcher: Fetcher):
await fetcher.grant_access_token(code) await fetcher.grant_access_token(code)
return {"message": "Login successful"} return {"message": "Login successful"}

View File

@@ -1,16 +1,16 @@
from __future__ import annotations from __future__ import annotations
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService as StorageServiceDep
from app.storage import LocalStorageService, StorageService from app.storage import LocalStorageService
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
file_router = APIRouter(prefix="/file", include_in_schema=False) file_router = APIRouter(prefix="/file", include_in_schema=False)
@file_router.get("/{path:path}") @file_router.get("/{path:path}")
async def get_file(path: str, storage: StorageService = Depends(get_storage_service)): async def get_file(path: str, storage: StorageServiceDep):
if not isinstance(storage, LocalStorageService): if not isinstance(storage, LocalStorageService):
raise HTTPException(404, "Not Found") raise HTTPException(404, "Not Found")
if not await storage.is_exists(path): if not await storage.is_exists(path):

View File

@@ -11,21 +11,18 @@ from app.database.playlists import Playlist as DBPlaylist
from app.database.room import Room from app.database.room import Room
from app.database.room_participated_user import RoomParticipatedUser from app.database.room_participated_user import RoomParticipatedUser
from app.database.user import User from app.database.user import User
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.fetcher import Fetcher
from app.log import logger from app.log import logger
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
from app.storage.base import StorageService
from app.utils import utcnow from app.utils import utcnow
from .notification.server import server from .notification.server import server
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy import update from sqlalchemy import update
from sqlmodel import col, select from sqlmodel import col, select
@@ -637,8 +634,8 @@ async def add_user_to_room(
async def ensure_beatmap_present( async def ensure_beatmap_present(
beatmap_data: BeatmapEnsureRequest, beatmap_data: BeatmapEnsureRequest,
db: Database, db: Database,
redis: Redis = Depends(get_redis), redis: Redis,
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
确保谱面在服务器中存在(包括元数据和原始文件缓存)。 确保谱面在服务器中存在(包括元数据和原始文件缓存)。
@@ -677,7 +674,7 @@ class ReplayDataRequest(BaseModel):
@router.post("/scores/replay") @router.post("/scores/replay")
async def save_replay( async def save_replay(
req: ReplayDataRequest, req: ReplayDataRequest,
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService,
): ):
replay_data = req.mreplay replay_data = req.mreplay
replay_path = f"replays/{req.score_id}_{req.beatmap_id}_{req.user_id}_lazer_replay.osr" replay_path = f"replays/{req.score_id}_{req.beatmap_id}_{req.user_id}_lazer_replay.osr"

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, Self from typing import Annotated, Any, Literal, Self
from app.database.chat import ( from app.database.chat import (
ChannelType, ChannelType,
@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.user import User, UserResp from app.database.user import User, UserResp
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
@@ -20,7 +20,6 @@ from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
@@ -38,11 +37,14 @@ class UpdateResponse(BaseModel):
) )
async def get_update( async def get_update(
session: Database, session: Database,
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), redis: Redis,
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"), history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
current_user: User = Security(get_current_user, scopes=["chat.read"]), since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
redis: Redis = Depends(get_redis), includes: Annotated[
list[str],
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
): ):
resp = UpdateResponse() resp = UpdateResponse()
if "presence" in includes: if "presence" in includes:
@@ -86,9 +88,9 @@ async def get_update(
) )
async def join_channel( async def join_channel(
session: Database, session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: Annotated[str, Path(..., description="频道 ID/名称")],
user: str = Path(..., description="用户 ID"), user: Annotated[str, Path(..., description="用户 ID")],
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
@@ -110,9 +112,9 @@ async def join_channel(
) )
async def leave_channel( async def leave_channel(
session: Database, session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: Annotated[str, Path(..., description="频道 ID/名称")],
user: str = Path(..., description="用户 ID"), user: Annotated[str, Path(..., description="用户 ID")],
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
@@ -135,8 +137,8 @@ async def leave_channel(
) )
async def get_channel_list( async def get_channel_list(
session: Database, session: Database,
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all() channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = [] results = []
@@ -171,9 +173,9 @@ class GetChannelResp(BaseModel):
) )
async def get_channel( async def get_channel(
session: Database, session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: Annotated[str, Path(..., description="频道 ID/名称")],
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
@@ -245,9 +247,9 @@ class CreateChannelReq(BaseModel):
) )
async def create_channel( async def create_channel(
session: Database, session: Database,
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))],
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
if req.type == "PM": if req.type == "PM":
target = await session.get(User, req.target_id) target = await session.get(User, req.target_id)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database import ChatMessageResp from app.database import ChatMessageResp
from app.database.chat import ( from app.database.chat import (
ChannelType, ChannelType,
@@ -11,7 +13,7 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.user import User from app.database.user import User
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.log import logger from app.log import logger
@@ -24,7 +26,6 @@ from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
@@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel):
) )
async def keep_alive( async def keep_alive(
session: Database, session: Database,
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
current_user: User = Security(get_current_user, scopes=["chat.read"]), since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
): ):
resp = KeepAliveResp() resp = KeepAliveResp()
if history_since: if history_since:
@@ -73,9 +74,9 @@ class MessageReq(BaseModel):
) )
async def send_message( async def send_message(
session: Database, session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: Annotated[str, Path(..., description="频道 ID/名称")],
req: MessageReq = Depends(BodyOrForm(MessageReq)), req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
current_user: User = Security(get_current_user, scopes=["chat.write"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
): ):
# 使用明确的查询来获取 channel避免延迟加载 # 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
@@ -156,10 +157,10 @@ async def send_message(
async def get_message( async def get_message(
session: Database, session: Database,
channel: str, channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"), limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50,
until: int | None = Query(None, description="获取自此消息 ID 之的消息(向后翻历史"), since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之的消息(向前加载新消息")] = 0,
current_user: User = Security(get_current_user, scopes=["chat.read"]), until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None,
): ):
# 1) 查频道 # 1) 查频道
if channel.isdigit(): if channel.isdigit():
@@ -220,9 +221,9 @@ async def get_message(
) )
async def mark_as_read( async def mark_as_read(
session: Database, session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: Annotated[str, Path(..., description="频道 ID/名称")],
message: int = Path(..., description="消息 ID"), message: Annotated[int, Path(..., description="消息 ID")],
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
): ):
# 使用明确的查询获取 channel避免延迟加载 # 使用明确的查询获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
@@ -259,9 +260,9 @@ class NewPMResp(BaseModel):
) )
async def create_new_pm( async def create_new_pm(
session: Database, session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)), req: Annotated[PMReq, Depends(BodyOrForm(PMReq))],
current_user: User = Security(get_current_user, scopes=["chat.write"]), current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
user_id = current_user.id user_id = current_user.id
target = await session.get(User, req.target_id) target = await session.get(User, req.target_id)

View File

@@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import overload from typing import Annotated, overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.notification import UserNotification, insert_notification from app.database.notification import UserNotification, insert_notification
from app.database.user import User from app.database.user import User
from app.dependencies.database import ( from app.dependencies.database import (
DBFactory, DBFactory,
Redis,
get_db_factory, get_db_factory,
get_redis, get_redis,
with_db, with_db,
@@ -22,7 +23,6 @@ from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -298,10 +298,10 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
@chat_router.websocket("/notification-server") @chat_router.websocket("/notification-server")
async def chat_websocket( async def chat_websocket(
websocket: WebSocket, websocket: WebSocket,
token: str | None = Query(None, description="认证令牌支持通过URL参数传递"), factory: Annotated[DBFactory, Depends(get_db_factory)],
access_token: str | None = Query(None, description="访问令牌支持通过URL参数传递"), token: Annotated[str | None, Query(description="认证令牌支持通过URL参数传递")] = None,
authorization: str | None = Header(None, description="Bearer认证头"), access_token: Annotated[str | None, Query(description="访问令牌支持通过URL参数传递")] = None,
factory: DBFactory = Depends(get_db_factory), authorization: Annotated[str | None, Header(description="Bearer认证头")] = None,
): ):
if not server._subscribed: if not server._subscribed:
server._subscribed = True server._subscribed = True

View File

@@ -1,15 +1,16 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database.auth import OAuthToken from app.database.auth import OAuthToken
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.geoip import get_geoip_helper from app.dependencies.geoip import GeoIPService
from app.dependencies.user import UserAndToken, get_client_user_and_token from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.helpers.geoip_helper import GeoIPHelper
from .router import router from .router import router
from fastapi import Depends, HTTPException, Security from fastapi import HTTPException, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import col, select from sqlmodel import col, select
@@ -28,8 +29,8 @@ class SessionsResp(BaseModel):
) )
async def get_sessions( async def get_sessions(
session: Database, session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPHelper = Depends(get_geoip_helper), geoip: GeoIPService,
): ):
current_user, token = user_and_token current_user, token = user_and_token
sessions = ( sessions = (
@@ -57,7 +58,7 @@ async def get_sessions(
async def delete_session( async def delete_session(
session: Database, session: Database,
session_id: int, session_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
): ):
current_user, token = user_and_token current_user, token = user_and_token
if session_id == token.id: if session_id == token.id:
@@ -91,8 +92,8 @@ class TrustedDevicesResp(BaseModel):
) )
async def get_trusted_devices( async def get_trusted_devices(
session: Database, session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPHelper = Depends(get_geoip_helper), geoip: GeoIPService,
): ):
current_user, token = user_and_token current_user, token = user_and_token
devices = ( devices = (
@@ -131,7 +132,7 @@ async def get_trusted_devices(
async def delete_trusted_device( async def delete_trusted_device(
session: Database, session: Database,
device_id: int, device_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
): ):
current_user, token = user_and_token current_user, token = user_and_token
device = await session.get(TrustedDevice, device_id) device = await session.get(TrustedDevice, device_id)

View File

@@ -1,25 +1,24 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
from typing import Annotated
from app.database.user import User
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from app.storage.base import StorageService
from app.utils import check_image from app.utils import check_image
from .router import router from .router import router
from fastapi import Depends, File, Security from fastapi import File
@router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"]) @router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"])
async def upload_avatar( async def upload_avatar(
session: Database, session: Database,
content: bytes = File(...), content: Annotated[bytes, File(...)],
current_user: User = Security(get_client_user), current_user: ClientUser,
storage: StorageService = Depends(get_storage_service), storage: StorageService,
): ):
"""上传用户头像 """上传用户头像

View File

@@ -1,17 +1,18 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.database.beatmapset import Beatmapset from app.database.beatmapset import Beatmapset
from app.database.beatmapset_ratings import BeatmapRating from app.database.beatmapset_ratings import BeatmapRating
from app.database.score import Score from app.database.score import Score
from app.database.user import User
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from app.service.beatmapset_update_service import get_beatmapset_update_service from app.service.beatmapset_update_service import get_beatmapset_update_service
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException, Security from fastapi import Body, Depends, HTTPException
from fastapi_limiter.depends import RateLimiter from fastapi_limiter.depends import RateLimiter
from sqlmodel import col, exists, select from sqlmodel import col, exists, select
@@ -25,7 +26,7 @@ from sqlmodel import col, exists, select
async def can_rate_beatmapset( async def can_rate_beatmapset(
beatmapset_id: int, beatmapset_id: int,
session: Database, session: Database,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
"""检查用户是否可以评价谱面集 """检查用户是否可以评价谱面集
@@ -57,8 +58,8 @@ async def can_rate_beatmapset(
async def rate_beatmaps( async def rate_beatmaps(
beatmapset_id: int, beatmapset_id: int,
session: Database, session: Database,
rating: int = Body(..., ge=0, le=10), rating: Annotated[int, Body(..., ge=0, le=10)],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
"""为谱面集评分 """为谱面集评分
@@ -96,7 +97,7 @@ async def rate_beatmaps(
async def sync_beatmapset( async def sync_beatmapset(
beatmapset_id: int, beatmapset_id: int,
session: Database, session: Database,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
"""请求同步谱面集 """请求同步谱面集

View File

@@ -1,25 +1,25 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
from typing import Annotated
from app.database.user import User, UserProfileCover from app.database.user import UserProfileCover
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from app.storage.base import StorageService
from app.utils import check_image from app.utils import check_image
from .router import router from .router import router
from fastapi import Depends, File, Security from fastapi import File
@router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"]) @router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"])
async def upload_cover( async def upload_cover(
session: Database, session: Database,
content: bytes = File(...), content: Annotated[bytes, File(...)],
current_user: User = Security(get_client_user), current_user: ClientUser,
storage: StorageService = Depends(get_storage_service), storage: StorageService,
): ):
"""上传用户头图 """上传用户头图

View File

@@ -1,16 +1,15 @@
from __future__ import annotations from __future__ import annotations
import secrets import secrets
from typing import Annotated
from app.database.auth import OAuthClient, OAuthToken from app.database.auth import OAuthClient, OAuthToken
from app.database.user import User from app.dependencies.database import Database, Redis
from app.dependencies.database import Database, get_redis from app.dependencies.user import ClientUser
from app.dependencies.user import get_client_user
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException, Security from fastapi import Body, HTTPException
from redis.asyncio import Redis
from sqlmodel import select, text from sqlmodel import select, text
@@ -22,10 +21,10 @@ from sqlmodel import select, text
) )
async def create_oauth_app( async def create_oauth_app(
session: Database, session: Database,
name: str = Body(..., max_length=100, description="应用程序名称"), name: Annotated[str, Body(..., max_length=100, description="应用程序名称")],
description: str = Body("", description="应用程序描述"), redirect_uris: Annotated[list[str], Body(..., description="允许的重定向 URI 列表")],
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), current_user: ClientUser,
current_user: User = Security(get_client_user), description: Annotated[str, Body(description="应用程序描述")] = "",
): ):
result = await session.execute( result = await session.execute(
text( text(
@@ -64,7 +63,7 @@ async def create_oauth_app(
async def get_oauth_app( async def get_oauth_app(
session: Database, session: Database,
client_id: int, client_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
oauth_app = await session.get(OAuthClient, client_id) oauth_app = await session.get(OAuthClient, client_id)
if not oauth_app: if not oauth_app:
@@ -85,7 +84,7 @@ async def get_oauth_app(
) )
async def get_user_oauth_apps( async def get_user_oauth_apps(
session: Database, session: Database,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id)) oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
return [ return [
@@ -109,7 +108,7 @@ async def get_user_oauth_apps(
async def delete_oauth_app( async def delete_oauth_app(
session: Database, session: Database,
client_id: int, client_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
@@ -134,10 +133,10 @@ async def delete_oauth_app(
async def update_oauth_app( async def update_oauth_app(
session: Database, session: Database,
client_id: int, client_id: int,
name: str = Body(..., max_length=100, description="应用程序新名称"), name: Annotated[str, Body(..., max_length=100, description="应用程序新名称")],
description: str = Body("", description="应用程序新描述"), redirect_uris: Annotated[list[str], Body(..., description="新的重定向 URI 列表")],
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"), current_user: ClientUser,
current_user: User = Security(get_client_user), description: Annotated[str, Body(description="应用程序新描述")] = "",
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
@@ -168,7 +167,7 @@ async def update_oauth_app(
async def refresh_secret( async def refresh_secret(
session: Database, session: Database,
client_id: int, client_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
@@ -200,10 +199,10 @@ async def refresh_secret(
async def generate_oauth_code( async def generate_oauth_code(
session: Database, session: Database,
client_id: int, client_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
redirect_uri: str = Body(..., description="授权后重定向的 URI"), redirect_uri: Annotated[str, Body(..., description="授权后重定向的 URI")],
scopes: list[str] = Body(..., description="请求的权限范围列表"), scopes: Annotated[list[str], Body(..., description="请求的权限范围列表")],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
client = await session.get(OAuthClient, client_id) client = await session.get(OAuthClient, client_id)
if not client: if not client:

View File

@@ -1,13 +1,15 @@
from __future__ import annotations from __future__ import annotations
from app.database import Relationship, User from typing import Annotated
from app.database import Relationship
from app.database.relationship import RelationshipType from app.database.relationship import RelationshipType
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from .router import router from .router import router
from fastapi import HTTPException, Path, Security from fastapi import HTTPException, Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlmodel import select from sqlmodel import select
@@ -27,8 +29,8 @@ class CheckResponse(BaseModel):
) )
async def check_user_relationship( async def check_user_relationship(
db: Database, db: Database,
user_id: int = Path(..., description="目标用户的 ID"), user_id: Annotated[int, Path(..., description="目标用户的 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
if user_id == current_user.id: if user_id == current_user.id:
raise HTTPException(422, "Cannot check relationship with yourself") raise HTTPException(422, "Cannot check relationship with yourself")

View File

@@ -1,17 +1,14 @@
from __future__ import annotations from __future__ import annotations
from app.database.score import Score from app.database.score import Score
from app.database.user import User from app.dependencies.database import Database, Redis
from app.dependencies.database import Database, get_redis from app.dependencies.storage import StorageService
from app.dependencies.storage import get_storage_service from app.dependencies.user import ClientUser
from app.dependencies.user import get_client_user
from app.service.user_cache_service import refresh_user_cache_background from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from .router import router from .router import router
from fastapi import BackgroundTasks, Depends, HTTPException, Security from fastapi import BackgroundTasks, HTTPException
from redis.asyncio import Redis
@router.delete( @router.delete(
@@ -24,9 +21,9 @@ async def delete_score(
session: Database, session: Database,
background_task: BackgroundTasks, background_task: BackgroundTasks,
score_id: int, score_id: int,
redis: Redis = Depends(get_redis), redis: Redis,
current_user: User = Security(get_client_user), current_user: ClientUser,
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService,
): ):
"""删除成绩 """删除成绩

View File

@@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest from app.database.team import Team, TeamMember, TeamRequest
from app.database.user import BASE_INCLUDES, User, UserResp from app.database.user import BASE_INCLUDES, User, UserResp
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from app.models.notification import ( from app.models.notification import (
TeamApplicationAccept, TeamApplicationAccept,
TeamApplicationReject, TeamApplicationReject,
@@ -14,27 +15,25 @@ from app.models.notification import (
) )
from app.router.notification import server from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service from app.service.ranking_cache_service import get_ranking_cache_service
from app.storage.base import StorageService
from app.utils import check_image, utcnow from app.utils import check_image, utcnow
from .router import router from .router import router
from fastapi import Depends, File, Form, HTTPException, Path, Request, Security from fastapi import File, Form, HTTPException, Path, Request
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import exists, select from sqlmodel import exists, select
@router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"]) @router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"])
async def create_team( async def create_team(
session: Database, session: Database,
storage: StorageService = Depends(get_storage_service), storage: StorageService,
current_user: User = Security(get_client_user), current_user: ClientUser,
flag: bytes = File(..., description="战队图标文件"), flag: Annotated[bytes, File(..., description="战队图标文件")],
cover: bytes = File(..., description="战队头图文件"), cover: Annotated[bytes, File(..., description="战队头图文件")],
name: str = Form(max_length=100, description="战队名称"), name: Annotated[str, Form(max_length=100, description="战队名称")],
short_name: str = Form(max_length=10, description="战队缩写"), short_name: Annotated[str, Form(max_length=10, description="战队缩写")],
redis: Redis = Depends(get_redis), redis: Redis,
): ):
"""创建战队。 """创建战队。
@@ -88,13 +87,13 @@ async def create_team(
async def update_team( async def update_team(
team_id: int, team_id: int,
session: Database, session: Database,
storage: StorageService = Depends(get_storage_service), storage: StorageService,
current_user: User = Security(get_client_user), current_user: ClientUser,
flag: bytes | None = File(default=None, description="战队图标文件"), flag: Annotated[bytes | None, File(description="战队图标文件")] = None,
cover: bytes | None = File(default=None, description="战队头图文件"), cover: Annotated[bytes | None, File(description="战队头图文件")] = None,
name: str | None = Form(default=None, max_length=100, description="战队名称"), name: Annotated[str | None, Form(max_length=100, description="战队名称")] = None,
short_name: str | None = Form(default=None, max_length=10, description="战队缩写"), short_name: Annotated[str | None, Form(max_length=10, description="战队缩写")] = None,
leader_id: int | None = Form(default=None, description="战队队长 ID"), leader_id: Annotated[int | None, Form(description="战队队长 ID")] = None,
): ):
"""修改战队。 """修改战队。
@@ -161,9 +160,9 @@ async def update_team(
@router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"]) @router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"])
async def delete_team( async def delete_team(
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
@@ -191,7 +190,7 @@ class TeamQueryResp(BaseModel):
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"]) @router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
async def get_team( async def get_team(
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: Annotated[int, Path(..., description="战队 ID")],
): ):
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all() members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
return TeamQueryResp( return TeamQueryResp(
@@ -203,8 +202,8 @@ async def get_team(
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"]) @router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
async def request_join_team( async def request_join_team(
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
@@ -231,10 +230,10 @@ async def request_join_team(
async def handle_request( async def handle_request(
req: Request, req: Request,
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: int = Path(..., description="用户 ID"), user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
@@ -272,10 +271,10 @@ async def handle_request(
@router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"]) @router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"])
async def kick_member( async def kick_member(
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: int = Path(..., description="用户 ID"), user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.auth import ( from app.auth import (
check_totp_backup_code, check_totp_backup_code,
finish_create_totp_key, finish_create_totp_key,
@@ -9,17 +11,15 @@ from app.auth import (
) )
from app.const import BACKUP_CODE_LENGTH from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys from app.database.auth import TotpKeys
from app.database.user import User from app.dependencies.database import Database, Redis
from app.dependencies.database import Database, get_redis from app.dependencies.user import ClientUser
from app.dependencies.user import get_client_user
from app.models.totp import FinishStatus, StartCreateTotpKeyResp from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException, Security from fastapi import Body, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import pyotp import pyotp
from redis.asyncio import Redis
class TotpStatusResp(BaseModel): class TotpStatusResp(BaseModel):
@@ -37,7 +37,7 @@ class TotpStatusResp(BaseModel):
response_model=TotpStatusResp, response_model=TotpStatusResp,
) )
async def get_totp_status( async def get_totp_status(
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
"""检查用户是否已创建TOTP""" """检查用户是否已创建TOTP"""
totp_key = await current_user.awaitable_attrs.totp_key totp_key = await current_user.awaitable_attrs.totp_key
@@ -62,8 +62,8 @@ async def get_totp_status(
status_code=201, status_code=201,
) )
async def start_create_totp( async def start_create_totp(
redis: Redis = Depends(get_redis), redis: Redis,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
if await current_user.awaitable_attrs.totp_key: if await current_user.awaitable_attrs.totp_key:
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user") raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
@@ -98,9 +98,9 @@ async def start_create_totp(
) )
async def finish_create_totp( async def finish_create_totp(
session: Database, session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"), code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码")],
redis: Redis = Depends(get_redis), redis: Redis,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session) status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
if status == FinishStatus.SUCCESS: if status == FinishStatus.SUCCESS:
@@ -122,9 +122,9 @@ async def finish_create_totp(
) )
async def disable_totp( async def disable_totp(
session: Database, session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"), code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码或备份码")],
redis: Redis = Depends(get_redis), redis: Redis,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
totp = await session.get(TotpKeys, current_user.id) totp = await session.get(TotpKeys, current_user.id)
if not totp: if not totp:

View File

@@ -1,24 +1,26 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.auth import validate_username from app.auth import validate_username
from app.config import settings from app.config import settings
from app.database.events import Event, EventType from app.database.events import Event, EventType
from app.database.user import User from app.database.user import User
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import ClientUser
from app.utils import utcnow from app.utils import utcnow
from .router import router from .router import router
from fastapi import Body, HTTPException, Security from fastapi import Body, HTTPException
from sqlmodel import exists, select from sqlmodel import exists, select
@router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"]) @router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"])
async def user_rename( async def user_rename(
session: Database, session: Database,
new_name: str = Body(..., description="新的用户名"), new_name: Annotated[str, Body(..., description="新的用户名")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
"""修改用户名 """修改用户名

View File

@@ -1,24 +1,22 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Annotated, Literal
from app.database.beatmap import Beatmap, calculate_beatmap_attributes from app.database.beatmap import Beatmap, calculate_beatmap_attributes
from app.database.beatmap_playcounts import BeatmapPlaycounts from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.beatmapset import Beatmapset from app.database.beatmapset import Beatmapset
from app.database.favourite_beatmapset import FavouriteBeatmapset from app.database.favourite_beatmapset import FavouriteBeatmapset
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.mods import int_to_mods from app.models.mods import int_to_mods
from app.models.score import GameMode from app.models.score import GameMode
from .router import AllStrModel, router from .router import AllStrModel, router
from fastapi import Depends, Query from fastapi import Query
from redis.asyncio import Redis
from sqlmodel import col, func, select from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -148,18 +146,18 @@ class V1Beatmap(AllStrModel):
) )
async def get_beatmaps( async def get_beatmaps(
session: Database, session: Database,
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"), redis: Redis,
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), fetcher: Fetcher,
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), since: Annotated[datetime | None, Query(description="自指定时间后拥有排行榜的谱面")] = None,
user: str | None = Query(None, alias="u", description=""), beatmapset_id: Annotated[int | None, Query(alias="s", description="面集 ID")] = None,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), beatmap_id: Annotated[int | None, Query(alias="b", description="谱面 ID")] = None,
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO user: Annotated[str | None, Query(alias="u", description="谱师")] = None,
convert: bool = Query(False, alias="a", description="转谱"), # TODO type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0, le=3)] = None, # TODO
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), convert: Annotated[bool, Query(alias="a", description="转谱")] = False, # TODO
mods: int = Query(0, description="应用到谱面属性的 MOD"), checksum: Annotated[str | None, Query(alias="h", description="谱面文件 MD5")] = None,
redis: Redis = Depends(get_redis), limit: Annotated[int, Query(ge=1, le=500, description="返回结果数量限制")] = 500,
fetcher: Fetcher = Depends(get_fetcher), mods: Annotated[int, Query(description="应用到谱面属性的 MOD")] = 0,
): ):
beatmaps: list[Beatmap] = [] beatmaps: list[Beatmap] = []
results = [] results = []

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Annotated, Literal
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.database.user import User from app.database.user import User
@@ -181,9 +181,9 @@ async def _count_online_users_optimized(redis):
) )
async def api_get_player_info( async def api_get_player_info(
session: Database, session: Database,
scope: Literal["stats", "events", "info", "all"] = Query(..., description="信息范围"), scope: Annotated[Literal["stats", "events", "info", "all"], Query(..., description="信息范围")],
id: int | None = Query(None, ge=3, le=2147483647, description="用户 ID"), id: Annotated[int | None, Query(ge=3, le=2147483647, description="用户 ID")] = None,
name: str | None = Query(None, regex=r"^[\w \[\]-]{2,32}$", description="用户名"), name: Annotated[str | None, Query(regex=r"^[\w \[\]-]{2,32}$", description="用户名")] = None,
): ):
""" """
获取指定玩家的信息 获取指定玩家的信息

View File

@@ -2,15 +2,14 @@ from __future__ import annotations
import base64 import base64
from datetime import date from datetime import date
from typing import Literal from typing import Annotated, Literal
from app.database.counts import ReplayWatchedCount from app.database.counts import ReplayWatchedCount
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.models.mods import int_to_mods from app.models.mods import int_to_mods
from app.models.score import GameMode from app.models.score import GameMode
from app.storage import StorageService
from .router import router from .router import router
@@ -34,18 +33,20 @@ class ReplayModel(BaseModel):
) )
async def download_replay( async def download_replay(
session: Database, session: Database,
beatmap: int = Query(..., alias="b", description="谱面 ID"), beatmap: Annotated[int, Query(..., alias="b", description="谱面 ID")],
user: str = Query(..., alias="u", description="用户"), user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: int | None = Query( storage_service: StorageService,
None, ruleset_id: Annotated[
alias="m", int | None,
description="Ruleset ID", Query(
ge=0, alias="m",
), description="Ruleset ID",
score_id: int | None = Query(None, alias="s", description="成绩 ID"), ge=0,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), ),
mods: int = Query(0, description="成绩的 MOD"), ] = None,
storage_service: StorageService = Depends(get_storage_service), score_id: Annotated[int | None, Query(alias="s", description="成绩 ID")] = None,
type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
): ):
mods_ = int_to_mods(mods) mods_ = int_to_mods(mods)
if score_id is not None: if score_id is not None:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Literal from typing import Annotated, Literal
from app.database.best_scores import PPBestScore from app.database.best_scores import PPBestScore
from app.database.score import Score, get_leaderboard from app.database.score import Score, get_leaderboard
@@ -69,10 +69,10 @@ class V1Score(AllStrModel):
) )
async def get_user_best( async def get_user_best(
session: Database, session: Database,
user: str = Query(..., alias="u", description="用户"), user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
): ):
try: try:
scores = ( scores = (
@@ -101,10 +101,10 @@ async def get_user_best(
) )
async def get_user_recent( async def get_user_recent(
session: Database, session: Database,
user: str = Query(..., alias="u", description="用户"), user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
): ):
try: try:
scores = ( scores = (
@@ -133,12 +133,12 @@ async def get_user_recent(
) )
async def get_scores( async def get_scores(
session: Database, session: Database,
user: str | None = Query(None, alias="u", description="用户"), beatmap_id: Annotated[int, Query(alias="b", description="谱面 ID")],
beatmap_id: int = Query(alias="b", description="谱面 ID"), user: Annotated[str | None, Query(alias="u", description="用户")] = None,
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10,
mods: int = Query(0, description="成绩的 MOD"), mods: Annotated[int, Query(description="成绩的 MOD")] = 0,
): ):
try: try:
if user is not None: if user is not None:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Annotated, Literal
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.user import User from app.database.user import User
@@ -104,10 +104,10 @@ class V1User(AllStrModel):
async def get_user( async def get_user(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
user: str = Query(..., alias="u", description="用户"), user: Annotated[str, Query(..., alias="u", description="用户")],
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0)] = None,
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"), type: Annotated[Literal["string", "id"] | None, Query(description="用户类型string 用户名称 / id 用户 ID")] = None,
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"), event_days: Annotated[int, Query(ge=1, le=31, description="从现在起所有事件的最大天数")] = 1,
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
import json import json
from typing import Annotated
from app.database import Beatmap, BeatmapResp, User from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapAttributes from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, int_to_mods from app.models.mods import APIMod, int_to_mods
from app.models.score import ( from app.models.score import (
@@ -18,10 +18,9 @@ from app.models.score import (
from .router import router from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
import rosu_pp_py as rosu import rosu_pp_py as rosu
from sqlmodel import col, select from sqlmodel import col, select
@@ -44,11 +43,11 @@ class BatchGetResp(BaseModel):
) )
async def lookup_beatmap( async def lookup_beatmap(
db: Database, db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), fetcher: Fetcher,
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"), id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None,
current_user: User = Security(get_current_user, scopes=["public"]), md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None,
fetcher: Fetcher = Depends(get_fetcher), filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None,
): ):
if id is None and md5 is None and filename is None: if id is None and md5 is None and filename is None:
raise HTTPException( raise HTTPException(
@@ -75,9 +74,9 @@ async def lookup_beatmap(
) )
async def get_beatmap( async def get_beatmap(
db: Database, db: Database,
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
): ):
try: try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
@@ -95,9 +94,12 @@ async def get_beatmap(
) )
async def batch_get_beatmaps( async def batch_get_beatmaps(
db: Database, db: Database,
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"), beatmap_ids: Annotated[
current_user: User = Security(get_current_user, scopes=["public"]), list[int],
fetcher: Fetcher = Depends(get_fetcher), Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher,
): ):
if not beatmap_ids: if not beatmap_ids:
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all() beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
@@ -127,16 +129,19 @@ async def batch_get_beatmaps(
) )
async def get_beatmap_attributes( async def get_beatmap_attributes(
db: Database, db: Database,
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mods: list[str] = Query( mods: Annotated[
default_factory=list, list[str],
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称", Query(
), default_factory=list,
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"), description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3), ),
redis: Redis = Depends(get_redis), ],
fetcher: Fetcher = Depends(get_fetcher), redis: Redis,
fetcher: Fetcher,
ruleset: Annotated[GameMode | None, Query(description="指定 ruleset为空则使用谱面自身模式")] = None,
ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None,
): ):
mods_ = [] mods_ = []
if mods and mods[0].isdigit(): if mods and mods[0].isdigit():

View File

@@ -6,23 +6,20 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service from app.dependencies.beatmap_download import DownloadService
from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency from app.dependencies.beatmapset_cache import BeatmapsetCacheService
from app.dependencies.database import Database, get_redis, with_db from app.dependencies.database import Database, Redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import ClientUser, get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import SearchQueryModel from app.models.beatmap import SearchQueryModel
from app.service.asset_proxy_helper import process_response_assets from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmap_download_service import BeatmapDownloadService from app.service.beatmapset_cache_service import generate_hash
from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash
from .router import router from .router import router
from fastapi import ( from fastapi import (
BackgroundTasks, BackgroundTasks,
Depends,
Form, Form,
HTTPException, HTTPException,
Path, Path,
@@ -53,10 +50,10 @@ async def search_beatmapset(
query: Annotated[SearchQueryModel, Query(...)], query: Annotated[SearchQueryModel, Query(...)],
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
redis=Depends(get_redis), redis: Redis,
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), cache_service: BeatmapsetCacheService,
): ):
params = parse_qs(qs=request.url.query, keep_blank_values=True) params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {} cursor = {}
@@ -134,10 +131,10 @@ async def search_beatmapset(
async def lookup_beatmapset( async def lookup_beatmapset(
db: Database, db: Database,
request: Request, request: Request,
beatmap_id: int = Query(description="谱面 ID"), beatmap_id: Annotated[int, Query(description="谱面 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), cache_service: BeatmapsetCacheService,
): ):
# 先尝试从缓存获取 # 先尝试从缓存获取
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id) cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
@@ -170,10 +167,10 @@ async def lookup_beatmapset(
async def get_beatmapset( async def get_beatmapset(
db: Database, db: Database,
request: Request, request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), cache_service: BeatmapsetCacheService,
): ):
# 先尝试从缓存获取 # 先尝试从缓存获取
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id) cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
@@ -203,14 +200,12 @@ async def get_beatmapset(
description="\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。", description="\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。",
) )
async def download_beatmapset( async def download_beatmapset(
request: Request, client_ip: IPAddress,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"), current_user: ClientUser,
current_user: User = Security(get_client_user), download_service: DownloadService,
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True,
): ):
client_ip = get_client_ip(request)
geoip_helper = get_geoip_helper() geoip_helper = get_geoip_helper()
geo_info = geoip_helper.lookup(client_ip) geo_info = geoip_helper.lookup(client_ip)
country_code = geo_info.get("country_iso", "") country_code = geo_info.get("country_iso", "")
@@ -242,9 +237,12 @@ async def download_beatmapset(
) )
async def favourite_beatmapset( async def favourite_beatmapset(
db: Database, db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")],
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"), action: Annotated[
current_user: User = Security(get_client_user), Literal["favourite", "unfavourite"],
Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
],
current_user: ClientUser,
): ):
existing_favourite = ( existing_favourite = (
await db.exec( await db.exec(

View File

@@ -5,14 +5,13 @@
from __future__ import annotations from __future__ import annotations
from app.dependencies.database import get_redis from app.dependencies.database import Redis
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
from .router import router from .router import router
from fastapi import Depends, HTTPException from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
class CacheStatsResponse(BaseModel): class CacheStatsResponse(BaseModel):
@@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel):
tags=["缓存管理"], tags=["缓存管理"],
) )
async def get_cache_stats( async def get_cache_stats(
redis: Redis = Depends(get_redis), redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用 # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
): ):
try: try:
@@ -68,7 +67,7 @@ async def get_cache_stats(
) )
async def invalidate_user_cache( async def invalidate_user_cache(
user_id: int, user_id: int,
redis: Redis = Depends(get_redis), redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
): ):
try: try:
@@ -87,7 +86,7 @@ async def invalidate_user_cache(
tags=["缓存管理"], tags=["缓存管理"],
) )
async def clear_all_user_cache( async def clear_all_user_cache(
redis: Redis = Depends(get_redis), redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
): ):
try: try:
@@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel):
) )
async def warmup_cache( async def warmup_cache(
request: CacheWarmupRequest, request: CacheWarmupRequest,
redis: Redis = Depends(get_redis), redis: Redis,
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
): ):
try: try:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database import MeResp, User from app.database import MeResp, User
from app.dependencies import get_current_user
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user_and_token from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.exceptions.userpage import UserpageError from app.exceptions.userpage import UserpageError
from app.models.score import GameMode from app.models.score import GameMode
from app.models.user import Page from app.models.user import Page
@@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security
) )
async def get_user_info_with_ruleset( async def get_user_info_with_ruleset(
session: Database, session: Database,
ruleset: GameMode = Path(description="指定 ruleset"), ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
): ):
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id) user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
return user_resp return user_resp
@@ -45,7 +46,7 @@ async def get_user_info_with_ruleset(
) )
async def get_user_info_default( async def get_user_info_default(
session: Database, session: Database,
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
): ):
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id) user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
return user_resp return user_resp
@@ -85,8 +86,8 @@ async def get_user_info_default(
async def update_userpage( async def update_userpage(
request: UpdateUserpageRequest, request: UpdateUserpageRequest,
session: Database, session: Database,
user_id: int = Path(description="用户ID"), user_id: Annotated[int, Path(description="用户ID")],
current_user: User = Security(get_current_user, scopes=["edit"]), current_user: Annotated[User, Security(get_current_user, scopes=["edit"])],
): ):
"""更新用户页面内容匹配官方osu-web实现""" """更新用户页面内容匹配官方osu-web实现"""
# 检查权限:只能编辑自己的页面(除非是管理员) # 检查权限:只能编辑自己的页面(除非是管理员)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Annotated, Literal
from app.config import settings from app.config import settings
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service from app.service.ranking_cache_service import get_ranking_cache_service
@@ -45,11 +45,11 @@ SortType = Literal["performance", "score"]
async def get_team_ranking_pp( async def get_team_ranking_pp(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
page: int = Query(1, ge=1, description="页码"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
current_user: User = Security(get_current_user, scopes=["public"]), page: Annotated[int, Query(ge=1, description="页码")] = 1,
): ):
return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user) return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page)
@router.get( @router.get(
@@ -62,14 +62,17 @@ async def get_team_ranking_pp(
async def get_team_ranking( async def get_team_ranking(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
sort: SortType = Path( sort: Annotated[
..., SortType,
description="排名类型performance 表现分 / score 计分成绩总分 " Path(
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", ...,
), description="排名类型performance 表现分 / score 计分成绩总分 "
ruleset: GameMode = Path(..., description="指定 ruleset"), "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
page: int = Query(1, ge=1, description="页码"), ),
current_user: User = Security(get_current_user, scopes=["public"]), ],
ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
): ):
# 获取 Redis 连接和缓存服务 # 获取 Redis 连接和缓存服务
redis = get_redis() redis = get_redis()
@@ -193,11 +196,11 @@ class CountryResponse(BaseModel):
async def get_country_ranking_pp( async def get_country_ranking_pp(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
page: int = Query(1, ge=1, description="页码"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
current_user: User = Security(get_current_user, scopes=["public"]), page: Annotated[int, Query(ge=1, description="页码")] = 1,
): ):
return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user) return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page)
@router.get( @router.get(
@@ -210,14 +213,17 @@ async def get_country_ranking_pp(
async def get_country_ranking( async def get_country_ranking(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
page: int = Query(1, ge=1, description="页码"), sort: Annotated[
sort: SortType = Path( SortType,
..., Path(
description="排名类型performance 表现分 / score 计分成绩总分 " ...,
"**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", description="排名类型performance 表现分 / score 计分成绩总分 "
), "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**",
current_user: User = Security(get_current_user, scopes=["public"]), ),
],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: Annotated[int, Query(ge=1, description="页码")] = 1,
): ):
# 获取 Redis 连接和缓存服务 # 获取 Redis 连接和缓存服务
redis = get_redis() redis = get_redis()
@@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel):
async def get_user_ranking( async def get_user_ranking(
session: Database, session: Database,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")],
sort: SortType = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"), sort: Annotated[SortType, Path(..., description="排名类型performance 表现分 / score 计分成绩总分")],
country: str | None = Query(None, description="国家代码"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
page: int = Query(1, ge=1, description=""), country: Annotated[str | None, Query(description="国家代")] = None,
current_user: User = Security(get_current_user, scopes=["public"]), page: Annotated[int, Query(ge=1, description="页码")] = 1,
): ):
# 获取 Redis 连接和缓存服务 # 获取 Redis 连接和缓存服务
redis = get_redis() redis = get_redis()

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database import Relationship, RelationshipResp, RelationshipType, User from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.database.user import UserResp from app.database.user import UserResp
from app.dependencies.api_version import APIVersion from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import ClientUser, get_current_user
from .router import router from .router import router
@@ -56,7 +58,7 @@ async def get_relationship(
db: Database, db: Database,
request: Request, request: Request,
api_version: APIVersion, api_version: APIVersion,
current_user: User = Security(get_current_user, scopes=["friends.read"]), current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])],
): ):
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec( relationships = await db.exec(
@@ -107,8 +109,8 @@ class AddFriendResp(BaseModel):
async def add_relationship( async def add_relationship(
db: Database, db: Database,
request: Request, request: Request,
target: int = Query(description="目标用户 ID"), target: Annotated[int, Query(description="目标用户 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
if not (await db.exec(select(exists()).where(User.id == target))).first(): if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found") raise HTTPException(404, "Target user not found")
@@ -176,8 +178,8 @@ async def add_relationship(
async def delete_relationship( async def delete_relationship(
db: Database, db: Database,
request: Request, request: Request,
target: int = Path(..., description="目标用户 ID"), target: Annotated[int, Path(..., description="目标用户 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
if not (await db.exec(select(exists()).where(User.id == target))).first(): if not (await db.exec(select(exists()).where(User.id == target))).first():
raise HTTPException(404, "Target user not found") raise HTTPException(404, "Target user not found")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC from datetime import UTC
from typing import Literal from typing import Annotated, Literal
from app.database.beatmap import Beatmap, BeatmapResp from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp from app.database.beatmapset import BeatmapsetResp
@@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score from app.database.score import Score
from app.database.user import User, UserResp from app.database.user import User, UserResp
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import ClientUser, get_current_user
from app.models.room import RoomCategory, RoomStatus from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api from app.service.room import create_playlist_room_from_api
from app.signalr.hub import MultiplayerHubs from app.signalr.hub import MultiplayerHubs
@@ -21,9 +21,8 @@ from app.utils import utcnow
from .router import router from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession
) )
async def get_all_rooms( async def get_all_rooms(
db: Database, db: Database,
mode: Literal["open", "ended", "participated", "owned"] | None = Query( current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
default="open", mode: Annotated[
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"), Literal["open", "ended", "participated", "owned"] | None,
), Query(
category: RoomCategory = Query( description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
RoomCategory.NORMAL, ),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"), ] = "open",
), category: Annotated[
status: RoomStatus | None = Query(None, description="房间状态(可选)"), RoomCategory,
current_user: User = Security(get_current_user, scopes=["public"]), Query(
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
): ):
resp_list: list[RoomResp] = [] resp_list: list[RoomResp] = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category] where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
@@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
async def create_room( async def create_room(
db: Database, db: Database,
room: APIUploadedRoom, room: APIUploadedRoom,
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
user_id = current_user.id user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id) db_room = await create_playlist_room_from_api(db, room, user_id)
@@ -162,13 +165,15 @@ async def create_room(
) )
async def get_room( async def get_room(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
category: str = Query( current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
default="", redis: Redis,
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"), category: Annotated[
), str,
current_user: User = Security(get_current_user, scopes=["public"]), Query(
redis: Redis = Depends(get_redis), description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
] = "",
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None: if db_room is None:
@@ -185,8 +190,8 @@ async def get_room(
) )
async def delete_room( async def delete_room(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None: if db_room is None:
@@ -205,10 +210,10 @@ async def delete_room(
) )
async def add_user_to_room( async def add_user_to_room(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: int = Path(..., description="用户 ID"), user_id: Annotated[int, Path(..., description="用户 ID")],
redis: Redis = Depends(get_redis), redis: Redis,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
@@ -229,10 +234,10 @@ async def add_user_to_room(
) )
async def remove_user_from_room( async def remove_user_from_room(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
user_id: int = Path(..., description="用户 ID"), user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
@@ -273,8 +278,8 @@ class APILeaderboard(BaseModel):
) )
async def get_room_leaderboard( async def get_room_leaderboard(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None: if db_room is None:
@@ -329,11 +334,11 @@ class RoomEvents(BaseModel):
) )
async def get_room_events( async def get_room_events(
db: Database, db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"), after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None,
before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"), before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None,
): ):
events = ( events = (
await db.exec( await db.exec(

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, date from datetime import UTC, date
import time import time
from typing import Annotated
from app.calculator import clamp from app.calculator import clamp
from app.config import settings from app.config import settings
@@ -34,11 +35,10 @@ from app.database.score import (
process_user, process_user,
) )
from app.dependencies.api_version import APIVersion from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis, with_db from app.dependencies.database import Database, Redis, get_redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import Fetcher, get_fetcher
from app.dependencies.storage import get_storage_service from app.dependencies.storage import StorageService
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import ClientUser, get_current_user
from app.fetcher import Fetcher
from app.log import logger from app.log import logger
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.room import RoomCategory from app.models.room import RoomCategory
@@ -50,7 +50,6 @@ from app.models.score import (
) )
from app.service.beatmap_cache_service import get_beatmap_cache_service from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from app.utils import utcnow from app.utils import utcnow
from .router import router from .router import router
@@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse
from fastapi_limiter.depends import RateLimiter from fastapi_limiter.depends import RateLimiter
from httpx import HTTPError from httpx import HTTPError
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, func, select from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
async def get_beatmap_scores( async def get_beatmap_scores(
db: Database, db: Database,
api_version: APIVersion, api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: Annotated[int, Path(description="谱面 ID")],
mode: GameMode = Query(description="指定 auleset"), mode: Annotated[GameMode, Query(description="指定 auleset")],
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")],
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
type: LeaderboardType = Query( legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
LeaderboardType.GLOBAL, type: Annotated[
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"), LeaderboardType,
), Query(
current_user: User = Security(get_current_user, scopes=["public"]), description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), ),
] = LeaderboardType.GLOBAL,
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
): ):
if legacy_only: if legacy_only:
raise HTTPException(status_code=404, detail="this server only contains lazer scores") raise HTTPException(status_code=404, detail="this server only contains lazer scores")
@@ -294,12 +294,12 @@ async def get_beatmap_scores(
async def get_user_beatmap_score( async def get_user_beatmap_score(
db: Database, db: Database,
api_version: APIVersion, api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"), legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"), mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
current_user: User = Security(get_current_user, scopes=["public"]), mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None,
): ):
user_score = ( user_score = (
await db.exec( await db.exec(
@@ -342,11 +342,11 @@ async def get_user_beatmap_score(
async def get_user_all_beatmap_scores( async def get_user_all_beatmap_scores(
db: Database, db: Database,
api_version: APIVersion, api_version: APIVersion,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: Annotated[int, Path(description="谱面 ID")],
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"), legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None,
current_user: User = Security(get_current_user, scopes=["public"]), ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None,
): ):
all_user_scores = ( all_user_scores = (
await db.exec( await db.exec(
@@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores(
async def create_solo_score( async def create_solo_score(
background_task: BackgroundTasks, background_task: BackgroundTasks,
db: Database, db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: Annotated[int, Path(description="谱面 ID")],
version_hash: str = Form("", description="游戏版本哈希"), beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
beatmap_hash: str = Form(description="谱面文件哈希"), ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), current_user: ClientUser,
current_user: User = Security(get_client_user), version_hash: Annotated[str, Form(description="游戏版本哈希")] = "",
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -406,12 +406,12 @@ async def create_solo_score(
async def submit_solo_score( async def submit_solo_score(
background_task: BackgroundTasks, background_task: BackgroundTasks,
db: Database, db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: Annotated[int, Path(description="谱面 ID")],
token: int = Path(description="成绩令牌 ID"), token: Annotated[int, Path(description="成绩令牌 ID")],
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"), info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")],
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
fetcher=Depends(get_fetcher), fetcher: Fetcher,
): ):
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher) return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@@ -428,11 +428,11 @@ async def create_playlist_score(
background_task: BackgroundTasks, background_task: BackgroundTasks,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
beatmap_id: int = Form(description="谱面 ID"), beatmap_id: Annotated[int, Form(description="谱面 ID")],
beatmap_hash: str = Form(description="游戏版本哈希"), beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")],
version_hash: str = Form("", description="谱面版本哈希"), current_user: ClientUser,
current_user: User = Security(get_client_user), version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -496,9 +496,9 @@ async def submit_playlist_score(
playlist_id: int, playlist_id: int,
token: int, token: int,
info: SoloScoreSubmissionInfo, info: SoloScoreSubmissionInfo,
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -555,9 +555,9 @@ async def index_playlist_scores(
session: Database, session: Database,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"), limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50,
current_user: User = Security(get_current_user, scopes=["public"]), cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -623,8 +623,8 @@ async def show_playlist_score(
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
score_id: int, score_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
redis: Redis = Depends(get_redis), redis: Redis,
): ):
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
@@ -692,7 +692,7 @@ async def get_user_playlist_score(
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
user_id: int, user_id: int,
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
score_record = None score_record = None
start_time = time.time() start_time = time.time()
@@ -725,8 +725,8 @@ async def get_user_playlist_score(
) )
async def pin_score( async def pin_score(
db: Database, db: Database,
score_id: int = Path(description="成绩 ID"), score_id: Annotated[int, Path(description="成绩 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -770,8 +770,8 @@ async def pin_score(
) )
async def unpin_score( async def unpin_score(
db: Database, db: Database,
score_id: int = Path(description="成绩 ID"), score_id: Annotated[int, Path(description="成绩 ID")],
current_user: User = Security(get_client_user), current_user: ClientUser,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -805,10 +805,10 @@ async def unpin_score(
) )
async def reorder_score_pin( async def reorder_score_pin(
db: Database, db: Database,
score_id: int = Path(description="成绩 ID"), score_id: Annotated[int, Path(description="成绩 ID")],
after_score_id: int | None = Body(default=None, description="放在该成绩之后"), current_user: ClientUser,
before_score_id: int | None = Body(default=None, description="放在该成绩之"), after_score_id: Annotated[int | None, Body(description="放在该成绩之")] = None,
current_user: User = Security(get_client_user), before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -893,8 +893,8 @@ async def reorder_score_pin(
async def download_score_replay( async def download_score_replay(
score_id: int, score_id: int,
db: Database, db: Database,
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService,
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id

View File

@@ -11,8 +11,8 @@ from app.config import settings
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
from app.database.auth import TotpKeys from app.database.auth import TotpKeys
from app.dependencies.api_version import APIVersion from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, Redis, get_redis
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import IPAddress
from app.dependencies.user import UserAndToken, get_client_user_and_token from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.dependencies.user_agent import UserAgentInfo from app.dependencies.user_agent import UserAgentInfo
from app.log import logger from app.log import logger
@@ -27,7 +27,6 @@ from .router import router
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
class VerifyMethod(BaseModel): class VerifyMethod(BaseModel):
@@ -64,10 +63,14 @@ async def verify_session(
db: Database, db: Database,
api_version: APIVersion, api_version: APIVersion,
user_agent: UserAgentInfo, user_agent: UserAgentInfo,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"), verification_key: Annotated[
user_and_token: UserAndToken = Security(get_client_user_and_token), str,
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
],
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
) -> Response: ) -> Response:
current_user = user_and_token[0] current_user = user_and_token[0]
token_id = user_and_token[1].id token_id = user_and_token[1].id
@@ -82,7 +85,6 @@ async def verify_session(
else await LoginSessionService.get_login_method(user_id, token_id, redis) else await LoginSessionService.get_login_method(user_id, token_id, redis)
) )
ip_address = get_client_ip(request)
login_method = "password" login_method = "password"
try: try:
@@ -182,12 +184,12 @@ async def verify_session(
tags=["验证"], tags=["验证"],
) )
async def reissue_verification_code( async def reissue_verification_code(
request: Request,
db: Database, db: Database,
user_agent: UserAgentInfo, user_agent: UserAgentInfo,
api_version: APIVersion, api_version: APIVersion,
ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> SessionReissueResponse: ) -> SessionReissueResponse:
current_user = user_and_token[0] current_user = user_and_token[0]
token_id = user_and_token[1].id token_id = user_and_token[1].id
@@ -203,7 +205,6 @@ async def reissue_verification_code(
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码") return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
try: try:
ip_address = get_client_ip(request)
user_id = current_user.id user_id = current_user.id
success, message = await EmailVerificationService.resend_verification_code( success, message = await EmailVerificationService.resend_verification_code(
db, db,
@@ -233,17 +234,15 @@ async def reissue_verification_code(
async def fallback_email( async def fallback_email(
db: Database, db: Database,
user_agent: UserAgentInfo, user_agent: UserAgentInfo,
request: Request, ip_address: IPAddress,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token), user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
) -> VerifyMethod: ) -> VerifyMethod:
current_user = user_and_token[0] current_user = user_and_token[0]
token_id = user_and_token[1].id token_id = user_and_token[1].id
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis): if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
ip_address = get_client_ip(request)
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis) await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
success, message = await EmailVerificationService.resend_verification_code( success, message = await EmailVerificationService.resend_verification_code(
db, db,

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.database.beatmap_tags import BeatmapTagVote from app.database.beatmap_tags import BeatmapTagVote
from app.database.score import Score from app.database.score import Score
from app.database.user import User from app.database.user import User
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from app.models.score import Rank from app.models.score import Rank
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
@@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
description="为指定谱面添加标签投票。", description="为指定谱面添加标签投票。",
) )
async def vote_beatmap_tags( async def vote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: int = Path(..., description="标签 ID"), tag_id: Annotated[int, Path(..., description="标签 ID")],
session: AsyncSession = Depends(get_db), session: Database,
current_user: User = Depends(get_client_user), current_user: Annotated[User, Depends(get_client_user)],
): ):
try: try:
get_tag_by_id(tag_id) get_tag_by_id(tag_id)
@@ -90,10 +92,10 @@ async def vote_beatmap_tags(
description="取消对指定谱面标签的投票。", description="取消对指定谱面标签的投票。",
) )
async def devote_beatmap_tags( async def devote_beatmap_tags(
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: Annotated[int, Path(..., description="谱面 ID")],
tag_id: int = Path(..., description="标签 ID"), tag_id: Annotated[int, Path(..., description="标签 ID")],
session: AsyncSession = Depends(get_db), session: Database,
current_user: User = Depends(get_client_user), current_user: Annotated[User, Depends(get_client_user)],
): ):
""" """
取消对谱面指定标签的投票。 取消对谱面指定标签的投票。

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from typing import Literal from typing import Annotated, Literal
from app.config import settings from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
@@ -51,9 +51,12 @@ async def get_users(
session: Database, session: Database,
request: Request, request: Request,
background_task: BackgroundTasks, background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"), user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")],
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use include_variant_statistics: Annotated[
bool,
Query(description="是否包含各模式的统计信息"),
] = False, # TODO: future use
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -119,9 +122,9 @@ async def get_users(
) )
async def get_user_events( async def get_user_events(
session: Database, session: Database,
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
limit: int | None = Query(None, description="限制返回的活动数量"), limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None,
offset: int | None = Query(None, description="活动日志的偏移量"), offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
): ):
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
if db_user is None or db_user.id == BANCHOBOT_ID: if db_user is None or db_user.id == BANCHOBOT_ID:
@@ -147,9 +150,9 @@ async def get_user_events(
) )
async def get_user_kudosu( async def get_user_kudosu(
session: Database, session: Database,
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
offset: int = Query(default=0, description="偏移量"), offset: Annotated[int, Query(description="偏移量")] = 0,
limit: int = Query(default=6, description="返回记录数量限制"), limit: Annotated[int, Query(description="返回记录数量限制")] = 6,
): ):
""" """
获取用户的 kudosu 记录 获取用户的 kudosu 记录
@@ -176,8 +179,8 @@ async def get_user_kudosu(
async def get_user_info_ruleset( async def get_user_info_ruleset(
session: Database, session: Database,
background_task: BackgroundTasks, background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"), user_id: Annotated[str, Path(description="用户 ID 或用户名")],
ruleset: GameMode | None = Path(description="指定 ruleset"), ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
): ):
redis = get_redis() redis = get_redis()
@@ -225,7 +228,7 @@ async def get_user_info(
background_task: BackgroundTasks, background_task: BackgroundTasks,
session: Database, session: Database,
request: Request, request: Request,
user_id: str = Path(description="用户 ID 或用户名"), user_id: Annotated[str, Path(description="用户 ID 或用户名")],
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
): ):
redis = get_redis() redis = get_redis()
@@ -274,11 +277,11 @@ async def get_user_info(
async def get_user_beatmapsets( async def get_user_beatmapsets(
session: Database, session: Database,
background_task: BackgroundTasks, background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
type: BeatmapsetType = Path(description="谱面集类型"), type: Annotated[BeatmapsetType, Path(description="谱面集类型")],
current_user: User = Security(get_current_user, scopes=["public"]), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: int = Query(0, ge=0, description="偏移量"), offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -356,16 +359,17 @@ async def get_user_scores(
session: Database, session: Database,
api_version: APIVersion, api_version: APIVersion,
background_task: BackgroundTasks, background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"), user_id: Annotated[int, Path(description="用户 ID")],
type: Literal["best", "recent", "firsts", "pinned"] = Path( type: Annotated[
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩") Literal["best", "recent", "firsts", "pinned"],
), Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"), ],
include_fails: bool = Query(False, description="是否包含失败的成绩"), current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"), legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False,
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False,
offset: int = Query(0, ge=0, description="偏移量"), mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None,
current_user: User = Security(get_current_user, scopes=["public"]), limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100,
offset: Annotated[int, Query(ge=0, description="偏移量")] = 0,
): ):
is_legacy_api = api_version < 20220705 is_legacy_api = api_version < 20220705
redis = get_redis() redis = get_redis()

View File

@@ -7,9 +7,8 @@ from typing import Literal
import uuid import uuid
from app.database import User as DBUser from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import DBFactory, get_db_factory from app.dependencies.database import DBFactory, get_db_factory
from app.dependencies.user import get_current_user_and_token from app.dependencies.user import get_current_user, get_current_user_and_token
from app.log import logger from app.log import logger
from app.models.signalr import NegotiateResponse, Transport from app.models.signalr import NegotiateResponse, Transport