From 1c65b21bb95eb9f5fd360e8dd9c54825780c6ff6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 18 Aug 2025 16:37:30 +0000 Subject: [PATCH] refactor(app): update database code --- app/dependencies/database.py | 7 +++++ app/dependencies/user.py | 9 +++---- app/models/multiplayer_hub.py | 13 +++++----- app/models/score.py | 2 +- app/router/auth.py | 8 +++--- app/router/chat/channel.py | 15 +++++------ app/router/chat/message.py | 13 +++++----- app/router/chat/server.py | 8 +++--- app/router/private/avatar.py | 5 ++-- app/router/private/cover.py | 5 ++-- app/router/private/oauth.py | 17 ++++++------ app/router/private/relationship.py | 7 +++-- app/router/private/username.py | 7 +++-- app/router/v1/beatmap.py | 4 +-- app/router/v1/replay.py | 5 ++-- app/router/v1/score.py | 11 ++++---- app/router/v1/user.py | 9 +++---- app/router/v2/beatmap.py | 11 ++++---- app/router/v2/beatmapset.py | 15 +++++------ app/router/v2/me.py | 9 +++---- app/router/v2/ranking.py | 9 +++---- app/router/v2/relationship.py | 11 ++++---- app/router/v2/room.py | 18 ++++++------- app/router/v2/score.py | 30 +++++++++++----------- app/router/v2/user.py | 17 ++++++------ app/service/calculate_all_user_rank.py | 5 ++-- app/service/create_banchobot.py | 5 ++-- app/service/daily_challenge.py | 7 +++-- app/service/osu_rx_statistics.py | 5 ++-- app/service/subscribers/score_processed.py | 5 ++-- app/signalr/hub/metadata.py | 9 +++---- app/signalr/hub/multiplayer.py | 21 ++++++++------- app/signalr/hub/spectator.py | 13 +++++----- app/signalr/router.py | 20 +++++++-------- 34 files changed, 167 insertions(+), 188 deletions(-) diff --git a/app/dependencies/database.py b/app/dependencies/database.py index f7e94ce..eb5b94d 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -3,9 +3,11 @@ from __future__ import annotations from collections.abc import AsyncIterator, Callable from contextvars import ContextVar import json +from typing import Annotated from app.config import settings +from fastapi import Depends from pydantic import BaseModel import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine @@ -52,7 +54,12 @@ async def get_db(): yield session +def with_db(): + return AsyncSession(engine) + + DBFactory = Callable[[], AsyncIterator[AsyncSession]] +Database = Annotated[AsyncSession, Depends(get_db)] async def get_db_factory() -> DBFactory: diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 50c06da..d3787dc 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -8,7 +8,7 @@ from app.database import User from app.database.auth import V1APIKeys from app.models.oauth import OAuth2ClientCredentialsBearer -from .database import get_db +from .database import Database from fastapi import Depends, HTTPException from fastapi.security import ( @@ -19,7 +19,6 @@ from fastapi.security import ( SecurityScopes, ) from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() @@ -64,7 +63,7 @@ v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API async def v1_authorize( - db: Annotated[AsyncSession, Depends(get_db)], + db: Database, api_key: Annotated[str, Depends(v1_api_key)], ): """V1 API Key 授权""" @@ -79,8 +78,8 @@ async def v1_authorize( async def get_client_user( + db: Database, token: Annotated[str, Depends(oauth2_password)], - db: Annotated[AsyncSession, Depends(get_db)], ): token_record = await get_token_by_access_token(db, token) if not token_record: @@ -95,8 +94,8 @@ async def get_client_user( async def get_current_user( + db: Database, security_scopes: SecurityScopes, - db: Annotated[AsyncSession, Depends(get_db)], token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_client_credentials: Annotated[ diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index c315300..5011af9 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -18,7 +18,7 @@ from typing import ( ) from app.database.beatmap import Beatmap -from app.dependencies.database import engine +from app.dependencies.database import with_db from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException @@ -41,7 +41,6 @@ from .signalr import ( from pydantic import BaseModel, Field from sqlalchemy import update from sqlmodel import col -from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.database.room import Room @@ -473,7 +472,7 @@ class MultiplayerQueue: (item for item in self.room.playlist if not item.expired), key=lambda x: x.id, ) - async with AsyncSession(engine) as session: + async with with_db() as session: for idx, item in enumerate(ordered_active_items): if item.playlist_order == idx: continue @@ -522,7 +521,7 @@ class MultiplayerQueue: if item.freestyle and len(item.allowed_mods) > 0: raise InvokeException("Freestyle items cannot have allowed mods") - async with AsyncSession(engine) as session: + async with with_db() as session: fetcher = await get_fetcher() async with session: beatmap = await Beatmap.get_or_fetch( @@ -548,7 +547,7 @@ class MultiplayerQueue: if item.freestyle and len(item.allowed_mods) > 0: raise InvokeException("Freestyle items cannot have allowed mods") - async with AsyncSession(engine) as session: + async with with_db() as session: fetcher = await get_fetcher() async with session: beatmap = await Beatmap.get_or_fetch( @@ -622,7 +621,7 @@ class MultiplayerQueue: "Attempted to remove an item which has already been played" ) - async with AsyncSession(engine) as session: + async with with_db() as session: await Playlist.delete_item(item.id, self.room.room_id, session) found_item = next((i for i in self.room.playlist if i.id == item.id), None) @@ -637,7 +636,7 @@ class MultiplayerQueue: async def finish_current_item(self): from app.database import Playlist - async with AsyncSession(engine) as session: + async with with_db() as session: played_at = datetime.now(UTC) await session.execute( update(Playlist) diff --git a/app/models/score.py b/app/models/score.py index 701028f..26a67f1 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -92,7 +92,7 @@ class GameMode(str, Enum): def parse(cls, v: str | int) -> "GameMode | None": if isinstance(v, int) or v.isdigit(): return cls.from_int_extra(int(v)) - v = v.lower() + v = v.upper() try: return cls[v] except ValueError: diff --git a/app/router/auth.py b/app/router/auth.py index 1c06ffb..4120c46 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -18,8 +18,7 @@ from app.config import settings from app.const import BANCHOBOT_ID from app.database import DailyChallengeStats, OAuthClient, User from app.database.statistics import UserStatistics -from app.dependencies import get_db -from app.dependencies.database import get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.helpers.geoip_helper import GeoIPHelper from app.log import logger @@ -37,7 +36,6 @@ from fastapi.responses import JSONResponse from redis.asyncio import Redis from sqlalchemy import text from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession def create_oauth_error_response( @@ -89,11 +87,11 @@ router = APIRouter(tags=["osu! OAuth 认证"]) description="用户注册接口", ) async def register_user( + db: Database, request: Request, user_username: str = Form(..., alias="user[username]", description="用户名"), user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"), user_password: str = Form(..., alias="user[password]", description="密码"), - db: AsyncSession = Depends(get_db), geoip: GeoIPHelper = Depends(get_geoip_helper), ): username_errors = validate_username(user_username) @@ -205,6 +203,7 @@ async def register_user( description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。", ) async def oauth_token( + db: Database, request: Request, grant_type: Literal[ "authorization_code", "refresh_token", "password", "client_credentials" @@ -218,7 +217,6 @@ async def oauth_token( refresh_token: str | None = Form( None, description="刷新令牌(仅刷新令牌模式需要)" ), - db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): scopes = scope.split(" ") diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index 059fc50..fbba4ae 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -11,7 +11,7 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.lazer_user import User, UserResp -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router @@ -22,7 +22,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field, model_validator from redis.asyncio import Redis from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession class UpdateResponse(BaseModel): @@ -38,6 +37,7 @@ class UpdateResponse(BaseModel): tags=["聊天"], ) async def get_update( + session: Database, history_since: int | None = Query( None, description="获取自此禁言 ID 之后的禁言记录" ), @@ -46,7 +46,6 @@ async def get_update( ["presence", "silences"], alias="includes[]", description="要包含的更新类型" ), current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): resp = UpdateResponse() @@ -101,10 +100,10 @@ async def get_update( tags=["聊天"], ) async def join_channel( + session: Database, channel: str = Path(..., description="频道 ID/名称"), user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), - session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) @@ -121,10 +120,10 @@ async def join_channel( tags=["聊天"], ) async def leave_channel( + session: Database, channel: str = Path(..., description="频道 ID/名称"), user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), - session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) @@ -142,8 +141,8 @@ async def leave_channel( tags=["聊天"], ) async def get_channel_list( + session: Database, current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): channels = ( @@ -181,9 +180,9 @@ class GetChannelResp(BaseModel): tags=["聊天"], ) async def get_channel( + session: Database, channel: str = Path(..., description="频道 ID/名称"), current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): db_channel = await ChatChannel.get(channel, session) @@ -250,9 +249,9 @@ class CreateChannelReq(BaseModel): tags=["聊天"], ) async def create_channel( + session: Database, req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): if req.type == "PM": diff --git a/app/router/chat/message.py b/app/router/chat/message.py index e6caf05..45888bc 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -11,7 +11,7 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.lazer_user import User -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router @@ -23,7 +23,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field from redis.asyncio import Redis from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession class KeepAliveResp(BaseModel): @@ -38,12 +37,12 @@ class KeepAliveResp(BaseModel): tags=["聊天"], ) async def keep_alive( + session: Database, history_since: int | None = Query( None, description="获取自此禁言 ID 之后的禁言记录" ), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), ): resp = KeepAliveResp() if history_since: @@ -84,10 +83,10 @@ class MessageReq(BaseModel): tags=["聊天"], ) async def send_message( + session: Database, channel: str = Path(..., description="频道 ID/名称"), req: MessageReq = Depends(BodyOrForm(MessageReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), - session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: @@ -125,12 +124,12 @@ async def send_message( tags=["聊天"], ) async def get_message( + session: Database, channel: str, limit: int = Query(50, ge=1, le=50, description="获取消息的数量"), since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"), until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: @@ -158,10 +157,10 @@ async def get_message( tags=["聊天"], ) async def mark_as_read( + session: Database, channel: str = Path(..., description="频道 ID/名称"), message: int = Path(..., description="消息 ID"), current_user: User = Security(get_current_user, scopes=["chat.read"]), - session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: @@ -191,9 +190,9 @@ class NewPMResp(BaseModel): tags=["聊天"], ) async def create_new_pm( + session: Database, req: PMReq = Depends(BodyOrForm(PMReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): user_id = current_user.id diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 61e0877..88990a8 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -6,9 +6,9 @@ from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMes from app.database.lazer_user import User from app.dependencies.database import ( DBFactory, - engine, get_db_factory, get_redis, + with_db, ) from app.dependencies.user import get_current_user from app.log import logger @@ -200,7 +200,7 @@ class ChatServer: ) async def join_room_channel(self, channel_id: int, user_id: int): - async with AsyncSession(engine) as session: + async with with_db() as session: channel = await ChatChannel.get(channel_id, session) if channel is None: return @@ -212,7 +212,7 @@ class ChatServer: await self.join_channel(user, channel, session) async def leave_room_channel(self, channel_id: int, user_id: int): - async with AsyncSession(engine) as session: + async with with_db() as session: channel = await ChatChannel.get(channel_id, session) if channel is None: return @@ -268,7 +268,7 @@ async def chat_websocket( token = authorization[7:] if ( user := await get_current_user( - SecurityScopes(scopes=["chat.read"]), session, token_pw=token + session, SecurityScopes(scopes=["chat.read"]), token_pw=token ) ) is None: await websocket.close(code=1008) diff --git a/app/router/private/avatar.py b/app/router/private/avatar.py index e724daf..0315e5a 100644 --- a/app/router/private/avatar.py +++ b/app/router/private/avatar.py @@ -4,7 +4,7 @@ import hashlib from io import BytesIO from app.database.lazer_user import User -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.storage import get_storage_service from app.dependencies.user import get_client_user from app.storage.base import StorageService @@ -13,7 +13,6 @@ from .router import router from fastapi import Depends, File, HTTPException, Security from PIL import Image -from sqlmodel.ext.asyncio.session import AsyncSession @router.post( @@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession name="上传头像", ) async def upload_avatar( + session: Database, content: bytes = File(...), current_user: User = Security(get_client_user), storage: StorageService = Depends(get_storage_service), - session: AsyncSession = Depends(get_db), ): """上传用户头像 diff --git a/app/router/private/cover.py b/app/router/private/cover.py index 08397fe..3a8e5f8 100644 --- a/app/router/private/cover.py +++ b/app/router/private/cover.py @@ -4,7 +4,7 @@ import hashlib from io import BytesIO from app.database.lazer_user import User, UserProfileCover -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.storage import get_storage_service from app.dependencies.user import get_client_user from app.storage.base import StorageService @@ -13,7 +13,6 @@ from .router import router from fastapi import Depends, File, HTTPException, Security from PIL import Image -from sqlmodel.ext.asyncio.session import AsyncSession @router.post( @@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession name="上传头图", ) async def upload_cover( + session: Database, content: bytes = File(...), current_user: User = Security(get_client_user), storage: StorageService = Depends(get_storage_service), - session: AsyncSession = Depends(get_db), ): """上传用户头图 diff --git a/app/router/private/oauth.py b/app/router/private/oauth.py index d3950e3..c18d8bd 100644 --- a/app/router/private/oauth.py +++ b/app/router/private/oauth.py @@ -4,7 +4,7 @@ import secrets from app.database.auth import OAuthClient, OAuthToken from app.database.lazer_user import User -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.user import get_client_user from .router import router @@ -12,7 +12,6 @@ from .router import router from fastapi import Body, Depends, HTTPException, Security from redis.asyncio import Redis from sqlmodel import select, text -from sqlmodel.ext.asyncio.session import AsyncSession @router.post( @@ -21,11 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥", ) async def create_oauth_app( + session: Database, name: str = Body(..., max_length=100, description="应用程序名称"), description: str = Body("", description="应用程序描述"), redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), current_user: User = Security(get_client_user), - session: AsyncSession = Depends(get_db), ): result = await session.execute( # pyright: ignore[reportDeprecated] text( @@ -61,8 +60,8 @@ async def create_oauth_app( description="通过客户端 ID 获取 OAuth 应用的详细信息", ) async def get_oauth_app( + session: Database, client_id: int, - session: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): oauth_app = await session.get(OAuthClient, client_id) @@ -82,7 +81,7 @@ async def get_oauth_app( description="获取当前用户创建的所有 OAuth 应用程序", ) async def get_user_oauth_apps( - session: AsyncSession = Depends(get_db), + session: Database, current_user: User = Security(get_client_user), ): oauth_apps = await session.exec( @@ -106,8 +105,8 @@ async def get_user_oauth_apps( description="删除指定的 OAuth 应用程序及其关联的所有令牌", ) async def delete_oauth_app( + session: Database, client_id: int, - session: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): oauth_client = await session.get(OAuthClient, client_id) @@ -134,11 +133,11 @@ async def delete_oauth_app( description="更新指定 OAuth 应用的名称、描述和重定向 URI", ) async def update_oauth_app( + session: Database, client_id: int, name: str = Body(..., max_length=100, description="应用程序新名称"), description: str = Body("", description="应用程序新描述"), redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"), - session: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): oauth_client = await session.get(OAuthClient, client_id) @@ -169,8 +168,8 @@ async def update_oauth_app( description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效", ) async def refresh_secret( + session: Database, client_id: int, - session: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): oauth_client = await session.get(OAuthClient, client_id) @@ -204,11 +203,11 @@ async def refresh_secret( description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程", ) async def generate_oauth_code( + session: Database, client_id: int, current_user: User = Security(get_client_user), redirect_uri: str = Body(..., description="授权后重定向的 URI"), scopes: list[str] = Body(..., description="请求的权限范围列表"), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): client = await session.get(OAuthClient, client_id) diff --git a/app/router/private/relationship.py b/app/router/private/relationship.py index 15482bf..1a8af44 100644 --- a/app/router/private/relationship.py +++ b/app/router/private/relationship.py @@ -2,15 +2,14 @@ from __future__ import annotations from app.database import Relationship, User from app.database.relationship import RelationshipType -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.user import get_client_user from .router import router -from fastapi import Depends, HTTPException, Path, Security +from fastapi import HTTPException, Path, Security from pydantic import BaseModel, Field from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession class CheckResponse(BaseModel): @@ -26,9 +25,9 @@ class CheckResponse(BaseModel): response_model=CheckResponse, ) async def check_user_relationship( + db: Database, user_id: int = Path(..., description="目标用户的 ID"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): if user_id == current_user.id: raise HTTPException(422, "Cannot check relationship with yourself") diff --git a/app/router/private/username.py b/app/router/private/username.py index 5319d2c..0e66bd8 100644 --- a/app/router/private/username.py +++ b/app/router/private/username.py @@ -6,14 +6,13 @@ from app.auth import validate_username from app.config import settings from app.database.events import Event, EventType from app.database.lazer_user import User -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.user import get_client_user from .router import router -from fastapi import Body, Depends, HTTPException, Security +from fastapi import Body, HTTPException, Security from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession @router.post( @@ -21,8 +20,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession name="修改用户名", ) async def user_rename( + session: Database, new_name: str = Body(..., description="新的用户名"), - session: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): """修改用户名 diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py index 97d4e83..82209dc 100644 --- a/app/router/v1/beatmap.py +++ b/app/router/v1/beatmap.py @@ -8,7 +8,7 @@ from app.database.beatmap_playcounts import BeatmapPlaycounts from app.database.beatmapset import Beatmapset from app.database.favourite_beatmapset import FavouriteBeatmapset from app.database.score import Score -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.fetcher import get_fetcher from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus, Genre, Language @@ -149,6 +149,7 @@ class V1Beatmap(AllStrModel): description="根据指定条件搜索谱面。", ) async def get_beatmaps( + session: Database, since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"), beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), @@ -163,7 +164,6 @@ async def get_beatmaps( checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), mods: int = Query(0, description="应用到谱面属性的 MOD"), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher: Fetcher = Depends(get_fetcher), ): diff --git a/app/router/v1/replay.py b/app/router/v1/replay.py index 29e13d1..f1cc97d 100644 --- a/app/router/v1/replay.py +++ b/app/router/v1/replay.py @@ -6,7 +6,7 @@ from typing import Literal from app.database.counts import ReplayWatchedCount from app.database.score import Score -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.storage import get_storage_service from app.models.mods import int_to_mods from app.models.score import GameMode @@ -17,7 +17,6 @@ from .router import router from fastapi import Depends, HTTPException, Query from pydantic import BaseModel from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession class ReplayModel(BaseModel): @@ -32,6 +31,7 @@ class ReplayModel(BaseModel): description="获取指定谱面的回放文件。", ) async def download_replay( + session: Database, beatmap: int = Query(..., alias="b", description="谱面 ID"), user: str = Query(..., alias="u", description="用户"), ruleset_id: int | None = Query( @@ -45,7 +45,6 @@ async def download_replay( None, description="用户类型:string 用户名称 / id 用户 ID" ), mods: int = Query(0, description="成绩的 MOD"), - session: AsyncSession = Depends(get_db), storage_service: StorageService = Depends(get_storage_service), ): mods_ = int_to_mods(mods) diff --git a/app/router/v1/score.py b/app/router/v1/score.py index ffe0acb..d382522 100644 --- a/app/router/v1/score.py +++ b/app/router/v1/score.py @@ -5,16 +5,15 @@ from typing import Literal from app.database.pp_best_score import PPBestScore from app.database.score import Score, get_leaderboard -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.models.mods import int_to_mods, mod_to_save, mods_to_int from app.models.score import GameMode, LeaderboardType from .router import AllStrModel, router -from fastapi import Depends, HTTPException, Query +from fastapi import HTTPException, Query from sqlalchemy.orm import joinedload from sqlmodel import col, exists, select -from sqlmodel.ext.asyncio.session import AsyncSession class V1Score(AllStrModel): @@ -68,13 +67,13 @@ class V1Score(AllStrModel): description="获取指定用户的最好成绩。", ) async def get_user_best( + session: Database, user: str = Query(..., alias="u", description="用户"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), type: Literal["string", "id"] | None = Query( None, description="用户类型:string 用户名称 / id 用户 ID" ), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), - session: AsyncSession = Depends(get_db), ): try: scores = ( @@ -104,13 +103,13 @@ async def get_user_best( description="获取指定用户的最近成绩。", ) async def get_user_recent( + session: Database, user: str = Query(..., alias="u", description="用户"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), type: Literal["string", "id"] | None = Query( None, description="用户类型:string 用户名称 / id 用户 ID" ), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), - session: AsyncSession = Depends(get_db), ): try: scores = ( @@ -140,6 +139,7 @@ async def get_user_recent( description="获取指定谱面的成绩。", ) async def get_scores( + session: Database, user: str | None = Query(None, alias="u", description="用户"), beatmap_id: int = Query(alias="b", description="谱面 ID"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), @@ -148,7 +148,6 @@ async def get_scores( ), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), mods: int = Query(0, description="成绩的 MOD"), - session: AsyncSession = Depends(get_db), ): try: if user is not None: diff --git a/app/router/v1/user.py b/app/router/v1/user.py index 59ab2d4..19ada39 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -5,14 +5,13 @@ from typing import Literal from app.database.lazer_user import User from app.database.statistics import UserStatistics, UserStatisticsResp -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.models.score import GameMode from .router import AllStrModel, router -from fastapi import Depends, HTTPException, Query +from fastapi import HTTPException, Query from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession class V1User(AllStrModel): @@ -41,7 +40,7 @@ class V1User(AllStrModel): @classmethod async def from_db( - cls, session: AsyncSession, db_user: User, ruleset: GameMode | None = None + cls, session: Database, db_user: User, ruleset: GameMode | None = None ) -> "V1User": ruleset = ruleset or db_user.playmode current_statistics: UserStatistics | None = None @@ -92,6 +91,7 @@ class V1User(AllStrModel): description="获取指定用户的信息。", ) async def get_user( + session: Database, user: str = Query(..., alias="u", description="用户"), ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), type: Literal["string", "id"] | None = Query( @@ -100,7 +100,6 @@ async def get_user( event_days: int = Query( default=1, ge=1, le=31, description="从现在起所有事件的最大天数" ), - session: AsyncSession = Depends(get_db), ): db_user = ( await session.exec( diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index b6eef19..e5a2775 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -6,7 +6,7 @@ import json from app.database import Beatmap, BeatmapResp, User from app.database.beatmap import calculate_beatmap_attributes -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user from app.fetcher import Fetcher @@ -24,7 +24,6 @@ from pydantic import BaseModel from redis.asyncio import Redis import rosu_pp_py as rosu from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession class BatchGetResp(BaseModel): @@ -47,13 +46,13 @@ class BatchGetResp(BaseModel): ), ) async def lookup_beatmap( + db: Database, id: int | None = Query(default=None, alias="id", description="谱面 ID"), md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), filename: str | None = Query( default=None, alias="filename", description="谱面文件名" ), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): if id is None and md5 is None and filename is None: @@ -80,9 +79,9 @@ async def lookup_beatmap( description="获取单个谱面详情。", ) async def get_beatmap( + db: Database, beatmap_id: int = Path(..., description="谱面 ID"), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): try: @@ -103,11 +102,11 @@ async def get_beatmap( ), ) async def batch_get_beatmaps( + db: Database, beatmap_ids: list[int] = Query( alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)" ), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): if not beatmap_ids: @@ -157,6 +156,7 @@ async def batch_get_beatmaps( description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"), ) async def get_beatmap_attributes( + db: Database, beatmap_id: int = Path(..., description="谱面 ID"), current_user: User = Security(get_current_user, scopes=["public"]), mods: list[str] = Query( @@ -170,7 +170,6 @@ async def get_beatmap_attributes( default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3 ), redis: Redis = Depends(get_redis), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): mods_ = [] diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index 52a4f1d..5d4dadc 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -7,7 +7,7 @@ from urllib.parse import parse_qs from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database.beatmapset import SearchBeatmapsetsResp from app.dependencies.beatmap_download import get_beatmap_download_service -from app.dependencies.database import engine, get_db, get_redis +from app.dependencies.database import Database, get_redis, with_db from app.dependencies.fetcher import get_fetcher from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.user import get_client_user, get_current_user @@ -30,11 +30,10 @@ from fastapi import ( from fastapi.responses import RedirectResponse from httpx import HTTPError from sqlmodel import exists, select -from sqlmodel.ext.asyncio.session import AsyncSession async def _save_to_db(sets: SearchBeatmapsetsResp): - async with AsyncSession(engine) as session: + async with with_db() as session: for s in sets.beatmapsets: if not ( await session.exec(select(exists()).where(Beatmapset.id == s.id)) @@ -49,13 +48,13 @@ async def _save_to_db(sets: SearchBeatmapsetsResp): response_model=SearchBeatmapsetsResp, ) async def search_beatmapset( + db: Database, query: Annotated[SearchQueryModel, Query(...)], request: Request, background_tasks: BackgroundTasks, current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), - redis = Depends(get_redis), + redis=Depends(get_redis), ): params = parse_qs(qs=request.url.query, keep_blank_values=True) cursor = {} @@ -112,9 +111,9 @@ async def search_beatmapset( description=("通过谱面 ID 查询所属谱面集。"), ) async def lookup_beatmapset( + db: Database, beatmap_id: int = Query(description="谱面 ID"), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) @@ -132,9 +131,9 @@ async def lookup_beatmapset( description="获取单个谱面集详情。", ) async def get_beatmapset( + db: Database, beatmapset_id: int = Path(..., description="谱面集 ID"), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): try: @@ -196,12 +195,12 @@ async def download_beatmapset( description="**客户端专属**\n收藏或取消收藏指定谱面集。", ) async def favourite_beatmapset( + db: Database, beatmapset_id: int = Path(..., description="谱面集 ID"), action: Literal["favourite", "unfavourite"] = Form( description="操作类型:favourite 收藏 / unfavourite 取消收藏" ), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): assert current_user.id is not None existing_favourite = ( diff --git a/app/router/v2/me.py b/app/router/v2/me.py index 4ff227a..1ca9097 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -3,13 +3,12 @@ from __future__ import annotations from app.database import User, UserResp from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.models.score import GameMode from .router import router -from fastapi import Depends, Path, Security -from sqlmodel.ext.asyncio.session import AsyncSession +from fastapi import Path, Security @router.get( @@ -20,9 +19,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession tags=["用户"], ) async def get_user_info_with_ruleset( + session: Database, ruleset: GameMode = Path(description="指定 ruleset"), current_user: User = Security(get_current_user, scopes=["identify"]), - session: AsyncSession = Depends(get_db), ): return await UserResp.from_db( current_user, @@ -40,8 +39,8 @@ async def get_user_info_with_ruleset( tags=["用户"], ) async def get_user_info_default( + session: Database, current_user: User = Security(get_current_user, scopes=["identify"]), - session: AsyncSession = Depends(get_db), ): return await UserResp.from_db( current_user, diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index 6e86d55..711a793 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -5,15 +5,14 @@ from typing import Literal from app.database import User from app.database.statistics import UserStatistics, UserStatisticsResp from app.dependencies import get_current_user -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.models.score import GameMode from .router import router -from fastapi import Depends, Path, Query, Security +from fastapi import Path, Query, Security from pydantic import BaseModel from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession class CountryStatistics(BaseModel): @@ -36,10 +35,10 @@ class CountryResponse(BaseModel): tags=["排行榜"], ) async def get_country_ranking( + session: Database, ruleset: GameMode = Path(..., description="指定 ruleset"), page: int = Query(1, ge=1, description="页码"), # TODO current_user: User = Security(get_current_user, scopes=["public"]), - session: AsyncSession = Depends(get_db), ): response = CountryResponse(ranking=[]) countries = (await session.exec(select(User.country_code).distinct())).all() @@ -85,6 +84,7 @@ class TopUsersResponse(BaseModel): tags=["排行榜"], ) async def get_user_ranking( + session: Database, ruleset: GameMode = Path(..., description="指定 ruleset"), type: Literal["performance", "score"] = Path( ..., description="排名类型:performance 表现分 / score 计分成绩总分" @@ -92,7 +92,6 @@ async def get_user_ranking( country: str | None = Query(None, description="国家代码"), page: int = Query(1, ge=1, description="页码"), current_user: User = Security(get_current_user, scopes=["public"]), - session: AsyncSession = Depends(get_db), ): wheres = [ col(UserStatistics.mode) == ruleset, diff --git a/app/router/v2/relationship.py b/app/router/v2/relationship.py index e257a83..accbde8 100644 --- a/app/router/v2/relationship.py +++ b/app/router/v2/relationship.py @@ -1,15 +1,14 @@ from __future__ import annotations from app.database import Relationship, RelationshipResp, RelationshipType, User -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.user import get_client_user, get_current_user from .router import router -from fastapi import Depends, HTTPException, Path, Query, Request, Security +from fastapi import HTTPException, Path, Query, Request, Security from pydantic import BaseModel from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession @router.get( @@ -27,9 +26,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession description="获取当前用户的屏蔽用户列表。", ) async def get_relationship( + db: Database, request: Request, current_user: User = Security(get_current_user, scopes=["friends.read"]), - db: AsyncSession = Depends(get_db), ): relationship_type = ( RelationshipType.FOLLOW @@ -67,10 +66,10 @@ class AddFriendResp(BaseModel): description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。", ) async def add_relationship( + db: Database, request: Request, target: int = Query(description="目标用户 ID"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): assert current_user.id is not None relationship_type = ( @@ -141,10 +140,10 @@ async def add_relationship( description="**客户端专属**\n删除与目标用户的屏蔽关系。", ) async def delete_relationship( + db: Database, request: Request, target: int = Path(..., description="目标用户 ID"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): relationship_type = ( RelationshipType.BLOCK diff --git a/app/router/v2/room.py b/app/router/v2/room.py index 1a05fe0..1f4fd79 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -12,7 +12,7 @@ from app.database.playlists import Playlist, PlaylistResp from app.database.room import APIUploadedRoom, Room, RoomResp from app.database.room_participated_user import RoomParticipatedUser from app.database.score import Score -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.user import get_client_user, get_current_user from app.models.room import RoomCategory, RoomStatus from app.service.room import create_playlist_room_from_api @@ -36,6 +36,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession description="获取房间列表。支持按状态/模式筛选", ) async def get_all_rooms( + db: Database, mode: Literal["open", "ended", "participated", "owned", None] = Query( default="open", description=( @@ -51,7 +52,6 @@ async def get_all_rooms( ), ), status: RoomStatus | None = Query(None, description="房间状态(可选)"), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_current_user, scopes=["public"]), ): resp_list: list[RoomResp] = [] @@ -149,8 +149,8 @@ async def _participate_room( description="**客户端专属**\n创建一个新的房间。", ) async def create_room( + db: Database, room: APIUploadedRoom, - db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), ): @@ -173,6 +173,7 @@ async def create_room( description="获取单个房间详情。", ) async def get_room( + db: Database, room_id: int = Path(..., description="房间 ID"), category: str = Query( default="", @@ -181,7 +182,6 @@ async def get_room( " / DAILY_CHALLENGE 每日挑战 (可选)" ), ), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), ): @@ -201,8 +201,8 @@ async def get_room( description="**客户端专属**\n结束歌单模式房间。", ) async def delete_room( + db: Database, room_id: int = Path(..., description="房间 ID"), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() @@ -221,9 +221,9 @@ async def delete_room( description="**客户端专属**\n加入指定歌单模式房间。", ) async def add_user_to_room( + db: Database, room_id: int = Path(..., description="房间 ID"), user_id: int = Path(..., description="用户 ID"), - db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), current_user: User = Security(get_client_user), ): @@ -245,9 +245,9 @@ async def add_user_to_room( description="**客户端专属**\n离开指定歌单模式房间。", ) async def remove_user_from_room( + db: Database, room_id: int = Path(..., description="房间 ID"), user_id: int = Path(..., description="用户 ID"), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), ): @@ -289,8 +289,8 @@ class APILeaderboard(BaseModel): description="获取房间内累计得分排行榜。", ) async def get_room_leaderboard( + db: Database, room_id: int = Path(..., description="房间 ID"), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_current_user, scopes=["public"]), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() @@ -345,8 +345,8 @@ class RoomEvents(BaseModel): description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。", ) async def get_room_events( + db: Database, room_id: int = Path(..., description="房间 ID"), - db: AsyncSession = Depends(get_db), current_user: User = Security(get_current_user, scopes=["public"]), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"), diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 07b9f02..93eeb3c 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -32,7 +32,7 @@ from app.database.score import ( process_score, process_user, ) -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import Database, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.storage import get_storage_service from app.dependencies.user import get_client_user, get_current_user @@ -220,6 +220,7 @@ class BeatmapScores(BaseModel): description="获取指定谱面在特定条件下的排行榜及当前用户成绩。", ) async def get_beatmap_scores( + db: Database, beatmap_id: int = Path(description="谱面 ID"), mode: GameMode = Query(description="指定 auleset"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), @@ -233,7 +234,6 @@ async def get_beatmap_scores( ), ), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), ): if legacy_only: @@ -277,13 +277,13 @@ class BeatmapUserScore(BaseModel): description="获取指定用户在指定谱面上的最高成绩。", ) async def get_user_beatmap_score( + db: Database, beatmap_id: int = Path(description="谱面 ID"), user_id: int = Path(description="用户 ID"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), mode: GameMode | None = Query(None, description="指定 ruleset (可选)"), mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), ): if legacy_only: raise HTTPException( @@ -322,12 +322,12 @@ async def get_user_beatmap_score( description="获取指定用户在指定谱面上的全部成绩列表。", ) async def get_user_all_beatmap_scores( + db: Database, beatmap_id: int = Path(description="谱面 ID"), user_id: int = Path(description="用户 ID"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"), current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), ): if legacy_only: raise HTTPException( @@ -357,12 +357,12 @@ async def get_user_all_beatmap_scores( ) async def create_solo_score( background_task: BackgroundTasks, + db: Database, beatmap_id: int = Path(description="谱面 ID"), version_hash: str = Form("", description="游戏版本哈希"), beatmap_hash: str = Form(description="谱面文件哈希"), ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): assert current_user.id is not None background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) @@ -387,11 +387,11 @@ async def create_solo_score( ) async def submit_solo_score( req: Request, + db: Database, beatmap_id: int = Path(description="谱面 ID"), token: int = Path(description="成绩令牌 ID"), info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): @@ -407,6 +407,7 @@ async def submit_solo_score( description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。", ) async def create_playlist_score( + session: Database, background_task: BackgroundTasks, room_id: int, playlist_id: int, @@ -415,7 +416,6 @@ async def create_playlist_score( ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), version_hash: str = Form("", description="谱面版本哈希"), current_user: User = Security(get_client_user), - session: AsyncSession = Depends(get_db), ): assert current_user.id is not None room = await session.get(Room, room_id) @@ -483,12 +483,12 @@ async def create_playlist_score( description="**客户端专属**\n提交房间游玩项目成绩。", ) async def submit_playlist_score( + session: Database, room_id: int, playlist_id: int, token: int, info: SoloScoreSubmissionInfo, current_user: User = Security(get_client_user), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher: Fetcher = Depends(get_fetcher), ): @@ -541,6 +541,7 @@ class IndexedScoreResp(MultiplayerScores): tags=["成绩"], ) async def index_playlist_scores( + session: Database, room_id: int, playlist_id: int, limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), @@ -548,7 +549,6 @@ async def index_playlist_scores( 2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)" ), current_user: User = Security(get_current_user, scopes=["public"]), - session: AsyncSession = Depends(get_db), ): room = await session.get(Room, room_id) if not room: @@ -607,11 +607,11 @@ async def index_playlist_scores( tags=["成绩"], ) async def show_playlist_score( + session: Database, room_id: int, playlist_id: int, score_id: int, current_user: User = Security(get_client_user), - session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): room = await session.get(Room, room_id) @@ -678,11 +678,11 @@ async def show_playlist_score( tags=["成绩"], ) async def get_user_playlist_score( + session: Database, room_id: int, playlist_id: int, user_id: int, current_user: User = Security(get_client_user), - session: AsyncSession = Depends(get_db), ): score_record = None start_time = time.time() @@ -716,9 +716,9 @@ async def get_user_playlist_score( tags=["成绩"], ) async def pin_score( + db: Database, score_id: int = Path(description="成绩 ID"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): score_record = ( await db.exec( @@ -758,9 +758,9 @@ async def pin_score( tags=["成绩"], ) async def unpin_score( + db: Database, score_id: int = Path(description="成绩 ID"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): score_record = ( await db.exec( @@ -797,11 +797,11 @@ async def unpin_score( tags=["成绩"], ) async def reorder_score_pin( + db: Database, score_id: int = Path(description="成绩 ID"), after_score_id: int | None = Body(default=None, description="放在该成绩之后"), before_score_id: int | None = Body(default=None, description="放在该成绩之前"), current_user: User = Security(get_client_user), - db: AsyncSession = Depends(get_db), ): score_record = ( await db.exec( @@ -892,8 +892,8 @@ async def reorder_score_pin( ) async def download_score_replay( score_id: int, + db: Database, current_user: User = Security(get_current_user, scopes=["public"]), - db: AsyncSession = Depends(get_db), storage_service: StorageService = Depends(get_storage_service), ): score = (await db.exec(select(Score).where(Score.id == score_id))).first() diff --git a/app/router/v2/user.py b/app/router/v2/user.py index 123e120..e18cd0b 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -15,17 +15,16 @@ from app.database.events import EventResp from app.database.lazer_user import SEARCH_INCLUDED from app.database.pp_best_score import PPBestScore from app.database.score import Score, ScoreResp -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.user import get_current_user from app.models.score import GameMode from app.models.user import BeatmapsetType from .router import router -from fastapi import Depends, HTTPException, Path, Query, Security +from fastapi import HTTPException, Path, Query, Security from pydantic import BaseModel from sqlmodel import exists, false, select -from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import col @@ -43,6 +42,7 @@ class BatchUserResponse(BaseModel): @router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False) @router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False) async def get_users( + session: Database, user_ids: list[int] = Query( default_factory=list, alias="ids[]", description="要查询的用户 ID 列表" ), @@ -50,7 +50,6 @@ async def get_users( include_variant_statistics: bool = Query( default=False, description="是否包含各模式的统计信息" ), # TODO: future use - session: AsyncSession = Depends(get_db), ): if user_ids: searched_users = ( @@ -79,9 +78,9 @@ async def get_users( tags=["用户"], ) async def get_user_info_ruleset( + session: Database, user_id: str = Path(description="用户 ID 或用户名"), ruleset: GameMode | None = Path(description="指定 ruleset"), - session: AsyncSession = Depends(get_db), # current_user: User = Security(get_current_user, scopes=["public"]), ): searched_user = ( @@ -112,8 +111,8 @@ async def get_user_info_ruleset( tags=["用户"], ) async def get_user_info( + session: Database, user_id: str = Path(description="用户 ID 或用户名"), - session: AsyncSession = Depends(get_db), # current_user: User = Security(get_current_user, scopes=["public"]), ): searched_user = ( @@ -142,10 +141,10 @@ async def get_user_info( tags=["用户"], ) async def get_user_beatmapsets( + session: Database, user_id: int = Path(description="用户 ID"), type: BeatmapsetType = Path(description="谱面集类型"), current_user: User = Security(get_current_user, scopes=["public"]), - session: AsyncSession = Depends(get_db), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), offset: int = Query(0, ge=0, description="偏移量"), ): @@ -202,6 +201,7 @@ async def get_user_beatmapsets( tags=["用户"], ) async def get_user_scores( + session: Database, user_id: int = Path(description="用户 ID"), type: Literal["best", "recent", "firsts", "pinned"] = Path( description=( @@ -216,7 +216,6 @@ async def get_user_scores( ), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), offset: int = Query(0, ge=0, description="偏移量"), - session: AsyncSession = Depends(get_db), current_user: User = Security(get_current_user, scopes=["public"]), ): db_user = await session.get(User, user_id) @@ -267,10 +266,10 @@ async def get_user_scores( "/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp] ) async def get_user_events( + session: Database, user: int, limit: int | None = Query(None), offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移 - session: AsyncSession = Depends(get_db), ): db_user = await session.get(User, user) if db_user is None or db_user.id == BANCHOBOT_ID: diff --git a/app/service/calculate_all_user_rank.py b/app/service/calculate_all_user_rank.py index 55ae61e..bc4e074 100644 --- a/app/service/calculate_all_user_rank.py +++ b/app/service/calculate_all_user_rank.py @@ -4,12 +4,11 @@ from datetime import UTC, datetime, timedelta from app.database import RankHistory, UserStatistics from app.database.rank_history import RankTop -from app.dependencies.database import engine +from app.dependencies.database import with_db from app.dependencies.scheduler import get_scheduler from app.models.score import GameMode from sqlmodel import col, exists, select, update -from sqlmodel.ext.asyncio.session import AsyncSession @get_scheduler().scheduled_job( @@ -18,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession async def calculate_user_rank(is_today: bool = False): today = datetime.now(UTC).date() target_date = today if is_today else today - timedelta(days=1) - async with AsyncSession(engine) as session: + async with with_db() as session: for gamemode in GameMode: users = await session.exec( select(UserStatistics) diff --git a/app/service/create_banchobot.py b/app/service/create_banchobot.py index aa89fe6..dceec18 100644 --- a/app/service/create_banchobot.py +++ b/app/service/create_banchobot.py @@ -3,15 +3,14 @@ from __future__ import annotations from app.const import BANCHOBOT_ID from app.database.lazer_user import User from app.database.statistics import UserStatistics -from app.dependencies.database import engine +from app.dependencies.database import with_db from app.models.score import GameMode from sqlmodel import exists, select -from sqlmodel.ext.asyncio.session import AsyncSession async def create_banchobot(): - async with AsyncSession(engine) as session: + async with with_db() as session: is_exist = ( await session.exec(select(exists()).where(User.id == BANCHOBOT_ID)) ).first() diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py index f3dbf05..82fbb8f 100644 --- a/app/service/daily_challenge.py +++ b/app/service/daily_challenge.py @@ -6,7 +6,7 @@ import json from app.const import BANCHOBOT_ID from app.database.playlists import Playlist from app.database.room import Room -from app.dependencies.database import engine, get_redis +from app.dependencies.database import get_redis, with_db from app.dependencies.scheduler import get_scheduler from app.log import logger from app.models.metadata_hub import DailyChallengeInfo @@ -16,13 +16,12 @@ from app.models.room import RoomCategory from .room import create_playlist_room from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession async def create_daily_challenge_room( beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = [] ) -> Room: - async with AsyncSession(engine) as session: + async with with_db() as session: today = datetime.now(UTC).date() return await create_playlist_room( session=session, @@ -52,7 +51,7 @@ async def daily_challenge_job(): key = f"daily_challenge:{now.date()}" if not await redis.exists(key): return - async with AsyncSession(engine) as session: + async with with_db() as session: room = ( await session.exec( select(Room).where( diff --git a/app/service/osu_rx_statistics.py b/app/service/osu_rx_statistics.py index 60f94ce..b53082c 100644 --- a/app/service/osu_rx_statistics.py +++ b/app/service/osu_rx_statistics.py @@ -4,16 +4,15 @@ from app.config import settings from app.const import BANCHOBOT_ID from app.database.lazer_user import User from app.database.statistics import UserStatistics -from app.dependencies.database import engine +from app.dependencies.database import with_db from app.models.score import GameMode from sqlalchemy import exists from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession async def create_rx_statistics(): - async with AsyncSession(engine) as session: + async with with_db() as session: users = (await session.exec(select(User.id))).all() for i in users: if i == BANCHOBOT_ID: diff --git a/app/service/subscribers/score_processed.py b/app/service/subscribers/score_processed.py index 756993c..2b69740 100644 --- a/app/service/subscribers/score_processed.py +++ b/app/service/subscribers/score_processed.py @@ -4,13 +4,12 @@ from typing import TYPE_CHECKING from app.database import PlaylistBestScore, Score from app.database.playlist_best_score import get_position -from app.dependencies.database import engine +from app.dependencies.database import with_db from app.models.metadata_hub import MultiplayerRoomScoreSetEvent from .base import RedisSubscriber from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.signalr.hub import MetadataHub @@ -45,7 +44,7 @@ class ScoreSubscriber(RedisSubscriber): async def _notify_room_score_processed(self, score_id: int): if not self.metadata_hub: return - async with AsyncSession(engine) as session: + async with with_db() as session: score = await session.get(Score, score_id) if ( not score diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index e1888c2..a81c1c8 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore from app.database.playlists import Playlist from app.database.room import Room from app.database.score import Score -from app.dependencies.database import engine, get_redis +from app.dependencies.database import get_redis, with_db from app.models.metadata_hub import ( TOTAL_SCORE_DISTRIBUTION_BINS, DailyChallengeInfo, @@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber from .hub import Client, Hub from sqlmodel import col, select -from sqlmodel.ext.asyncio.session import AsyncSession ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers" @@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]): redis = get_redis() if await redis.exists(f"metadata:online:{state.connection_id}"): await redis.delete(f"metadata:online:{state.connection_id}") - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): user = ( await session.exec( @@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]): user_id = int(client.connection_id) self.get_or_create_state(client) - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): friends = ( await session.exec( @@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]): return list(stats.playlist_item_stats.values()) async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None: - async with AsyncSession(engine) as session: + async with with_db() as session: playlist_ids = ( await session.exec( select(Playlist.id).where( diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 81b2bdb..37012a4 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent from app.database.playlists import Playlist from app.database.relationship import Relationship, RelationshipType from app.database.room_participated_user import RoomParticipatedUser -from app.dependencies.database import engine, get_redis +from app.dependencies.database import get_redis, with_db from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException from app.log import logger @@ -50,7 +50,6 @@ from .hub import Client, Hub from httpx import HTTPError from sqlalchemy import update from sqlmodel import col, exists, select -from sqlmodel.ext.asyncio.session import AsyncSession GAMEPLAY_LOAD_TIMEOUT = 30 @@ -61,7 +60,7 @@ class MultiplayerEventLogger: async def log_event(self, event: MultiplayerEvent): try: - async with AsyncSession(engine) as session: + async with with_db() as session: session.add(event) await session.commit() except Exception as e: @@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): store = self.get_or_create_state(client) if store.room_id != 0: raise InvokeException("You are already in a room") - async with AsyncSession(engine) as session: + async with with_db() as session: async with session: db_room = Room( name=room.settings.name, @@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await server_room.match_type_handler.handle_join(user) await self.event_logger.player_joined(room_id, user.user_id) - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): if ( participated_user := ( @@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def change_db_settings(self, room: ServerMultiplayerRoom): - async with AsyncSession(engine) as session: + async with with_db() as session: await session.execute( update(Room) .where(col(Room.id) == room.room.room_id) @@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room, user, ) - async with AsyncSession(engine) as session: + async with with_db() as session: try: beatmap = await Beatmap.get_or_fetch( session, fetcher, bid=room.queue.current_item.beatmap_id @@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if not room.queue.current_item.freestyle: raise InvokeException("Current item does not allow free user styles.") - async with AsyncSession(engine) as session: + async with with_db() as session: item_beatmap = await session.get( Beatmap, room.queue.current_item.beatmap_id ) @@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): redis = get_redis() await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}") - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): participated_user = ( await session.exec( @@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def end_room(self, room: ServerMultiplayerRoom): assert room.room.host - async with AsyncSession(engine) as session: + async with with_db() as session: await session.execute( update(Room) .where(col(Room.id) == room.room.room_id) @@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user is None: raise InvokeException("You are not in this room") - async with AsyncSession(engine) as session: + async with with_db() as session: db_user = await session.get(User, user_id) target_relationship = ( await session.exec( diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index a89020e..7fd682e 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp from app.database.score import Score from app.database.score_token import ScoreToken from app.database.statistics import UserStatistics -from app.dependencies.database import engine, get_redis +from app.dependencies.database import get_redis, with_db from app.dependencies.fetcher import get_fetcher from app.dependencies.storage import get_storage_service from app.exception import InvokeException @@ -38,7 +38,6 @@ from .hub import Client, Hub from httpx import HTTPError from sqlalchemy.orm import joinedload from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession READ_SCORE_TIMEOUT = 30 REPLAY_LATEST_VER = 30000016 @@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]): return fetcher = await get_fetcher() - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): try: beatmap = await Beatmap.get_or_fetch( @@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]): assert store.checksum is not None assert store.ruleset_id is not None assert store.score is not None - async with AsyncSession(engine) as session: + async with with_db() as session: async with session: start_time = time.time() score_record = None @@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]): self, user_id: int, state: SpectatorState, store: StoreClientState ) -> None: async def _add_failtime(): - async with AsyncSession(engine) as session: + async with with_db() as session: failtime = await session.get(FailTime, state.beatmap_id) total_length = ( await session.exec( @@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]): return before_time = int(messages[0][1]["time"]) await redis.delete(key) - async with AsyncSession(engine) as session: + async with with_db() as session: gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods) statistics = ( await session.exec( @@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]): self.add_to_group(client, self.group_id(target_id)) - async with AsyncSession(engine) as session: + async with with_db() as session: async with session.begin(): username = ( await session.exec(select(User.username).where(User.id == user_id)) diff --git a/app/signalr/router.py b/app/signalr/router.py index ec0bb43..8601279 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -8,7 +8,7 @@ import uuid from app.database import User as DBUser from app.dependencies import get_current_user -from app.dependencies.database import get_db +from app.dependencies.database import DBFactory, get_db_factory from app.models.signalr import NegotiateResponse, Transport from .hub import Hubs @@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket from fastapi.security import SecurityScopes -from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter(prefix="/signalr", include_in_schema=False) @@ -47,7 +46,7 @@ async def connect( websocket: WebSocket, id: str, authorization: str = Header(...), - db: AsyncSession = Depends(get_db), + factory: DBFactory = Depends(get_db_factory), ): token = authorization[7:] user_id = id.split(":")[0] @@ -56,13 +55,14 @@ async def connect( await websocket.close(code=1008) return try: - if ( - user := await get_current_user( - SecurityScopes(scopes=["*"]), db, token_pw=token - ) - ) is None or str(user.id) != user_id: - await websocket.close(code=1008) - return + async for session in factory(): + if ( + user := await get_current_user( + session, SecurityScopes(scopes=["*"]), token_pw=token + ) + ) is None or str(user.id) != user_id: + await websocket.close(code=1008) + return except HTTPException: await websocket.close(code=1008) return