refactor(app): update database code

This commit is contained in:
MingxuanGame
2025-08-18 16:37:30 +00:00
parent 6bae937e01
commit 1c65b21bb9
34 changed files with 167 additions and 188 deletions

View File

@@ -3,9 +3,11 @@ from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextvars import ContextVar
import json
from typing import Annotated
from app.config import settings
from fastapi import Depends
from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
@@ -52,7 +54,12 @@ async def get_db():
yield session
def with_db():
return AsyncSession(engine)
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
Database = Annotated[AsyncSession, Depends(get_db)]
async def get_db_factory() -> DBFactory:

View File

@@ -8,7 +8,7 @@ from app.database import User
from app.database.auth import V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import get_db
from .database import Database
from fastapi import Depends, HTTPException
from fastapi.security import (
@@ -19,7 +19,6 @@ from fastapi.security import (
SecurityScopes,
)
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
@@ -64,7 +63,7 @@ v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API
async def v1_authorize(
db: Annotated[AsyncSession, Depends(get_db)],
db: Database,
api_key: Annotated[str, Depends(v1_api_key)],
):
"""V1 API Key 授权"""
@@ -79,8 +78,8 @@ async def v1_authorize(
async def get_client_user(
db: Database,
token: Annotated[str, Depends(oauth2_password)],
db: Annotated[AsyncSession, Depends(get_db)],
):
token_record = await get_token_by_access_token(db, token)
if not token_record:
@@ -95,8 +94,8 @@ async def get_client_user(
async def get_current_user(
db: Database,
security_scopes: SecurityScopes,
db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[

View File

@@ -18,7 +18,7 @@ from typing import (
)
from app.database.beatmap import Beatmap
from app.dependencies.database import engine
from app.dependencies.database import with_db
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
@@ -41,7 +41,6 @@ from .signalr import (
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.database.room import Room
@@ -473,7 +472,7 @@ class MultiplayerQueue:
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with AsyncSession(engine) as session:
async with with_db() as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
@@ -522,7 +521,7 @@ class MultiplayerQueue:
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
@@ -548,7 +547,7 @@ class MultiplayerQueue:
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
@@ -622,7 +621,7 @@ class MultiplayerQueue:
"Attempted to remove an item which has already been played"
)
async with AsyncSession(engine) as session:
async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
@@ -637,7 +636,7 @@ class MultiplayerQueue:
async def finish_current_item(self):
from app.database import Playlist
async with AsyncSession(engine) as session:
async with with_db() as session:
played_at = datetime.now(UTC)
await session.execute(
update(Playlist)

View File

@@ -92,7 +92,7 @@ class GameMode(str, Enum):
def parse(cls, v: str | int) -> "GameMode | None":
if isinstance(v, int) or v.isdigit():
return cls.from_int_extra(int(v))
v = v.lower()
v = v.upper()
try:
return cls[v]
except ValueError:

View File

@@ -18,8 +18,7 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
@@ -37,7 +36,6 @@ from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def create_oauth_error_response(
@@ -89,11 +87,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
description="用户注册接口",
)
async def register_user(
db: Database,
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
username_errors = validate_username(user_username)
@@ -205,6 +203,7 @@ async def register_user(
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
)
async def oauth_token(
db: Database,
request: Request,
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
@@ -218,7 +217,6 @@ async def oauth_token(
refresh_token: str | None = Form(
None, description="刷新令牌(仅刷新令牌模式需要)"
),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
scopes = scope.split(" ")

View File

@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.lazer_user import User, UserResp
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
@@ -22,7 +22,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class UpdateResponse(BaseModel):
@@ -38,6 +37,7 @@ class UpdateResponse(BaseModel):
tags=["聊天"],
)
async def get_update(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
@@ -46,7 +46,6 @@ async def get_update(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
@@ -101,10 +100,10 @@ async def get_update(
tags=["聊天"],
)
async def join_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
@@ -121,10 +120,10 @@ async def join_channel(
tags=["聊天"],
)
async def leave_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
@@ -142,8 +141,8 @@ async def leave_channel(
tags=["聊天"],
)
async def get_channel_list(
session: Database,
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
channels = (
@@ -181,9 +180,9 @@ class GetChannelResp(BaseModel):
tags=["聊天"],
)
async def get_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
@@ -250,9 +249,9 @@ class CreateChannelReq(BaseModel):
tags=["聊天"],
)
async def create_channel(
session: Database,
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
if req.type == "PM":

View File

@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
@@ -23,7 +23,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class KeepAliveResp(BaseModel):
@@ -38,12 +37,12 @@ class KeepAliveResp(BaseModel):
tags=["聊天"],
)
async def keep_alive(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
resp = KeepAliveResp()
if history_since:
@@ -84,10 +83,10 @@ class MessageReq(BaseModel):
tags=["聊天"],
)
async def send_message(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
@@ -125,12 +124,12 @@ async def send_message(
tags=["聊天"],
)
async def get_message(
session: Database,
channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
@@ -158,10 +157,10 @@ async def get_message(
tags=["聊天"],
)
async def mark_as_read(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
@@ -191,9 +190,9 @@ class NewPMResp(BaseModel):
tags=["聊天"],
)
async def create_new_pm(
session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
user_id = current_user.id

View File

@@ -6,9 +6,9 @@ from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMes
from app.database.lazer_user import User
from app.dependencies.database import (
DBFactory,
engine,
get_db_factory,
get_redis,
with_db,
)
from app.dependencies.user import get_current_user
from app.log import logger
@@ -200,7 +200,7 @@ class ChatServer:
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
@@ -212,7 +212,7 @@ class ChatServer:
await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
@@ -268,7 +268,7 @@ async def chat_websocket(
token = authorization[7:]
if (
user := await get_current_user(
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
await websocket.close(code=1008)

View File

@@ -4,7 +4,7 @@ import hashlib
from io import BytesIO
from app.database.lazer_user import User
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
@@ -13,7 +13,6 @@ from .router import router
from fastapi import Depends, File, HTTPException, Security
from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post(
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="上传头像",
)
async def upload_avatar(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db),
):
"""上传用户头像

View File

@@ -4,7 +4,7 @@ import hashlib
from io import BytesIO
from app.database.lazer_user import User, UserProfileCover
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
@@ -13,7 +13,6 @@ from .router import router
from fastapi import Depends, File, HTTPException, Security
from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post(
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="上传头图",
)
async def upload_cover(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db),
):
"""上传用户头图

View File

@@ -4,7 +4,7 @@ import secrets
from app.database.auth import OAuthClient, OAuthToken
from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from .router import router
@@ -12,7 +12,6 @@ from .router import router
from fastapi import Body, Depends, HTTPException, Security
from redis.asyncio import Redis
from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post(
@@ -21,11 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥",
)
async def create_oauth_app(
session: Database,
name: str = Body(..., max_length=100, description="应用程序名称"),
description: str = Body("", description="应用程序描述"),
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
):
result = await session.execute( # pyright: ignore[reportDeprecated]
text(
@@ -61,8 +60,8 @@ async def create_oauth_app(
description="通过客户端 ID 获取 OAuth 应用的详细信息",
)
async def get_oauth_app(
session: Database,
client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
oauth_app = await session.get(OAuthClient, client_id)
@@ -82,7 +81,7 @@ async def get_oauth_app(
description="获取当前用户创建的所有 OAuth 应用程序",
)
async def get_user_oauth_apps(
session: AsyncSession = Depends(get_db),
session: Database,
current_user: User = Security(get_client_user),
):
oauth_apps = await session.exec(
@@ -106,8 +105,8 @@ async def get_user_oauth_apps(
description="删除指定的 OAuth 应用程序及其关联的所有令牌",
)
async def delete_oauth_app(
session: Database,
client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
oauth_client = await session.get(OAuthClient, client_id)
@@ -134,11 +133,11 @@ async def delete_oauth_app(
description="更新指定 OAuth 应用的名称、描述和重定向 URI",
)
async def update_oauth_app(
session: Database,
client_id: int,
name: str = Body(..., max_length=100, description="应用程序新名称"),
description: str = Body("", description="应用程序新描述"),
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
oauth_client = await session.get(OAuthClient, client_id)
@@ -169,8 +168,8 @@ async def update_oauth_app(
description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效",
)
async def refresh_secret(
session: Database,
client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
oauth_client = await session.get(OAuthClient, client_id)
@@ -204,11 +203,11 @@ async def refresh_secret(
description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程",
)
async def generate_oauth_code(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
scopes: list[str] = Body(..., description="请求的权限范围列表"),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
client = await session.get(OAuthClient, client_id)

View File

@@ -2,15 +2,14 @@ from __future__ import annotations
from app.database import Relationship, User
from app.database.relationship import RelationshipType
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from .router import router
from fastapi import Depends, HTTPException, Path, Security
from fastapi import HTTPException, Path, Security
from pydantic import BaseModel, Field
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class CheckResponse(BaseModel):
@@ -26,9 +25,9 @@ class CheckResponse(BaseModel):
response_model=CheckResponse,
)
async def check_user_relationship(
db: Database,
user_id: int = Path(..., description="目标用户的 ID"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
if user_id == current_user.id:
raise HTTPException(422, "Cannot check relationship with yourself")

View File

@@ -6,14 +6,13 @@ from app.auth import validate_username
from app.config import settings
from app.database.events import Event, EventType
from app.database.lazer_user import User
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from fastapi import Body, HTTPException, Security
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post(
@@ -21,8 +20,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="修改用户名",
)
async def user_rename(
session: Database,
new_name: str = Body(..., description="新的用户名"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
"""修改用户名

View File

@@ -8,7 +8,7 @@ from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.beatmapset import Beatmapset
from app.database.favourite_beatmapset import FavouriteBeatmapset
from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus, Genre, Language
@@ -149,6 +149,7 @@ class V1Beatmap(AllStrModel):
description="根据指定条件搜索谱面。",
)
async def get_beatmaps(
session: Database,
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
@@ -163,7 +164,6 @@ async def get_beatmaps(
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
mods: int = Query(0, description="应用到谱面属性的 MOD"),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):

View File

@@ -6,7 +6,7 @@ from typing import Literal
from app.database.counts import ReplayWatchedCount
from app.database.score import Score
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.models.mods import int_to_mods
from app.models.score import GameMode
@@ -17,7 +17,6 @@ from .router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class ReplayModel(BaseModel):
@@ -32,6 +31,7 @@ class ReplayModel(BaseModel):
description="获取指定谱面的回放文件。",
)
async def download_replay(
session: Database,
beatmap: int = Query(..., alias="b", description="谱面 ID"),
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(
@@ -45,7 +45,6 @@ async def download_replay(
None, description="用户类型string 用户名称 / id 用户 ID"
),
mods: int = Query(0, description="成绩的 MOD"),
session: AsyncSession = Depends(get_db),
storage_service: StorageService = Depends(get_storage_service),
):
mods_ = int_to_mods(mods)

View File

@@ -5,16 +5,15 @@ from typing import Literal
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, get_leaderboard
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.models.mods import int_to_mods, mod_to_save, mods_to_int
from app.models.score import GameMode, LeaderboardType
from .router import AllStrModel, router
from fastapi import Depends, HTTPException, Query
from fastapi import HTTPException, Query
from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
class V1Score(AllStrModel):
@@ -68,13 +67,13 @@ class V1Score(AllStrModel):
description="获取指定用户的最好成绩。",
)
async def get_user_best(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
session: AsyncSession = Depends(get_db),
):
try:
scores = (
@@ -104,13 +103,13 @@ async def get_user_best(
description="获取指定用户的最近成绩。",
)
async def get_user_recent(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
session: AsyncSession = Depends(get_db),
):
try:
scores = (
@@ -140,6 +139,7 @@ async def get_user_recent(
description="获取指定谱面的成绩。",
)
async def get_scores(
session: Database,
user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
@@ -148,7 +148,6 @@ async def get_scores(
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"),
session: AsyncSession = Depends(get_db),
):
try:
if user is not None:

View File

@@ -5,14 +5,13 @@ from typing import Literal
from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.models.score import GameMode
from .router import AllStrModel, router
from fastapi import Depends, HTTPException, Query
from fastapi import HTTPException, Query
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class V1User(AllStrModel):
@@ -41,7 +40,7 @@ class V1User(AllStrModel):
@classmethod
async def from_db(
cls, session: AsyncSession, db_user: User, ruleset: GameMode | None = None
cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User":
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
@@ -92,6 +91,7 @@ class V1User(AllStrModel):
description="获取指定用户的信息。",
)
async def get_user(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
@@ -100,7 +100,6 @@ async def get_user(
event_days: int = Query(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
),
session: AsyncSession = Depends(get_db),
):
db_user = (
await session.exec(

View File

@@ -6,7 +6,7 @@ import json
from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
@@ -24,7 +24,6 @@ from pydantic import BaseModel
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class BatchGetResp(BaseModel):
@@ -47,13 +46,13 @@ class BatchGetResp(BaseModel):
),
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(
default=None, alias="filename", description="谱面文件名"
),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
if id is None and md5 is None and filename is None:
@@ -80,9 +79,9 @@ async def lookup_beatmap(
description="获取单个谱面详情。",
)
async def get_beatmap(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
@@ -103,11 +102,11 @@ async def get_beatmap(
),
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
if not beatmap_ids:
@@ -157,6 +156,7 @@ async def batch_get_beatmaps(
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
)
async def get_beatmap_attributes(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query(
@@ -170,7 +170,6 @@ async def get_beatmap_attributes(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
),
redis: Redis = Depends(get_redis),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
mods_ = []

View File

@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import engine, get_db, get_redis
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
@@ -30,11 +30,10 @@ from fastapi import (
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with AsyncSession(engine) as session:
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
@@ -49,13 +48,13 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
redis = Depends(get_redis),
redis=Depends(get_redis),
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
@@ -112,9 +111,9 @@ async def search_beatmapset(
description=("通过谱面 ID 查询所属谱面集。"),
)
async def lookup_beatmapset(
db: Database,
beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
@@ -132,9 +131,9 @@ async def lookup_beatmapset(
description="获取单个谱面集详情。",
)
async def get_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
@@ -196,12 +195,12 @@ async def download_beatmapset(
description="**客户端专属**\n收藏或取消收藏指定谱面集。",
)
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id is not None
existing_favourite = (

View File

@@ -3,13 +3,12 @@ from __future__ import annotations
from app.database import User, UserResp
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.models.score import GameMode
from .router import router
from fastapi import Depends, Path, Security
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Path, Security
@router.get(
@@ -20,9 +19,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
tags=["用户"],
)
async def get_user_info_with_ruleset(
session: Database,
ruleset: GameMode = Path(description="指定 ruleset"),
current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
):
return await UserResp.from_db(
current_user,
@@ -40,8 +39,8 @@ async def get_user_info_with_ruleset(
tags=["用户"],
)
async def get_user_info_default(
session: Database,
current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
):
return await UserResp.from_db(
current_user,

View File

@@ -5,15 +5,14 @@ from typing import Literal
from app.database import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.models.score import GameMode
from .router import router
from fastapi import Depends, Path, Query, Security
from fastapi import Path, Query, Security
from pydantic import BaseModel
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class CountryStatistics(BaseModel):
@@ -36,10 +35,10 @@ class CountryResponse(BaseModel):
tags=["排行榜"],
)
async def get_country_ranking(
session: Database,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
):
response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all()
@@ -85,6 +84,7 @@ class TopUsersResponse(BaseModel):
tags=["排行榜"],
)
async def get_user_ranking(
session: Database,
ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path(
..., description="排名类型performance 表现分 / score 计分成绩总分"
@@ -92,7 +92,6 @@ async def get_user_ranking(
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
):
wheres = [
col(UserStatistics.mode) == ruleset,

View File

@@ -1,15 +1,14 @@
from __future__ import annotations
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_client_user, get_current_user
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Request, Security
from fastapi import HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get(
@@ -27,9 +26,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="获取当前用户的屏蔽用户列表。",
)
async def get_relationship(
db: Database,
request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
RelationshipType.FOLLOW
@@ -67,10 +66,10 @@ class AddFriendResp(BaseModel):
description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。",
)
async def add_relationship(
db: Database,
request: Request,
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id is not None
relationship_type = (
@@ -141,10 +140,10 @@ async def add_relationship(
description="**客户端专属**\n删除与目标用户的屏蔽关系。",
)
async def delete_relationship(
db: Database,
request: Request,
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
RelationshipType.BLOCK

View File

@@ -12,7 +12,7 @@ from app.database.playlists import Playlist, PlaylistResp
from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user, get_current_user
from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
@@ -36,6 +36,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="获取房间列表。支持按状态/模式筛选",
)
async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open",
description=(
@@ -51,7 +52,6 @@ async def get_all_rooms(
),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
resp_list: list[RoomResp] = []
@@ -149,8 +149,8 @@ async def _participate_room(
description="**客户端专属**\n创建一个新的房间。",
)
async def create_room(
db: Database,
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
@@ -173,6 +173,7 @@ async def create_room(
description="获取单个房间详情。",
)
async def get_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
@@ -181,7 +182,6 @@ async def get_room(
" / DAILY_CHALLENGE 每日挑战 (可选)"
),
),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
@@ -201,8 +201,8 @@ async def get_room(
description="**客户端专属**\n结束歌单模式房间。",
)
async def delete_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
@@ -221,9 +221,9 @@ async def delete_room(
description="**客户端专属**\n加入指定歌单模式房间。",
)
async def add_user_to_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
@@ -245,9 +245,9 @@ async def add_user_to_room(
description="**客户端专属**\n离开指定歌单模式房间。",
)
async def remove_user_from_room(
db: Database,
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
@@ -289,8 +289,8 @@ class APILeaderboard(BaseModel):
description="获取房间内累计得分排行榜。",
)
async def get_room_leaderboard(
db: Database,
room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
@@ -345,8 +345,8 @@ class RoomEvents(BaseModel):
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
)
async def get_room_events(
db: Database,
room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),

View File

@@ -32,7 +32,7 @@ from app.database.score import (
process_score,
process_user,
)
from app.dependencies.database import get_db, get_redis
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user, get_current_user
@@ -220,6 +220,7 @@ class BeatmapScores(BaseModel):
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
)
async def get_beatmap_scores(
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
@@ -233,7 +234,6 @@ async def get_beatmap_scores(
),
),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
):
if legacy_only:
@@ -277,13 +277,13 @@ class BeatmapUserScore(BaseModel):
description="获取指定用户在指定谱面上的最高成绩。",
)
async def get_user_beatmap_score(
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
@@ -322,12 +322,12 @@ async def get_user_beatmap_score(
description="获取指定用户在指定谱面上的全部成绩列表。",
)
async def get_user_all_beatmap_scores(
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
raise HTTPException(
@@ -357,12 +357,12 @@ async def get_user_all_beatmap_scores(
)
async def create_solo_score(
background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
version_hash: str = Form("", description="游戏版本哈希"),
beatmap_hash: str = Form(description="谱面文件哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id is not None
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
@@ -387,11 +387,11 @@ async def create_solo_score(
)
async def submit_solo_score(
req: Request,
db: Database,
beatmap_id: int = Path(description="谱面 ID"),
token: int = Path(description="成绩令牌 ID"),
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
@@ -407,6 +407,7 @@ async def submit_solo_score(
description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。",
)
async def create_playlist_score(
session: Database,
background_task: BackgroundTasks,
room_id: int,
playlist_id: int,
@@ -415,7 +416,6 @@ async def create_playlist_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
):
assert current_user.id is not None
room = await session.get(Room, room_id)
@@ -483,12 +483,12 @@ async def create_playlist_score(
description="**客户端专属**\n提交房间游玩项目成绩。",
)
async def submit_playlist_score(
session: Database,
room_id: int,
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -541,6 +541,7 @@ class IndexedScoreResp(MultiplayerScores):
tags=["成绩"],
)
async def index_playlist_scores(
session: Database,
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
@@ -548,7 +549,6 @@ async def index_playlist_scores(
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
if not room:
@@ -607,11 +607,11 @@ async def index_playlist_scores(
tags=["成绩"],
)
async def show_playlist_score(
session: Database,
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
room = await session.get(Room, room_id)
@@ -678,11 +678,11 @@ async def show_playlist_score(
tags=["成绩"],
)
async def get_user_playlist_score(
session: Database,
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
):
score_record = None
start_time = time.time()
@@ -716,9 +716,9 @@ async def get_user_playlist_score(
tags=["成绩"],
)
async def pin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
@@ -758,9 +758,9 @@ async def pin_score(
tags=["成绩"],
)
async def unpin_score(
db: Database,
score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
@@ -797,11 +797,11 @@ async def unpin_score(
tags=["成绩"],
)
async def reorder_score_pin(
db: Database,
score_id: int = Path(description="成绩 ID"),
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
):
score_record = (
await db.exec(
@@ -892,8 +892,8 @@ async def reorder_score_pin(
)
async def download_score_replay(
score_id: int,
db: Database,
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
storage_service: StorageService = Depends(get_storage_service),
):
score = (await db.exec(select(Score).where(Score.id == score_id))).first()

View File

@@ -15,17 +15,16 @@ from app.database.events import EventResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.dependencies.database import get_db
from app.dependencies.database import Database
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
@@ -43,6 +42,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users(
session: Database,
user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
),
@@ -50,7 +50,6 @@ async def get_users(
include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
session: AsyncSession = Depends(get_db),
):
if user_ids:
searched_users = (
@@ -79,9 +78,9 @@ async def get_users(
tags=["用户"],
)
async def get_user_info_ruleset(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
@@ -112,8 +111,8 @@ async def get_user_info_ruleset(
tags=["用户"],
)
async def get_user_info(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
@@ -142,10 +141,10 @@ async def get_user_info(
tags=["用户"],
)
async def get_user_beatmapsets(
session: Database,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
):
@@ -202,6 +201,7 @@ async def get_user_beatmapsets(
tags=["用户"],
)
async def get_user_scores(
session: Database,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=(
@@ -216,7 +216,6 @@ async def get_user_scores(
),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_user = await session.get(User, user_id)
@@ -267,10 +266,10 @@ async def get_user_scores(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
async def get_user_events(
session: Database,
user: int,
limit: int | None = Query(None),
offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移
session: AsyncSession = Depends(get_db),
):
db_user = await session.get(User, user)
if db_user is None or db_user.id == BANCHOBOT_ID:

View File

@@ -4,12 +4,11 @@ from datetime import UTC, datetime, timedelta
from app.database import RankHistory, UserStatistics
from app.database.rank_history import RankTop
from app.dependencies.database import engine
from app.dependencies.database import with_db
from app.dependencies.scheduler import get_scheduler
from app.models.score import GameMode
from sqlmodel import col, exists, select, update
from sqlmodel.ext.asyncio.session import AsyncSession
@get_scheduler().scheduled_job(
@@ -18,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
async def calculate_user_rank(is_today: bool = False):
today = datetime.now(UTC).date()
target_date = today if is_today else today - timedelta(days=1)
async with AsyncSession(engine) as session:
async with with_db() as session:
for gamemode in GameMode:
users = await session.exec(
select(UserStatistics)

View File

@@ -3,15 +3,14 @@ from __future__ import annotations
from app.const import BANCHOBOT_ID
from app.database.lazer_user import User
from app.database.statistics import UserStatistics
from app.dependencies.database import engine
from app.dependencies.database import with_db
from app.models.score import GameMode
from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_banchobot():
async with AsyncSession(engine) as session:
async with with_db() as session:
is_exist = (
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first()

View File

@@ -6,7 +6,7 @@ import json
from app.const import BANCHOBOT_ID
from app.database.playlists import Playlist
from app.database.room import Room
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.models.metadata_hub import DailyChallengeInfo
@@ -16,13 +16,12 @@ from app.models.room import RoomCategory
from .room import create_playlist_room
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_daily_challenge_room(
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
) -> Room:
async with AsyncSession(engine) as session:
async with with_db() as session:
today = datetime.now(UTC).date()
return await create_playlist_room(
session=session,
@@ -52,7 +51,7 @@ async def daily_challenge_job():
key = f"daily_challenge:{now.date()}"
if not await redis.exists(key):
return
async with AsyncSession(engine) as session:
async with with_db() as session:
room = (
await session.exec(
select(Room).where(

View File

@@ -4,16 +4,15 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database.lazer_user import User
from app.database.statistics import UserStatistics
from app.dependencies.database import engine
from app.dependencies.database import with_db
from app.models.score import GameMode
from sqlalchemy import exists
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_rx_statistics():
async with AsyncSession(engine) as session:
async with with_db() as session:
users = (await session.exec(select(User.id))).all()
for i in users:
if i == BANCHOBOT_ID:

View File

@@ -4,13 +4,12 @@ from typing import TYPE_CHECKING
from app.database import PlaylistBestScore, Score
from app.database.playlist_best_score import get_position
from app.dependencies.database import engine
from app.dependencies.database import with_db
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
from .base import RedisSubscriber
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.signalr.hub import MetadataHub
@@ -45,7 +44,7 @@ class ScoreSubscriber(RedisSubscriber):
async def _notify_room_score_processed(self, score_id: int):
if not self.metadata_hub:
return
async with AsyncSession(engine) as session:
async with with_db() as session:
score = await session.get(Score, score_id)
if (
not score

View File

@@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore
from app.database.playlists import Playlist
from app.database.room import Room
from app.database.score import Score
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.models.metadata_hub import (
TOTAL_SCORE_DISTRIBUTION_BINS,
DailyChallengeInfo,
@@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber
from .hub import Client, Hub
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
@@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]):
redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
user = (
await session.exec(
@@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]):
user_id = int(client.connection_id)
self.get_or_create_state(client)
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
friends = (
await session.exec(
@@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]):
return list(stats.playlist_item_stats.values())
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
async with AsyncSession(engine) as session:
async with with_db() as session:
playlist_ids = (
await session.exec(
select(Playlist.id).where(

View File

@@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist
from app.database.relationship import Relationship, RelationshipType
from app.database.room_participated_user import RoomParticipatedUser
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.log import logger
@@ -50,7 +50,6 @@ from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy import update
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30
@@ -61,7 +60,7 @@ class MultiplayerEventLogger:
async def log_event(self, event: MultiplayerEvent):
try:
async with AsyncSession(engine) as session:
async with with_db() as session:
session.add(event)
await session.commit()
except Exception as e:
@@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
store = self.get_or_create_state(client)
if store.room_id != 0:
raise InvokeException("You are already in a room")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session:
db_room = Room(
name=room.settings.name,
@@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await server_room.match_type_handler.handle_join(user)
await self.event_logger.player_joined(room_id, user.user_id)
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
if (
participated_user := (
@@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
)
async def change_db_settings(self, room: ServerMultiplayerRoom):
async with AsyncSession(engine) as session:
async with with_db() as session:
await session.execute(
update(Room)
.where(col(Room.id) == room.room.room_id)
@@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room,
user,
)
async with AsyncSession(engine) as session:
async with with_db() as session:
try:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=room.queue.current_item.beatmap_id
@@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if not room.queue.current_item.freestyle:
raise InvokeException("Current item does not allow free user styles.")
async with AsyncSession(engine) as session:
async with with_db() as session:
item_beatmap = await session.get(
Beatmap, room.queue.current_item.beatmap_id
)
@@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
participated_user = (
await session.exec(
@@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host
async with AsyncSession(engine) as session:
async with with_db() as session:
await session.execute(
update(Room)
.where(col(Room.id) == room.room.room_id)
@@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if user is None:
raise InvokeException("You are not in this room")
async with AsyncSession(engine) as session:
async with with_db() as session:
db_user = await session.get(User, user_id)
target_relationship = (
await session.exec(

View File

@@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.statistics import UserStatistics
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.exception import InvokeException
@@ -38,7 +38,6 @@ from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 30
REPLAY_LATEST_VER = 30000016
@@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
return
fetcher = await get_fetcher()
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
try:
beatmap = await Beatmap.get_or_fetch(
@@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]):
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session:
start_time = time.time()
score_record = None
@@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]):
self, user_id: int, state: SpectatorState, store: StoreClientState
) -> None:
async def _add_failtime():
async with AsyncSession(engine) as session:
async with with_db() as session:
failtime = await session.get(FailTime, state.beatmap_id)
total_length = (
await session.exec(
@@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]):
return
before_time = int(messages[0][1]["time"])
await redis.delete(key)
async with AsyncSession(engine) as session:
async with with_db() as session:
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
statistics = (
await session.exec(
@@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.add_to_group(client, self.group_id(target_id))
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
username = (
await session.exec(select(User.username).where(User.id == user_id))

View File

@@ -8,7 +8,7 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.database import DBFactory, get_db_factory
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/signalr", include_in_schema=False)
@@ -47,7 +46,7 @@ async def connect(
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: AsyncSession = Depends(get_db),
factory: DBFactory = Depends(get_db_factory),
):
token = authorization[7:]
user_id = id.split(":")[0]
@@ -56,13 +55,14 @@ async def connect(
await websocket.close(code=1008)
return
try:
if (
user := await get_current_user(
SecurityScopes(scopes=["*"]), db, token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
async for session in factory():
if (
user := await get_current_user(
session, SecurityScopes(scopes=["*"]), token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
except HTTPException:
await websocket.close(code=1008)
return