refactor(app): update database code

This commit is contained in:
MingxuanGame
2025-08-18 16:37:30 +00:00
parent 6bae937e01
commit 1c65b21bb9
34 changed files with 167 additions and 188 deletions

View File

@@ -3,9 +3,11 @@ from __future__ import annotations
from collections.abc import AsyncIterator, Callable from collections.abc import AsyncIterator, Callable
from contextvars import ContextVar from contextvars import ContextVar
import json import json
from typing import Annotated
from app.config import settings from app.config import settings
from fastapi import Depends
from pydantic import BaseModel from pydantic import BaseModel
import redis.asyncio as redis import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
@@ -52,7 +54,12 @@ async def get_db():
yield session yield session
def with_db():
return AsyncSession(engine)
DBFactory = Callable[[], AsyncIterator[AsyncSession]] DBFactory = Callable[[], AsyncIterator[AsyncSession]]
Database = Annotated[AsyncSession, Depends(get_db)]
async def get_db_factory() -> DBFactory: async def get_db_factory() -> DBFactory:

View File

@@ -8,7 +8,7 @@ from app.database import User
from app.database.auth import V1APIKeys from app.database.auth import V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import get_db from .database import Database
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import ( from fastapi.security import (
@@ -19,7 +19,6 @@ from fastapi.security import (
SecurityScopes, SecurityScopes,
) )
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer() security = HTTPBearer()
@@ -64,7 +63,7 @@ v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API
async def v1_authorize( async def v1_authorize(
db: Annotated[AsyncSession, Depends(get_db)], db: Database,
api_key: Annotated[str, Depends(v1_api_key)], api_key: Annotated[str, Depends(v1_api_key)],
): ):
"""V1 API Key 授权""" """V1 API Key 授权"""
@@ -79,8 +78,8 @@ async def v1_authorize(
async def get_client_user( async def get_client_user(
db: Database,
token: Annotated[str, Depends(oauth2_password)], token: Annotated[str, Depends(oauth2_password)],
db: Annotated[AsyncSession, Depends(get_db)],
): ):
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
@@ -95,8 +94,8 @@ async def get_client_user(
async def get_current_user( async def get_current_user(
db: Database,
security_scopes: SecurityScopes, security_scopes: SecurityScopes,
db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[ token_client_credentials: Annotated[

View File

@@ -18,7 +18,7 @@ from typing import (
) )
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.dependencies.database import engine from app.dependencies.database import with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException from app.exception import InvokeException
@@ -41,7 +41,6 @@ from .signalr import (
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import update from sqlalchemy import update
from sqlmodel import col from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.database.room import Room from app.database.room import Room
@@ -473,7 +472,7 @@ class MultiplayerQueue:
(item for item in self.room.playlist if not item.expired), (item for item in self.room.playlist if not item.expired),
key=lambda x: x.id, key=lambda x: x.id,
) )
async with AsyncSession(engine) as session: async with with_db() as session:
for idx, item in enumerate(ordered_active_items): for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx: if item.playlist_order == idx:
continue continue
@@ -522,7 +521,7 @@ class MultiplayerQueue:
if item.freestyle and len(item.allowed_mods) > 0: if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods") raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session: async with with_db() as session:
fetcher = await get_fetcher() fetcher = await get_fetcher()
async with session: async with session:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(
@@ -548,7 +547,7 @@ class MultiplayerQueue:
if item.freestyle and len(item.allowed_mods) > 0: if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods") raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session: async with with_db() as session:
fetcher = await get_fetcher() fetcher = await get_fetcher()
async with session: async with session:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(
@@ -622,7 +621,7 @@ class MultiplayerQueue:
"Attempted to remove an item which has already been played" "Attempted to remove an item which has already been played"
) )
async with AsyncSession(engine) as session: async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session) await Playlist.delete_item(item.id, self.room.room_id, session)
found_item = next((i for i in self.room.playlist if i.id == item.id), None) found_item = next((i for i in self.room.playlist if i.id == item.id), None)
@@ -637,7 +636,7 @@ class MultiplayerQueue:
async def finish_current_item(self): async def finish_current_item(self):
from app.database import Playlist from app.database import Playlist
async with AsyncSession(engine) as session: async with with_db() as session:
played_at = datetime.now(UTC) played_at = datetime.now(UTC)
await session.execute( await session.execute(
update(Playlist) update(Playlist)

View File

@@ -92,7 +92,7 @@ class GameMode(str, Enum):
def parse(cls, v: str | int) -> "GameMode | None": def parse(cls, v: str | int) -> "GameMode | None":
if isinstance(v, int) or v.isdigit(): if isinstance(v, int) or v.isdigit():
return cls.from_int_extra(int(v)) return cls.from_int_extra(int(v))
v = v.lower() v = v.upper()
try: try:
return cls[v] return cls[v]
except ValueError: except ValueError:

View File

@@ -18,8 +18,7 @@ from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies import get_db from app.dependencies.database import Database, get_redis
from app.dependencies.database import get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger from app.log import logger
@@ -37,7 +36,6 @@ from fastapi.responses import JSONResponse
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlalchemy import text from sqlalchemy import text
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def create_oauth_error_response( def create_oauth_error_response(
@@ -89,11 +87,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
description="用户注册接口", description="用户注册接口",
) )
async def register_user( async def register_user(
db: Database,
request: Request, request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"), user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"), user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"), user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper), geoip: GeoIPHelper = Depends(get_geoip_helper),
): ):
username_errors = validate_username(user_username) username_errors = validate_username(user_username)
@@ -205,6 +203,7 @@ async def register_user(
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。", description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
) )
async def oauth_token( async def oauth_token(
db: Database,
request: Request, request: Request,
grant_type: Literal[ grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials" "authorization_code", "refresh_token", "password", "client_credentials"
@@ -218,7 +217,6 @@ async def oauth_token(
refresh_token: str | None = Form( refresh_token: str | None = Form(
None, description="刷新令牌(仅刷新令牌模式需要)" None, description="刷新令牌(仅刷新令牌模式需要)"
), ),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
scopes = scope.split(" ") scopes = scope.split(" ")

View File

@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.lazer_user import User, UserResp from app.database.lazer_user import User, UserResp
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
@@ -22,7 +22,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class UpdateResponse(BaseModel): class UpdateResponse(BaseModel):
@@ -38,6 +37,7 @@ class UpdateResponse(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def get_update( async def get_update(
session: Database,
history_since: int | None = Query( history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录" None, description="获取自此禁言 ID 之后的禁言记录"
), ),
@@ -46,7 +46,6 @@ async def get_update(
["presence", "silences"], alias="includes[]", description="要包含的更新类型" ["presence", "silences"], alias="includes[]", description="要包含的更新类型"
), ),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
resp = UpdateResponse() resp = UpdateResponse()
@@ -101,10 +100,10 @@ async def get_update(
tags=["聊天"], tags=["聊天"],
) )
async def join_channel( async def join_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"), user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
@@ -121,10 +120,10 @@ async def join_channel(
tags=["聊天"], tags=["聊天"],
) )
async def leave_channel( async def leave_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"), user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
@@ -142,8 +141,8 @@ async def leave_channel(
tags=["聊天"], tags=["聊天"],
) )
async def get_channel_list( async def get_channel_list(
session: Database,
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
channels = ( channels = (
@@ -181,9 +180,9 @@ class GetChannelResp(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def get_channel( async def get_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: str = Path(..., description="频道 ID/名称"),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
@@ -250,9 +249,9 @@ class CreateChannelReq(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def create_channel( async def create_channel(
session: Database,
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
if req.type == "PM": if req.type == "PM":

View File

@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
@@ -23,7 +23,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class KeepAliveResp(BaseModel): class KeepAliveResp(BaseModel):
@@ -38,12 +37,12 @@ class KeepAliveResp(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def keep_alive( async def keep_alive(
session: Database,
history_since: int | None = Query( history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录" None, description="获取自此禁言 ID 之后的禁言记录"
), ),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
): ):
resp = KeepAliveResp() resp = KeepAliveResp()
if history_since: if history_since:
@@ -84,10 +83,10 @@ class MessageReq(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def send_message( async def send_message(
session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)), req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]), current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
if db_channel is None: if db_channel is None:
@@ -125,12 +124,12 @@ async def send_message(
tags=["聊天"], tags=["聊天"],
) )
async def get_message( async def get_message(
session: Database,
channel: str, channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"), limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"), since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"), until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
if db_channel is None: if db_channel is None:
@@ -158,10 +157,10 @@ async def get_message(
tags=["聊天"], tags=["聊天"],
) )
async def mark_as_read( async def mark_as_read(
session: Database,
channel: str = Path(..., description="频道 ID/名称"), channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"), message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
): ):
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
if db_channel is None: if db_channel is None:
@@ -191,9 +190,9 @@ class NewPMResp(BaseModel):
tags=["聊天"], tags=["聊天"],
) )
async def create_new_pm( async def create_new_pm(
session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)), req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]), current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
user_id = current_user.id user_id = current_user.id

View File

@@ -6,9 +6,9 @@ from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMes
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import ( from app.dependencies.database import (
DBFactory, DBFactory,
engine,
get_db_factory, get_db_factory,
get_redis, get_redis,
with_db,
) )
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.log import logger from app.log import logger
@@ -200,7 +200,7 @@ class ChatServer:
) )
async def join_room_channel(self, channel_id: int, user_id: int): async def join_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session: async with with_db() as session:
channel = await ChatChannel.get(channel_id, session) channel = await ChatChannel.get(channel_id, session)
if channel is None: if channel is None:
return return
@@ -212,7 +212,7 @@ class ChatServer:
await self.join_channel(user, channel, session) await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int): async def leave_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session: async with with_db() as session:
channel = await ChatChannel.get(channel_id, session) channel = await ChatChannel.get(channel_id, session)
if channel is None: if channel is None:
return return
@@ -268,7 +268,7 @@ async def chat_websocket(
token = authorization[7:] token = authorization[7:]
if ( if (
user := await get_current_user( user := await get_current_user(
SecurityScopes(scopes=["chat.read"]), session, token_pw=token session, SecurityScopes(scopes=["chat.read"]), token_pw=token
) )
) is None: ) is None:
await websocket.close(code=1008) await websocket.close(code=1008)

View File

@@ -4,7 +4,7 @@ import hashlib
from io import BytesIO from io import BytesIO
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from app.storage.base import StorageService from app.storage.base import StorageService
@@ -13,7 +13,6 @@ from .router import router
from fastapi import Depends, File, HTTPException, Security from fastapi import Depends, File, HTTPException, Security
from PIL import Image from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post( @router.post(
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="上传头像", name="上传头像",
) )
async def upload_avatar( async def upload_avatar(
session: Database,
content: bytes = File(...), content: bytes = File(...),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service), storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db),
): ):
"""上传用户头像 """上传用户头像

View File

@@ -4,7 +4,7 @@ import hashlib
from io import BytesIO from io import BytesIO
from app.database.lazer_user import User, UserProfileCover from app.database.lazer_user import User, UserProfileCover
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from app.storage.base import StorageService from app.storage.base import StorageService
@@ -13,7 +13,6 @@ from .router import router
from fastapi import Depends, File, HTTPException, Security from fastapi import Depends, File, HTTPException, Security
from PIL import Image from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post( @router.post(
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="上传头图", name="上传头图",
) )
async def upload_cover( async def upload_cover(
session: Database,
content: bytes = File(...), content: bytes = File(...),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service), storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db),
): ):
"""上传用户头图 """上传用户头图

View File

@@ -4,7 +4,7 @@ import secrets
from app.database.auth import OAuthClient, OAuthToken from app.database.auth import OAuthClient, OAuthToken
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from .router import router from .router import router
@@ -12,7 +12,6 @@ from .router import router
from fastapi import Body, Depends, HTTPException, Security from fastapi import Body, Depends, HTTPException, Security
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import select, text from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post( @router.post(
@@ -21,11 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥", description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥",
) )
async def create_oauth_app( async def create_oauth_app(
session: Database,
name: str = Body(..., max_length=100, description="应用程序名称"), name: str = Body(..., max_length=100, description="应用程序名称"),
description: str = Body("", description="应用程序描述"), description: str = Body("", description="应用程序描述"),
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
): ):
result = await session.execute( # pyright: ignore[reportDeprecated] result = await session.execute( # pyright: ignore[reportDeprecated]
text( text(
@@ -61,8 +60,8 @@ async def create_oauth_app(
description="通过客户端 ID 获取 OAuth 应用的详细信息", description="通过客户端 ID 获取 OAuth 应用的详细信息",
) )
async def get_oauth_app( async def get_oauth_app(
session: Database,
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_app = await session.get(OAuthClient, client_id) oauth_app = await session.get(OAuthClient, client_id)
@@ -82,7 +81,7 @@ async def get_oauth_app(
description="获取当前用户创建的所有 OAuth 应用程序", description="获取当前用户创建的所有 OAuth 应用程序",
) )
async def get_user_oauth_apps( async def get_user_oauth_apps(
session: AsyncSession = Depends(get_db), session: Database,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_apps = await session.exec( oauth_apps = await session.exec(
@@ -106,8 +105,8 @@ async def get_user_oauth_apps(
description="删除指定的 OAuth 应用程序及其关联的所有令牌", description="删除指定的 OAuth 应用程序及其关联的所有令牌",
) )
async def delete_oauth_app( async def delete_oauth_app(
session: Database,
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
@@ -134,11 +133,11 @@ async def delete_oauth_app(
description="更新指定 OAuth 应用的名称、描述和重定向 URI", description="更新指定 OAuth 应用的名称、描述和重定向 URI",
) )
async def update_oauth_app( async def update_oauth_app(
session: Database,
client_id: int, client_id: int,
name: str = Body(..., max_length=100, description="应用程序新名称"), name: str = Body(..., max_length=100, description="应用程序新名称"),
description: str = Body("", description="应用程序新描述"), description: str = Body("", description="应用程序新描述"),
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"), redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
@@ -169,8 +168,8 @@ async def update_oauth_app(
description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效", description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效",
) )
async def refresh_secret( async def refresh_secret(
session: Database,
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
@@ -204,11 +203,11 @@ async def refresh_secret(
description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程", description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程",
) )
async def generate_oauth_code( async def generate_oauth_code(
session: Database,
client_id: int, client_id: int,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redirect_uri: str = Body(..., description="授权后重定向的 URI"), redirect_uri: str = Body(..., description="授权后重定向的 URI"),
scopes: list[str] = Body(..., description="请求的权限范围列表"), scopes: list[str] = Body(..., description="请求的权限范围列表"),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
client = await session.get(OAuthClient, client_id) client = await session.get(OAuthClient, client_id)

View File

@@ -2,15 +2,14 @@ from __future__ import annotations
from app.database import Relationship, User from app.database import Relationship, User
from app.database.relationship import RelationshipType from app.database.relationship import RelationshipType
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from .router import router from .router import router
from fastapi import Depends, HTTPException, Path, Security from fastapi import HTTPException, Path, Security
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class CheckResponse(BaseModel): class CheckResponse(BaseModel):
@@ -26,9 +25,9 @@ class CheckResponse(BaseModel):
response_model=CheckResponse, response_model=CheckResponse,
) )
async def check_user_relationship( async def check_user_relationship(
db: Database,
user_id: int = Path(..., description="目标用户的 ID"), user_id: int = Path(..., description="目标用户的 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
if user_id == current_user.id: if user_id == current_user.id:
raise HTTPException(422, "Cannot check relationship with yourself") raise HTTPException(422, "Cannot check relationship with yourself")

View File

@@ -6,14 +6,13 @@ from app.auth import validate_username
from app.config import settings from app.config import settings
from app.database.events import Event, EventType from app.database.events import Event, EventType
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.user import get_client_user from app.dependencies.user import get_client_user
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException, Security from fastapi import Body, HTTPException, Security
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.post( @router.post(
@@ -21,8 +20,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="修改用户名", name="修改用户名",
) )
async def user_rename( async def user_rename(
session: Database,
new_name: str = Body(..., description="新的用户名"), new_name: str = Body(..., description="新的用户名"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
"""修改用户名 """修改用户名

View File

@@ -8,7 +8,7 @@ from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.beatmapset import Beatmapset from app.database.beatmapset import Beatmapset
from app.database.favourite_beatmapset import FavouriteBeatmapset from app.database.favourite_beatmapset import FavouriteBeatmapset
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.beatmap import BeatmapRankStatus, Genre, Language
@@ -149,6 +149,7 @@ class V1Beatmap(AllStrModel):
description="根据指定条件搜索谱面。", description="根据指定条件搜索谱面。",
) )
async def get_beatmaps( async def get_beatmaps(
session: Database,
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"), since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
@@ -163,7 +164,6 @@ async def get_beatmaps(
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
mods: int = Query(0, description="应用到谱面属性的 MOD"), mods: int = Query(0, description="应用到谱面属性的 MOD"),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):

View File

@@ -6,7 +6,7 @@ from typing import Literal
from app.database.counts import ReplayWatchedCount from app.database.counts import ReplayWatchedCount
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.models.mods import int_to_mods from app.models.mods import int_to_mods
from app.models.score import GameMode from app.models.score import GameMode
@@ -17,7 +17,6 @@ from .router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class ReplayModel(BaseModel): class ReplayModel(BaseModel):
@@ -32,6 +31,7 @@ class ReplayModel(BaseModel):
description="获取指定谱面的回放文件。", description="获取指定谱面的回放文件。",
) )
async def download_replay( async def download_replay(
session: Database,
beatmap: int = Query(..., alias="b", description="谱面 ID"), beatmap: int = Query(..., alias="b", description="谱面 ID"),
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query( ruleset_id: int | None = Query(
@@ -45,7 +45,6 @@ async def download_replay(
None, description="用户类型string 用户名称 / id 用户 ID" None, description="用户类型string 用户名称 / id 用户 ID"
), ),
mods: int = Query(0, description="成绩的 MOD"), mods: int = Query(0, description="成绩的 MOD"),
session: AsyncSession = Depends(get_db),
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService = Depends(get_storage_service),
): ):
mods_ = int_to_mods(mods) mods_ = int_to_mods(mods)

View File

@@ -5,16 +5,15 @@ from typing import Literal
from app.database.pp_best_score import PPBestScore from app.database.pp_best_score import PPBestScore
from app.database.score import Score, get_leaderboard from app.database.score import Score, get_leaderboard
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.models.mods import int_to_mods, mod_to_save, mods_to_int from app.models.mods import int_to_mods, mod_to_save, mods_to_int
from app.models.score import GameMode, LeaderboardType from app.models.score import GameMode, LeaderboardType
from .router import AllStrModel, router from .router import AllStrModel, router
from fastapi import Depends, HTTPException, Query from fastapi import HTTPException, Query
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import col, exists, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
class V1Score(AllStrModel): class V1Score(AllStrModel):
@@ -68,13 +67,13 @@ class V1Score(AllStrModel):
description="获取指定用户的最好成绩。", description="获取指定用户的最好成绩。",
) )
async def get_user_best( async def get_user_best(
session: Database,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID" None, description="用户类型string 用户名称 / id 用户 ID"
), ),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
session: AsyncSession = Depends(get_db),
): ):
try: try:
scores = ( scores = (
@@ -104,13 +103,13 @@ async def get_user_best(
description="获取指定用户的最近成绩。", description="获取指定用户的最近成绩。",
) )
async def get_user_recent( async def get_user_recent(
session: Database,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID" None, description="用户类型string 用户名称 / id 用户 ID"
), ),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
session: AsyncSession = Depends(get_db),
): ):
try: try:
scores = ( scores = (
@@ -140,6 +139,7 @@ async def get_user_recent(
description="获取指定谱面的成绩。", description="获取指定谱面的成绩。",
) )
async def get_scores( async def get_scores(
session: Database,
user: str | None = Query(None, alias="u", description="用户"), user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"), beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
@@ -148,7 +148,6 @@ async def get_scores(
), ),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"), mods: int = Query(0, description="成绩的 MOD"),
session: AsyncSession = Depends(get_db),
): ):
try: try:
if user is not None: if user is not None:

View File

@@ -5,14 +5,13 @@ from typing import Literal
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.models.score import GameMode from app.models.score import GameMode
from .router import AllStrModel, router from .router import AllStrModel, router
from fastapi import Depends, HTTPException, Query from fastapi import HTTPException, Query
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class V1User(AllStrModel): class V1User(AllStrModel):
@@ -41,7 +40,7 @@ class V1User(AllStrModel):
@classmethod @classmethod
async def from_db( async def from_db(
cls, session: AsyncSession, db_user: User, ruleset: GameMode | None = None cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User": ) -> "V1User":
ruleset = ruleset or db_user.playmode ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None current_statistics: UserStatistics | None = None
@@ -92,6 +91,7 @@ class V1User(AllStrModel):
description="获取指定用户的信息。", description="获取指定用户的信息。",
) )
async def get_user( async def get_user(
session: Database,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(
@@ -100,7 +100,6 @@ async def get_user(
event_days: int = Query( event_days: int = Query(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数" default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
), ),
session: AsyncSession = Depends(get_db),
): ):
db_user = ( db_user = (
await session.exec( await session.exec(

View File

@@ -6,7 +6,7 @@ import json
from app.database import Beatmap, BeatmapResp, User from app.database import Beatmap, BeatmapResp, User
from app.database.beatmap import calculate_beatmap_attributes from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.fetcher import Fetcher from app.fetcher import Fetcher
@@ -24,7 +24,6 @@ from pydantic import BaseModel
from redis.asyncio import Redis from redis.asyncio import Redis
import rosu_pp_py as rosu import rosu_pp_py as rosu
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class BatchGetResp(BaseModel): class BatchGetResp(BaseModel):
@@ -47,13 +46,13 @@ class BatchGetResp(BaseModel):
), ),
) )
async def lookup_beatmap( async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"), id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query( filename: str | None = Query(
default=None, alias="filename", description="谱面文件名" default=None, alias="filename", description="谱面文件名"
), ),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
if id is None and md5 is None and filename is None: if id is None and md5 is None and filename is None:
@@ -80,9 +79,9 @@ async def lookup_beatmap(
description="获取单个谱面详情。", description="获取单个谱面详情。",
) )
async def get_beatmap( async def get_beatmap(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
try: try:
@@ -103,11 +102,11 @@ async def get_beatmap(
), ),
) )
async def batch_get_beatmaps( async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query( beatmap_ids: list[int] = Query(
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)" alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
), ),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
if not beatmap_ids: if not beatmap_ids:
@@ -157,6 +156,7 @@ async def batch_get_beatmaps(
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"), description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
) )
async def get_beatmap_attributes( async def get_beatmap_attributes(
db: Database,
beatmap_id: int = Path(..., description="谱面 ID"), beatmap_id: int = Path(..., description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query( mods: list[str] = Query(
@@ -170,7 +170,6 @@ async def get_beatmap_attributes(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3 default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
), ),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
mods_ = [] mods_ = []

View File

@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import engine, get_db, get_redis from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import get_client_user, get_current_user
@@ -30,11 +30,10 @@ from fastapi import (
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from httpx import HTTPError from httpx import HTTPError
from sqlmodel import exists, select from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def _save_to_db(sets: SearchBeatmapsetsResp): async def _save_to_db(sets: SearchBeatmapsetsResp):
async with AsyncSession(engine) as session: async with with_db() as session:
for s in sets.beatmapsets: for s in sets.beatmapsets:
if not ( if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id)) await session.exec(select(exists()).where(Beatmapset.id == s.id))
@@ -49,13 +48,13 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp, response_model=SearchBeatmapsetsResp,
) )
async def search_beatmapset( async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)], query: Annotated[SearchQueryModel, Query(...)],
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
redis = Depends(get_redis), redis=Depends(get_redis),
): ):
params = parse_qs(qs=request.url.query, keep_blank_values=True) params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {} cursor = {}
@@ -112,9 +111,9 @@ async def search_beatmapset(
description=("通过谱面 ID 查询所属谱面集。"), description=("通过谱面 ID 查询所属谱面集。"),
) )
async def lookup_beatmapset( async def lookup_beatmapset(
db: Database,
beatmap_id: int = Query(description="谱面 ID"), beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
@@ -132,9 +131,9 @@ async def lookup_beatmapset(
description="获取单个谱面集详情。", description="获取单个谱面集详情。",
) )
async def get_beatmapset( async def get_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
try: try:
@@ -196,12 +195,12 @@ async def download_beatmapset(
description="**客户端专属**\n收藏或取消收藏指定谱面集。", description="**客户端专属**\n收藏或取消收藏指定谱面集。",
) )
async def favourite_beatmapset( async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form( action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏" description="操作类型favourite 收藏 / unfavourite 取消收藏"
), ),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
assert current_user.id is not None assert current_user.id is not None
existing_favourite = ( existing_favourite = (

View File

@@ -3,13 +3,12 @@ from __future__ import annotations
from app.database import User, UserResp from app.database import User, UserResp
from app.database.lazer_user import ALL_INCLUDED from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.models.score import GameMode from app.models.score import GameMode
from .router import router from .router import router
from fastapi import Depends, Path, Security from fastapi import Path, Security
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get( @router.get(
@@ -20,9 +19,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
tags=["用户"], tags=["用户"],
) )
async def get_user_info_with_ruleset( async def get_user_info_with_ruleset(
session: Database,
ruleset: GameMode = Path(description="指定 ruleset"), ruleset: GameMode = Path(description="指定 ruleset"),
current_user: User = Security(get_current_user, scopes=["identify"]), current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
): ):
return await UserResp.from_db( return await UserResp.from_db(
current_user, current_user,
@@ -40,8 +39,8 @@ async def get_user_info_with_ruleset(
tags=["用户"], tags=["用户"],
) )
async def get_user_info_default( async def get_user_info_default(
session: Database,
current_user: User = Security(get_current_user, scopes=["identify"]), current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
): ):
return await UserResp.from_db( return await UserResp.from_db(
current_user, current_user,

View File

@@ -5,15 +5,14 @@ from typing import Literal
from app.database import User from app.database import User
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.models.score import GameMode from app.models.score import GameMode
from .router import router from .router import router
from fastapi import Depends, Path, Query, Security from fastapi import Path, Query, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class CountryStatistics(BaseModel): class CountryStatistics(BaseModel):
@@ -36,10 +35,10 @@ class CountryResponse(BaseModel):
tags=["排行榜"], tags=["排行榜"],
) )
async def get_country_ranking( async def get_country_ranking(
session: Database,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
): ):
response = CountryResponse(ranking=[]) response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all() countries = (await session.exec(select(User.country_code).distinct())).all()
@@ -85,6 +84,7 @@ class TopUsersResponse(BaseModel):
tags=["排行榜"], tags=["排行榜"],
) )
async def get_user_ranking( async def get_user_ranking(
session: Database,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path( type: Literal["performance", "score"] = Path(
..., description="排名类型performance 表现分 / score 计分成绩总分" ..., description="排名类型performance 表现分 / score 计分成绩总分"
@@ -92,7 +92,6 @@ async def get_user_ranking(
country: str | None = Query(None, description="国家代码"), country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
): ):
wheres = [ wheres = [
col(UserStatistics.mode) == ruleset, col(UserStatistics.mode) == ruleset,

View File

@@ -1,15 +1,14 @@
from __future__ import annotations from __future__ import annotations
from app.database import Relationship, RelationshipResp, RelationshipType, User from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import get_client_user, get_current_user
from .router import router from .router import router
from fastapi import Depends, HTTPException, Path, Query, Request, Security from fastapi import HTTPException, Path, Query, Request, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get( @router.get(
@@ -27,9 +26,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="获取当前用户的屏蔽用户列表。", description="获取当前用户的屏蔽用户列表。",
) )
async def get_relationship( async def get_relationship(
db: Database,
request: Request, request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]), current_user: User = Security(get_current_user, scopes=["friends.read"]),
db: AsyncSession = Depends(get_db),
): ):
relationship_type = ( relationship_type = (
RelationshipType.FOLLOW RelationshipType.FOLLOW
@@ -67,10 +66,10 @@ class AddFriendResp(BaseModel):
description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。", description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。",
) )
async def add_relationship( async def add_relationship(
db: Database,
request: Request, request: Request,
target: int = Query(description="目标用户 ID"), target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
assert current_user.id is not None assert current_user.id is not None
relationship_type = ( relationship_type = (
@@ -141,10 +140,10 @@ async def add_relationship(
description="**客户端专属**\n删除与目标用户的屏蔽关系。", description="**客户端专属**\n删除与目标用户的屏蔽关系。",
) )
async def delete_relationship( async def delete_relationship(
db: Database,
request: Request, request: Request,
target: int = Path(..., description="目标用户 ID"), target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
relationship_type = ( relationship_type = (
RelationshipType.BLOCK RelationshipType.BLOCK

View File

@@ -12,7 +12,7 @@ from app.database.playlists import Playlist, PlaylistResp
from app.database.room import APIUploadedRoom, Room, RoomResp from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.room_participated_user import RoomParticipatedUser from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import get_client_user, get_current_user
from app.models.room import RoomCategory, RoomStatus from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api from app.service.room import create_playlist_room_from_api
@@ -36,6 +36,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
description="获取房间列表。支持按状态/模式筛选", description="获取房间列表。支持按状态/模式筛选",
) )
async def get_all_rooms( async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query( mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open", default="open",
description=( description=(
@@ -51,7 +52,6 @@ async def get_all_rooms(
), ),
), ),
status: RoomStatus | None = Query(None, description="房间状态(可选)"), status: RoomStatus | None = Query(None, description="房间状态(可选)"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
resp_list: list[RoomResp] = [] resp_list: list[RoomResp] = []
@@ -149,8 +149,8 @@ async def _participate_room(
description="**客户端专属**\n创建一个新的房间。", description="**客户端专属**\n创建一个新的房间。",
) )
async def create_room( async def create_room(
db: Database,
room: APIUploadedRoom, room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
@@ -173,6 +173,7 @@ async def create_room(
description="获取单个房间详情。", description="获取单个房间详情。",
) )
async def get_room( async def get_room(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
category: str = Query( category: str = Query(
default="", default="",
@@ -181,7 +182,6 @@ async def get_room(
" / DAILY_CHALLENGE 每日挑战 (可选)" " / DAILY_CHALLENGE 每日挑战 (可选)"
), ),
), ),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
@@ -201,8 +201,8 @@ async def get_room(
description="**客户端专属**\n结束歌单模式房间。", description="**客户端专属**\n结束歌单模式房间。",
) )
async def delete_room( async def delete_room(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
@@ -221,9 +221,9 @@ async def delete_room(
description="**客户端专属**\n加入指定歌单模式房间。", description="**客户端专属**\n加入指定歌单模式房间。",
) )
async def add_user_to_room( async def add_user_to_room(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"), user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
@@ -245,9 +245,9 @@ async def add_user_to_room(
description="**客户端专属**\n离开指定歌单模式房间。", description="**客户端专属**\n离开指定歌单模式房间。",
) )
async def remove_user_from_room( async def remove_user_from_room(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"), user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
@@ -289,8 +289,8 @@ class APILeaderboard(BaseModel):
description="获取房间内累计得分排行榜。", description="获取房间内累计得分排行榜。",
) )
async def get_room_leaderboard( async def get_room_leaderboard(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
@@ -345,8 +345,8 @@ class RoomEvents(BaseModel):
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。", description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
) )
async def get_room_events( async def get_room_events(
db: Database,
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"), after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),

View File

@@ -32,7 +32,7 @@ from app.database.score import (
process_score, process_score,
process_user, process_user,
) )
from app.dependencies.database import get_db, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import get_client_user, get_current_user
@@ -220,6 +220,7 @@ class BeatmapScores(BaseModel):
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。", description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
) )
async def get_beatmap_scores( async def get_beatmap_scores(
db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"), mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
@@ -233,7 +234,6 @@ async def get_beatmap_scores(
), ),
), ),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
): ):
if legacy_only: if legacy_only:
@@ -277,13 +277,13 @@ class BeatmapUserScore(BaseModel):
description="获取指定用户在指定谱面上的最高成绩。", description="获取指定用户在指定谱面上的最高成绩。",
) )
async def get_user_beatmap_score( async def get_user_beatmap_score(
db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"), mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"), mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(
@@ -322,12 +322,12 @@ async def get_user_beatmap_score(
description="获取指定用户在指定谱面上的全部成绩列表。", description="获取指定用户在指定谱面上的全部成绩列表。",
) )
async def get_user_all_beatmap_scores( async def get_user_all_beatmap_scores(
db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"), ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(
@@ -357,12 +357,12 @@ async def get_user_all_beatmap_scores(
) )
async def create_solo_score( async def create_solo_score(
background_task: BackgroundTasks, background_task: BackgroundTasks,
db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
version_hash: str = Form("", description="游戏版本哈希"), version_hash: str = Form("", description="游戏版本哈希"),
beatmap_hash: str = Form(description="谱面文件哈希"), beatmap_hash: str = Form(description="谱面文件哈希"),
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
assert current_user.id is not None assert current_user.id is not None
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
@@ -387,11 +387,11 @@ async def create_solo_score(
) )
async def submit_solo_score( async def submit_solo_score(
req: Request, req: Request,
db: Database,
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
token: int = Path(description="成绩令牌 ID"), token: int = Path(description="成绩令牌 ID"),
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"), info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher), fetcher=Depends(get_fetcher),
): ):
@@ -407,6 +407,7 @@ async def submit_solo_score(
description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。", description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。",
) )
async def create_playlist_score( async def create_playlist_score(
session: Database,
background_task: BackgroundTasks, background_task: BackgroundTasks,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
@@ -415,7 +416,6 @@ async def create_playlist_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
version_hash: str = Form("", description="谱面版本哈希"), version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
): ):
assert current_user.id is not None assert current_user.id is not None
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
@@ -483,12 +483,12 @@ async def create_playlist_score(
description="**客户端专属**\n提交房间游玩项目成绩。", description="**客户端专属**\n提交房间游玩项目成绩。",
) )
async def submit_playlist_score( async def submit_playlist_score(
session: Database,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
token: int, token: int,
info: SoloScoreSubmissionInfo, info: SoloScoreSubmissionInfo,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
@@ -541,6 +541,7 @@ class IndexedScoreResp(MultiplayerScores):
tags=["成绩"], tags=["成绩"],
) )
async def index_playlist_scores( async def index_playlist_scores(
session: Database,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
@@ -548,7 +549,6 @@ async def index_playlist_scores(
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)" 2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
), ),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
): ):
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
@@ -607,11 +607,11 @@ async def index_playlist_scores(
tags=["成绩"], tags=["成绩"],
) )
async def show_playlist_score( async def show_playlist_score(
session: Database,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
score_id: int, score_id: int,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
@@ -678,11 +678,11 @@ async def show_playlist_score(
tags=["成绩"], tags=["成绩"],
) )
async def get_user_playlist_score( async def get_user_playlist_score(
session: Database,
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
user_id: int, user_id: int,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
session: AsyncSession = Depends(get_db),
): ):
score_record = None score_record = None
start_time = time.time() start_time = time.time()
@@ -716,9 +716,9 @@ async def get_user_playlist_score(
tags=["成绩"], tags=["成绩"],
) )
async def pin_score( async def pin_score(
db: Database,
score_id: int = Path(description="成绩 ID"), score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
score_record = ( score_record = (
await db.exec( await db.exec(
@@ -758,9 +758,9 @@ async def pin_score(
tags=["成绩"], tags=["成绩"],
) )
async def unpin_score( async def unpin_score(
db: Database,
score_id: int = Path(description="成绩 ID"), score_id: int = Path(description="成绩 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
score_record = ( score_record = (
await db.exec( await db.exec(
@@ -797,11 +797,11 @@ async def unpin_score(
tags=["成绩"], tags=["成绩"],
) )
async def reorder_score_pin( async def reorder_score_pin(
db: Database,
score_id: int = Path(description="成绩 ID"), score_id: int = Path(description="成绩 ID"),
after_score_id: int | None = Body(default=None, description="放在该成绩之后"), after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
before_score_id: int | None = Body(default=None, description="放在该成绩之前"), before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
db: AsyncSession = Depends(get_db),
): ):
score_record = ( score_record = (
await db.exec( await db.exec(
@@ -892,8 +892,8 @@ async def reorder_score_pin(
) )
async def download_score_replay( async def download_score_replay(
score_id: int, score_id: int,
db: Database,
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService = Depends(get_storage_service),
): ):
score = (await db.exec(select(Score).where(Score.id == score_id))).first() score = (await db.exec(select(Score).where(Score.id == score_id))).first()

View File

@@ -15,17 +15,16 @@ from app.database.events import EventResp
from app.database.lazer_user import SEARCH_INCLUDED from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp from app.database.score import Score, ScoreResp
from app.dependencies.database import get_db from app.dependencies.database import Database
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.models.score import GameMode from app.models.score import GameMode
from app.models.user import BeatmapsetType from app.models.user import BeatmapsetType
from .router import router from .router import router
from fastapi import Depends, HTTPException, Path, Query, Security from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import exists, false, select from sqlmodel import exists, false, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col from sqlmodel.sql.expression import col
@@ -43,6 +42,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False) @router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False) @router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users( async def get_users(
session: Database,
user_ids: list[int] = Query( user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表" default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
), ),
@@ -50,7 +50,6 @@ async def get_users(
include_variant_statistics: bool = Query( include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息" default=False, description="是否包含各模式的统计信息"
), # TODO: future use ), # TODO: future use
session: AsyncSession = Depends(get_db),
): ):
if user_ids: if user_ids:
searched_users = ( searched_users = (
@@ -79,9 +78,9 @@ async def get_users(
tags=["用户"], tags=["用户"],
) )
async def get_user_info_ruleset( async def get_user_info_ruleset(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"), user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"), ruleset: GameMode | None = Path(description="指定 ruleset"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
): ):
searched_user = ( searched_user = (
@@ -112,8 +111,8 @@ async def get_user_info_ruleset(
tags=["用户"], tags=["用户"],
) )
async def get_user_info( async def get_user_info(
session: Database,
user_id: str = Path(description="用户 ID 或用户名"), user_id: str = Path(description="用户 ID 或用户名"),
session: AsyncSession = Depends(get_db),
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
): ):
searched_user = ( searched_user = (
@@ -142,10 +141,10 @@ async def get_user_info(
tags=["用户"], tags=["用户"],
) )
async def get_user_beatmapsets( async def get_user_beatmapsets(
session: Database,
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"), type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"), offset: int = Query(0, ge=0, description="偏移量"),
): ):
@@ -202,6 +201,7 @@ async def get_user_beatmapsets(
tags=["用户"], tags=["用户"],
) )
async def get_user_scores( async def get_user_scores(
session: Database,
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path( type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=( description=(
@@ -216,7 +216,6 @@ async def get_user_scores(
), ),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"), offset: int = Query(0, ge=0, description="偏移量"),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
@@ -267,10 +266,10 @@ async def get_user_scores(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp] "/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
) )
async def get_user_events( async def get_user_events(
session: Database,
user: int, user: int,
limit: int | None = Query(None), limit: int | None = Query(None),
offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移 offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移
session: AsyncSession = Depends(get_db),
): ):
db_user = await session.get(User, user) db_user = await session.get(User, user)
if db_user is None or db_user.id == BANCHOBOT_ID: if db_user is None or db_user.id == BANCHOBOT_ID:

View File

@@ -4,12 +4,11 @@ from datetime import UTC, datetime, timedelta
from app.database import RankHistory, UserStatistics from app.database import RankHistory, UserStatistics
from app.database.rank_history import RankTop from app.database.rank_history import RankTop
from app.dependencies.database import engine from app.dependencies.database import with_db
from app.dependencies.scheduler import get_scheduler from app.dependencies.scheduler import get_scheduler
from app.models.score import GameMode from app.models.score import GameMode
from sqlmodel import col, exists, select, update from sqlmodel import col, exists, select, update
from sqlmodel.ext.asyncio.session import AsyncSession
@get_scheduler().scheduled_job( @get_scheduler().scheduled_job(
@@ -18,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
async def calculate_user_rank(is_today: bool = False): async def calculate_user_rank(is_today: bool = False):
today = datetime.now(UTC).date() today = datetime.now(UTC).date()
target_date = today if is_today else today - timedelta(days=1) target_date = today if is_today else today - timedelta(days=1)
async with AsyncSession(engine) as session: async with with_db() as session:
for gamemode in GameMode: for gamemode in GameMode:
users = await session.exec( users = await session.exec(
select(UserStatistics) select(UserStatistics)

View File

@@ -3,15 +3,14 @@ from __future__ import annotations
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies.database import engine from app.dependencies.database import with_db
from app.models.score import GameMode from app.models.score import GameMode
from sqlmodel import exists, select from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_banchobot(): async def create_banchobot():
async with AsyncSession(engine) as session: async with with_db() as session:
is_exist = ( is_exist = (
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID)) await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first() ).first()

View File

@@ -6,7 +6,7 @@ import json
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database.playlists import Playlist from app.database.playlists import Playlist
from app.database.room import Room from app.database.room import Room
from app.dependencies.database import engine, get_redis from app.dependencies.database import get_redis, with_db
from app.dependencies.scheduler import get_scheduler from app.dependencies.scheduler import get_scheduler
from app.log import logger from app.log import logger
from app.models.metadata_hub import DailyChallengeInfo from app.models.metadata_hub import DailyChallengeInfo
@@ -16,13 +16,12 @@ from app.models.room import RoomCategory
from .room import create_playlist_room from .room import create_playlist_room
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_daily_challenge_room( async def create_daily_challenge_room(
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = [] beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
) -> Room: ) -> Room:
async with AsyncSession(engine) as session: async with with_db() as session:
today = datetime.now(UTC).date() today = datetime.now(UTC).date()
return await create_playlist_room( return await create_playlist_room(
session=session, session=session,
@@ -52,7 +51,7 @@ async def daily_challenge_job():
key = f"daily_challenge:{now.date()}" key = f"daily_challenge:{now.date()}"
if not await redis.exists(key): if not await redis.exists(key):
return return
async with AsyncSession(engine) as session: async with with_db() as session:
room = ( room = (
await session.exec( await session.exec(
select(Room).where( select(Room).where(

View File

@@ -4,16 +4,15 @@ from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies.database import engine from app.dependencies.database import with_db
from app.models.score import GameMode from app.models.score import GameMode
from sqlalchemy import exists from sqlalchemy import exists
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_rx_statistics(): async def create_rx_statistics():
async with AsyncSession(engine) as session: async with with_db() as session:
users = (await session.exec(select(User.id))).all() users = (await session.exec(select(User.id))).all()
for i in users: for i in users:
if i == BANCHOBOT_ID: if i == BANCHOBOT_ID:

View File

@@ -4,13 +4,12 @@ from typing import TYPE_CHECKING
from app.database import PlaylistBestScore, Score from app.database import PlaylistBestScore, Score
from app.database.playlist_best_score import get_position from app.database.playlist_best_score import get_position
from app.dependencies.database import engine from app.dependencies.database import with_db
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
from .base import RedisSubscriber from .base import RedisSubscriber
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.signalr.hub import MetadataHub from app.signalr.hub import MetadataHub
@@ -45,7 +44,7 @@ class ScoreSubscriber(RedisSubscriber):
async def _notify_room_score_processed(self, score_id: int): async def _notify_room_score_processed(self, score_id: int):
if not self.metadata_hub: if not self.metadata_hub:
return return
async with AsyncSession(engine) as session: async with with_db() as session:
score = await session.get(Score, score_id) score = await session.get(Score, score_id)
if ( if (
not score not score

View File

@@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore
from app.database.playlists import Playlist from app.database.playlists import Playlist
from app.database.room import Room from app.database.room import Room
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import engine, get_redis from app.dependencies.database import get_redis, with_db
from app.models.metadata_hub import ( from app.models.metadata_hub import (
TOTAL_SCORE_DISTRIBUTION_BINS, TOTAL_SCORE_DISTRIBUTION_BINS,
DailyChallengeInfo, DailyChallengeInfo,
@@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber
from .hub import Client, Hub from .hub import Client, Hub
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers" ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
@@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]):
redis = get_redis() redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"): if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}") await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
user = ( user = (
await session.exec( await session.exec(
@@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]):
user_id = int(client.connection_id) user_id = int(client.connection_id)
self.get_or_create_state(client) self.get_or_create_state(client)
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
friends = ( friends = (
await session.exec( await session.exec(
@@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]):
return list(stats.playlist_item_stats.values()) return list(stats.playlist_item_stats.values())
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None: async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
async with AsyncSession(engine) as session: async with with_db() as session:
playlist_ids = ( playlist_ids = (
await session.exec( await session.exec(
select(Playlist.id).where( select(Playlist.id).where(

View File

@@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist from app.database.playlists import Playlist
from app.database.relationship import Relationship, RelationshipType from app.database.relationship import Relationship, RelationshipType
from app.database.room_participated_user import RoomParticipatedUser from app.database.room_participated_user import RoomParticipatedUser
from app.dependencies.database import engine, get_redis from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException from app.exception import InvokeException
from app.log import logger from app.log import logger
@@ -50,7 +50,6 @@ from .hub import Client, Hub
from httpx import HTTPError from httpx import HTTPError
from sqlalchemy import update from sqlalchemy import update
from sqlmodel import col, exists, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30 GAMEPLAY_LOAD_TIMEOUT = 30
@@ -61,7 +60,7 @@ class MultiplayerEventLogger:
async def log_event(self, event: MultiplayerEvent): async def log_event(self, event: MultiplayerEvent):
try: try:
async with AsyncSession(engine) as session: async with with_db() as session:
session.add(event) session.add(event)
await session.commit() await session.commit()
except Exception as e: except Exception as e:
@@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
if store.room_id != 0: if store.room_id != 0:
raise InvokeException("You are already in a room") raise InvokeException("You are already in a room")
async with AsyncSession(engine) as session: async with with_db() as session:
async with session: async with session:
db_room = Room( db_room = Room(
name=room.settings.name, name=room.settings.name,
@@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await server_room.match_type_handler.handle_join(user) await server_room.match_type_handler.handle_join(user)
await self.event_logger.player_joined(room_id, user.user_id) await self.event_logger.player_joined(room_id, user.user_id)
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
if ( if (
participated_user := ( participated_user := (
@@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
async def change_db_settings(self, room: ServerMultiplayerRoom): async def change_db_settings(self, room: ServerMultiplayerRoom):
async with AsyncSession(engine) as session: async with with_db() as session:
await session.execute( await session.execute(
update(Room) update(Room)
.where(col(Room.id) == room.room.room_id) .where(col(Room.id) == room.room.room_id)
@@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room, room,
user, user,
) )
async with AsyncSession(engine) as session: async with with_db() as session:
try: try:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=room.queue.current_item.beatmap_id session, fetcher, bid=room.queue.current_item.beatmap_id
@@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if not room.queue.current_item.freestyle: if not room.queue.current_item.freestyle:
raise InvokeException("Current item does not allow free user styles.") raise InvokeException("Current item does not allow free user styles.")
async with AsyncSession(engine) as session: async with with_db() as session:
item_beatmap = await session.get( item_beatmap = await session.get(
Beatmap, room.queue.current_item.beatmap_id Beatmap, room.queue.current_item.beatmap_id
) )
@@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
redis = get_redis() redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}") await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
participated_user = ( participated_user = (
await session.exec( await session.exec(
@@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def end_room(self, room: ServerMultiplayerRoom): async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host assert room.room.host
async with AsyncSession(engine) as session: async with with_db() as session:
await session.execute( await session.execute(
update(Room) update(Room)
.where(col(Room.id) == room.room.room_id) .where(col(Room.id) == room.room.room_id)
@@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if user is None: if user is None:
raise InvokeException("You are not in this room") raise InvokeException("You are not in this room")
async with AsyncSession(engine) as session: async with with_db() as session:
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
target_relationship = ( target_relationship = (
await session.exec( await session.exec(

View File

@@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp
from app.database.score import Score from app.database.score import Score
from app.database.score_token import ScoreToken from app.database.score_token import ScoreToken
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies.database import engine, get_redis from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.exception import InvokeException from app.exception import InvokeException
@@ -38,7 +38,6 @@ from .hub import Client, Hub
from httpx import HTTPError from httpx import HTTPError
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 30 READ_SCORE_TIMEOUT = 30
REPLAY_LATEST_VER = 30000016 REPLAY_LATEST_VER = 30000016
@@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
return return
fetcher = await get_fetcher() fetcher = await get_fetcher()
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
try: try:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(
@@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]):
assert store.checksum is not None assert store.checksum is not None
assert store.ruleset_id is not None assert store.ruleset_id is not None
assert store.score is not None assert store.score is not None
async with AsyncSession(engine) as session: async with with_db() as session:
async with session: async with session:
start_time = time.time() start_time = time.time()
score_record = None score_record = None
@@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]):
self, user_id: int, state: SpectatorState, store: StoreClientState self, user_id: int, state: SpectatorState, store: StoreClientState
) -> None: ) -> None:
async def _add_failtime(): async def _add_failtime():
async with AsyncSession(engine) as session: async with with_db() as session:
failtime = await session.get(FailTime, state.beatmap_id) failtime = await session.get(FailTime, state.beatmap_id)
total_length = ( total_length = (
await session.exec( await session.exec(
@@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]):
return return
before_time = int(messages[0][1]["time"]) before_time = int(messages[0][1]["time"])
await redis.delete(key) await redis.delete(key)
async with AsyncSession(engine) as session: async with with_db() as session:
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods) gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
statistics = ( statistics = (
await session.exec( await session.exec(
@@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.add_to_group(client, self.group_id(target_id)) self.add_to_group(client, self.group_id(target_id))
async with AsyncSession(engine) as session: async with with_db() as session:
async with session.begin(): async with session.begin():
username = ( username = (
await session.exec(select(User.username).where(User.id == user_id)) await session.exec(select(User.username).where(User.id == user_id))

View File

@@ -8,7 +8,7 @@ import uuid
from app.database import User as DBUser from app.database import User as DBUser
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import get_db from app.dependencies.database import DBFactory, get_db_factory
from app.models.signalr import NegotiateResponse, Transport from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs from .hub import Hubs
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes from fastapi.security import SecurityScopes
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/signalr", include_in_schema=False) router = APIRouter(prefix="/signalr", include_in_schema=False)
@@ -47,7 +46,7 @@ async def connect(
websocket: WebSocket, websocket: WebSocket,
id: str, id: str,
authorization: str = Header(...), authorization: str = Header(...),
db: AsyncSession = Depends(get_db), factory: DBFactory = Depends(get_db_factory),
): ):
token = authorization[7:] token = authorization[7:]
user_id = id.split(":")[0] user_id = id.split(":")[0]
@@ -56,13 +55,14 @@ async def connect(
await websocket.close(code=1008) await websocket.close(code=1008)
return return
try: try:
if ( async for session in factory():
user := await get_current_user( if (
SecurityScopes(scopes=["*"]), db, token_pw=token user := await get_current_user(
) session, SecurityScopes(scopes=["*"]), token_pw=token
) is None or str(user.id) != user_id: )
await websocket.close(code=1008) ) is None or str(user.id) != user_id:
return await websocket.close(code=1008)
return
except HTTPException: except HTTPException:
await websocket.close(code=1008) await websocket.close(code=1008)
return return