refactor(app): update database code
This commit is contained in:
@@ -3,9 +3,11 @@ from __future__ import annotations
|
|||||||
from collections.abc import AsyncIterator, Callable
|
from collections.abc import AsyncIterator, Callable
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
import json
|
import json
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
@@ -52,7 +54,12 @@ async def get_db():
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
def with_db():
|
||||||
|
return AsyncSession(engine)
|
||||||
|
|
||||||
|
|
||||||
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
||||||
|
Database = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
async def get_db_factory() -> DBFactory:
|
async def get_db_factory() -> DBFactory:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from app.database import User
|
|||||||
from app.database.auth import V1APIKeys
|
from app.database.auth import V1APIKeys
|
||||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||||
|
|
||||||
from .database import get_db
|
from .database import Database
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import (
|
from fastapi.security import (
|
||||||
@@ -19,7 +19,6 @@ from fastapi.security import (
|
|||||||
SecurityScopes,
|
SecurityScopes,
|
||||||
)
|
)
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
@@ -64,7 +63,7 @@ v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API
|
|||||||
|
|
||||||
|
|
||||||
async def v1_authorize(
|
async def v1_authorize(
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Database,
|
||||||
api_key: Annotated[str, Depends(v1_api_key)],
|
api_key: Annotated[str, Depends(v1_api_key)],
|
||||||
):
|
):
|
||||||
"""V1 API Key 授权"""
|
"""V1 API Key 授权"""
|
||||||
@@ -79,8 +78,8 @@ async def v1_authorize(
|
|||||||
|
|
||||||
|
|
||||||
async def get_client_user(
|
async def get_client_user(
|
||||||
|
db: Database,
|
||||||
token: Annotated[str, Depends(oauth2_password)],
|
token: Annotated[str, Depends(oauth2_password)],
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
|
||||||
):
|
):
|
||||||
token_record = await get_token_by_access_token(db, token)
|
token_record = await get_token_by_access_token(db, token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
@@ -95,8 +94,8 @@ async def get_client_user(
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
|
db: Database,
|
||||||
security_scopes: SecurityScopes,
|
security_scopes: SecurityScopes,
|
||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
|
||||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||||
token_client_credentials: Annotated[
|
token_client_credentials: Annotated[
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from app.database.beatmap import Beatmap
|
from app.database.beatmap import Beatmap
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import with_db
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
|
|
||||||
@@ -41,7 +41,6 @@ from .signalr import (
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
from sqlmodel import col
|
from sqlmodel import col
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.database.room import Room
|
from app.database.room import Room
|
||||||
@@ -473,7 +472,7 @@ class MultiplayerQueue:
|
|||||||
(item for item in self.room.playlist if not item.expired),
|
(item for item in self.room.playlist if not item.expired),
|
||||||
key=lambda x: x.id,
|
key=lambda x: x.id,
|
||||||
)
|
)
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
for idx, item in enumerate(ordered_active_items):
|
for idx, item in enumerate(ordered_active_items):
|
||||||
if item.playlist_order == idx:
|
if item.playlist_order == idx:
|
||||||
continue
|
continue
|
||||||
@@ -522,7 +521,7 @@ class MultiplayerQueue:
|
|||||||
if item.freestyle and len(item.allowed_mods) > 0:
|
if item.freestyle and len(item.allowed_mods) > 0:
|
||||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
fetcher = await get_fetcher()
|
fetcher = await get_fetcher()
|
||||||
async with session:
|
async with session:
|
||||||
beatmap = await Beatmap.get_or_fetch(
|
beatmap = await Beatmap.get_or_fetch(
|
||||||
@@ -548,7 +547,7 @@ class MultiplayerQueue:
|
|||||||
if item.freestyle and len(item.allowed_mods) > 0:
|
if item.freestyle and len(item.allowed_mods) > 0:
|
||||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
fetcher = await get_fetcher()
|
fetcher = await get_fetcher()
|
||||||
async with session:
|
async with session:
|
||||||
beatmap = await Beatmap.get_or_fetch(
|
beatmap = await Beatmap.get_or_fetch(
|
||||||
@@ -622,7 +621,7 @@ class MultiplayerQueue:
|
|||||||
"Attempted to remove an item which has already been played"
|
"Attempted to remove an item which has already been played"
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||||
|
|
||||||
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||||
@@ -637,7 +636,7 @@ class MultiplayerQueue:
|
|||||||
async def finish_current_item(self):
|
async def finish_current_item(self):
|
||||||
from app.database import Playlist
|
from app.database import Playlist
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
played_at = datetime.now(UTC)
|
played_at = datetime.now(UTC)
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(Playlist)
|
update(Playlist)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class GameMode(str, Enum):
|
|||||||
def parse(cls, v: str | int) -> "GameMode | None":
|
def parse(cls, v: str | int) -> "GameMode | None":
|
||||||
if isinstance(v, int) or v.isdigit():
|
if isinstance(v, int) or v.isdigit():
|
||||||
return cls.from_int_extra(int(v))
|
return cls.from_int_extra(int(v))
|
||||||
v = v.lower()
|
v = v.upper()
|
||||||
try:
|
try:
|
||||||
return cls[v]
|
return cls[v]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ from app.config import settings
|
|||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import DailyChallengeStats, OAuthClient, User
|
from app.database import DailyChallengeStats, OAuthClient, User
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies import get_db
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.database import get_redis
|
|
||||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||||
from app.helpers.geoip_helper import GeoIPHelper
|
from app.helpers.geoip_helper import GeoIPHelper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -37,7 +36,6 @@ from fastapi.responses import JSONResponse
|
|||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
def create_oauth_error_response(
|
def create_oauth_error_response(
|
||||||
@@ -89,11 +87,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
|
|||||||
description="用户注册接口",
|
description="用户注册接口",
|
||||||
)
|
)
|
||||||
async def register_user(
|
async def register_user(
|
||||||
|
db: Database,
|
||||||
request: Request,
|
request: Request,
|
||||||
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
||||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||||
):
|
):
|
||||||
username_errors = validate_username(user_username)
|
username_errors = validate_username(user_username)
|
||||||
@@ -205,6 +203,7 @@ async def register_user(
|
|||||||
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
||||||
)
|
)
|
||||||
async def oauth_token(
|
async def oauth_token(
|
||||||
|
db: Database,
|
||||||
request: Request,
|
request: Request,
|
||||||
grant_type: Literal[
|
grant_type: Literal[
|
||||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||||
@@ -218,7 +217,6 @@ async def oauth_token(
|
|||||||
refresh_token: str | None = Form(
|
refresh_token: str | None = Form(
|
||||||
None, description="刷新令牌(仅刷新令牌模式需要)"
|
None, description="刷新令牌(仅刷新令牌模式需要)"
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
scopes = scope.split(" ")
|
scopes = scope.split(" ")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from app.database.chat import (
|
|||||||
UserSilenceResp,
|
UserSilenceResp,
|
||||||
)
|
)
|
||||||
from app.database.lazer_user import User, UserResp
|
from app.database.lazer_user import User, UserResp
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.param import BodyOrForm
|
from app.dependencies.param import BodyOrForm
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
@@ -22,7 +22,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateResponse(BaseModel):
|
class UpdateResponse(BaseModel):
|
||||||
@@ -38,6 +37,7 @@ class UpdateResponse(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def get_update(
|
async def get_update(
|
||||||
|
session: Database,
|
||||||
history_since: int | None = Query(
|
history_since: int | None = Query(
|
||||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||||
),
|
),
|
||||||
@@ -46,7 +46,6 @@ async def get_update(
|
|||||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
resp = UpdateResponse()
|
resp = UpdateResponse()
|
||||||
@@ -101,10 +100,10 @@ async def get_update(
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def join_channel(
|
async def join_channel(
|
||||||
|
session: Database,
|
||||||
channel: str = Path(..., description="频道 ID/名称"),
|
channel: str = Path(..., description="频道 ID/名称"),
|
||||||
user: str = Path(..., description="用户 ID"),
|
user: str = Path(..., description="用户 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
|
||||||
@@ -121,10 +120,10 @@ async def join_channel(
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def leave_channel(
|
async def leave_channel(
|
||||||
|
session: Database,
|
||||||
channel: str = Path(..., description="频道 ID/名称"),
|
channel: str = Path(..., description="频道 ID/名称"),
|
||||||
user: str = Path(..., description="用户 ID"),
|
user: str = Path(..., description="用户 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
|
||||||
@@ -142,8 +141,8 @@ async def leave_channel(
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def get_channel_list(
|
async def get_channel_list(
|
||||||
|
session: Database,
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
channels = (
|
channels = (
|
||||||
@@ -181,9 +180,9 @@ class GetChannelResp(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def get_channel(
|
async def get_channel(
|
||||||
|
session: Database,
|
||||||
channel: str = Path(..., description="频道 ID/名称"),
|
channel: str = Path(..., description="频道 ID/名称"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
@@ -250,9 +249,9 @@ class CreateChannelReq(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def create_channel(
|
async def create_channel(
|
||||||
|
session: Database,
|
||||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
if req.type == "PM":
|
if req.type == "PM":
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from app.database.chat import (
|
|||||||
UserSilenceResp,
|
UserSilenceResp,
|
||||||
)
|
)
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.param import BodyOrForm
|
from app.dependencies.param import BodyOrForm
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
@@ -23,7 +23,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class KeepAliveResp(BaseModel):
|
class KeepAliveResp(BaseModel):
|
||||||
@@ -38,12 +37,12 @@ class KeepAliveResp(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def keep_alive(
|
async def keep_alive(
|
||||||
|
session: Database,
|
||||||
history_since: int | None = Query(
|
history_since: int | None = Query(
|
||||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||||
),
|
),
|
||||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
resp = KeepAliveResp()
|
resp = KeepAliveResp()
|
||||||
if history_since:
|
if history_since:
|
||||||
@@ -84,10 +83,10 @@ class MessageReq(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def send_message(
|
async def send_message(
|
||||||
|
session: Database,
|
||||||
channel: str = Path(..., description="频道 ID/名称"),
|
channel: str = Path(..., description="频道 ID/名称"),
|
||||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
@@ -125,12 +124,12 @@ async def send_message(
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def get_message(
|
async def get_message(
|
||||||
|
session: Database,
|
||||||
channel: str,
|
channel: str,
|
||||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||||
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
|
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
|
||||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
@@ -158,10 +157,10 @@ async def get_message(
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def mark_as_read(
|
async def mark_as_read(
|
||||||
|
session: Database,
|
||||||
channel: str = Path(..., description="频道 ID/名称"),
|
channel: str = Path(..., description="频道 ID/名称"),
|
||||||
message: int = Path(..., description="消息 ID"),
|
message: int = Path(..., description="消息 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
@@ -191,9 +190,9 @@ class NewPMResp(BaseModel):
|
|||||||
tags=["聊天"],
|
tags=["聊天"],
|
||||||
)
|
)
|
||||||
async def create_new_pm(
|
async def create_new_pm(
|
||||||
|
session: Database,
|
||||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMes
|
|||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import (
|
from app.dependencies.database import (
|
||||||
DBFactory,
|
DBFactory,
|
||||||
engine,
|
|
||||||
get_db_factory,
|
get_db_factory,
|
||||||
get_redis,
|
get_redis,
|
||||||
|
with_db,
|
||||||
)
|
)
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -200,7 +200,7 @@ class ChatServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
channel = await ChatChannel.get(channel_id, session)
|
channel = await ChatChannel.get(channel_id, session)
|
||||||
if channel is None:
|
if channel is None:
|
||||||
return
|
return
|
||||||
@@ -212,7 +212,7 @@ class ChatServer:
|
|||||||
await self.join_channel(user, channel, session)
|
await self.join_channel(user, channel, session)
|
||||||
|
|
||||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
channel = await ChatChannel.get(channel_id, session)
|
channel = await ChatChannel.get(channel_id, session)
|
||||||
if channel is None:
|
if channel is None:
|
||||||
return
|
return
|
||||||
@@ -268,7 +268,7 @@ async def chat_websocket(
|
|||||||
token = authorization[7:]
|
token = authorization[7:]
|
||||||
if (
|
if (
|
||||||
user := await get_current_user(
|
user := await get_current_user(
|
||||||
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
|
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||||
)
|
)
|
||||||
) is None:
|
) is None:
|
||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import hashlib
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.dependencies.user import get_client_user
|
from app.dependencies.user import get_client_user
|
||||||
from app.storage.base import StorageService
|
from app.storage.base import StorageService
|
||||||
@@ -13,7 +13,6 @@ from .router import router
|
|||||||
|
|
||||||
from fastapi import Depends, File, HTTPException, Security
|
from fastapi import Depends, File, HTTPException, Security
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
name="上传头像",
|
name="上传头像",
|
||||||
)
|
)
|
||||||
async def upload_avatar(
|
async def upload_avatar(
|
||||||
|
session: Database,
|
||||||
content: bytes = File(...),
|
content: bytes = File(...),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
storage: StorageService = Depends(get_storage_service),
|
storage: StorageService = Depends(get_storage_service),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""上传用户头像
|
"""上传用户头像
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import hashlib
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from app.database.lazer_user import User, UserProfileCover
|
from app.database.lazer_user import User, UserProfileCover
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.dependencies.user import get_client_user
|
from app.dependencies.user import get_client_user
|
||||||
from app.storage.base import StorageService
|
from app.storage.base import StorageService
|
||||||
@@ -13,7 +13,6 @@ from .router import router
|
|||||||
|
|
||||||
from fastapi import Depends, File, HTTPException, Security
|
from fastapi import Depends, File, HTTPException, Security
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -21,10 +20,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
name="上传头图",
|
name="上传头图",
|
||||||
)
|
)
|
||||||
async def upload_cover(
|
async def upload_cover(
|
||||||
|
session: Database,
|
||||||
content: bytes = File(...),
|
content: bytes = File(...),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
storage: StorageService = Depends(get_storage_service),
|
storage: StorageService = Depends(get_storage_service),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""上传用户头图
|
"""上传用户头图
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import secrets
|
|||||||
|
|
||||||
from app.database.auth import OAuthClient, OAuthToken
|
from app.database.auth import OAuthClient, OAuthToken
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.user import get_client_user
|
from app.dependencies.user import get_client_user
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
@@ -12,7 +12,6 @@ from .router import router
|
|||||||
from fastapi import Body, Depends, HTTPException, Security
|
from fastapi import Body, Depends, HTTPException, Security
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import select, text
|
from sqlmodel import select, text
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -21,11 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥",
|
description="创建一个新的 OAuth 应用程序,并生成客户端 ID 和密钥",
|
||||||
)
|
)
|
||||||
async def create_oauth_app(
|
async def create_oauth_app(
|
||||||
|
session: Database,
|
||||||
name: str = Body(..., max_length=100, description="应用程序名称"),
|
name: str = Body(..., max_length=100, description="应用程序名称"),
|
||||||
description: str = Body("", description="应用程序描述"),
|
description: str = Body("", description="应用程序描述"),
|
||||||
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
|
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
result = await session.execute( # pyright: ignore[reportDeprecated]
|
result = await session.execute( # pyright: ignore[reportDeprecated]
|
||||||
text(
|
text(
|
||||||
@@ -61,8 +60,8 @@ async def create_oauth_app(
|
|||||||
description="通过客户端 ID 获取 OAuth 应用的详细信息",
|
description="通过客户端 ID 获取 OAuth 应用的详细信息",
|
||||||
)
|
)
|
||||||
async def get_oauth_app(
|
async def get_oauth_app(
|
||||||
|
session: Database,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
oauth_app = await session.get(OAuthClient, client_id)
|
oauth_app = await session.get(OAuthClient, client_id)
|
||||||
@@ -82,7 +81,7 @@ async def get_oauth_app(
|
|||||||
description="获取当前用户创建的所有 OAuth 应用程序",
|
description="获取当前用户创建的所有 OAuth 应用程序",
|
||||||
)
|
)
|
||||||
async def get_user_oauth_apps(
|
async def get_user_oauth_apps(
|
||||||
session: AsyncSession = Depends(get_db),
|
session: Database,
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
oauth_apps = await session.exec(
|
oauth_apps = await session.exec(
|
||||||
@@ -106,8 +105,8 @@ async def get_user_oauth_apps(
|
|||||||
description="删除指定的 OAuth 应用程序及其关联的所有令牌",
|
description="删除指定的 OAuth 应用程序及其关联的所有令牌",
|
||||||
)
|
)
|
||||||
async def delete_oauth_app(
|
async def delete_oauth_app(
|
||||||
|
session: Database,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
oauth_client = await session.get(OAuthClient, client_id)
|
oauth_client = await session.get(OAuthClient, client_id)
|
||||||
@@ -134,11 +133,11 @@ async def delete_oauth_app(
|
|||||||
description="更新指定 OAuth 应用的名称、描述和重定向 URI",
|
description="更新指定 OAuth 应用的名称、描述和重定向 URI",
|
||||||
)
|
)
|
||||||
async def update_oauth_app(
|
async def update_oauth_app(
|
||||||
|
session: Database,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
name: str = Body(..., max_length=100, description="应用程序新名称"),
|
name: str = Body(..., max_length=100, description="应用程序新名称"),
|
||||||
description: str = Body("", description="应用程序新描述"),
|
description: str = Body("", description="应用程序新描述"),
|
||||||
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
|
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
oauth_client = await session.get(OAuthClient, client_id)
|
oauth_client = await session.get(OAuthClient, client_id)
|
||||||
@@ -169,8 +168,8 @@ async def update_oauth_app(
|
|||||||
description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效",
|
description="为指定的 OAuth 应用生成新的客户端密钥,并使所有现有的令牌失效",
|
||||||
)
|
)
|
||||||
async def refresh_secret(
|
async def refresh_secret(
|
||||||
|
session: Database,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
oauth_client = await session.get(OAuthClient, client_id)
|
oauth_client = await session.get(OAuthClient, client_id)
|
||||||
@@ -204,11 +203,11 @@ async def refresh_secret(
|
|||||||
description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程",
|
description="为特定用户和 OAuth 应用生成授权码,用于授权码授权流程",
|
||||||
)
|
)
|
||||||
async def generate_oauth_code(
|
async def generate_oauth_code(
|
||||||
|
session: Database,
|
||||||
client_id: int,
|
client_id: int,
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
|
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
|
||||||
scopes: list[str] = Body(..., description="请求的权限范围列表"),
|
scopes: list[str] = Body(..., description="请求的权限范围列表"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
client = await session.get(OAuthClient, client_id)
|
client = await session.get(OAuthClient, client_id)
|
||||||
|
|||||||
@@ -2,15 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from app.database import Relationship, User
|
from app.database import Relationship, User
|
||||||
from app.database.relationship import RelationshipType
|
from app.database.relationship import RelationshipType
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import get_client_user
|
from app.dependencies.user import get_client_user
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Path, Security
|
from fastapi import HTTPException, Path, Security
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class CheckResponse(BaseModel):
|
class CheckResponse(BaseModel):
|
||||||
@@ -26,9 +25,9 @@ class CheckResponse(BaseModel):
|
|||||||
response_model=CheckResponse,
|
response_model=CheckResponse,
|
||||||
)
|
)
|
||||||
async def check_user_relationship(
|
async def check_user_relationship(
|
||||||
|
db: Database,
|
||||||
user_id: int = Path(..., description="目标用户的 ID"),
|
user_id: int = Path(..., description="目标用户的 ID"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
if user_id == current_user.id:
|
if user_id == current_user.id:
|
||||||
raise HTTPException(422, "Cannot check relationship with yourself")
|
raise HTTPException(422, "Cannot check relationship with yourself")
|
||||||
|
|||||||
@@ -6,14 +6,13 @@ from app.auth import validate_username
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database.events import Event, EventType
|
from app.database.events import Event, EventType
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import get_client_user
|
from app.dependencies.user import get_client_user
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Body, Depends, HTTPException, Security
|
from fastapi import Body, HTTPException, Security
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -21,8 +20,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
name="修改用户名",
|
name="修改用户名",
|
||||||
)
|
)
|
||||||
async def user_rename(
|
async def user_rename(
|
||||||
|
session: Database,
|
||||||
new_name: str = Body(..., description="新的用户名"),
|
new_name: str = Body(..., description="新的用户名"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
"""修改用户名
|
"""修改用户名
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from app.database.beatmap_playcounts import BeatmapPlaycounts
|
|||||||
from app.database.beatmapset import Beatmapset
|
from app.database.beatmapset import Beatmapset
|
||||||
from app.database.favourite_beatmapset import FavouriteBeatmapset
|
from app.database.favourite_beatmapset import FavouriteBeatmapset
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||||
@@ -149,6 +149,7 @@ class V1Beatmap(AllStrModel):
|
|||||||
description="根据指定条件搜索谱面。",
|
description="根据指定条件搜索谱面。",
|
||||||
)
|
)
|
||||||
async def get_beatmaps(
|
async def get_beatmaps(
|
||||||
|
session: Database,
|
||||||
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
|
since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"),
|
||||||
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
|
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
|
||||||
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
|
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
|
||||||
@@ -163,7 +164,6 @@ async def get_beatmaps(
|
|||||||
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
|
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
|
||||||
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
|
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
|
||||||
mods: int = Query(0, description="应用到谱面属性的 MOD"),
|
mods: int = Query(0, description="应用到谱面属性的 MOD"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from app.database.counts import ReplayWatchedCount
|
from app.database.counts import ReplayWatchedCount
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.models.mods import int_to_mods
|
from app.models.mods import int_to_mods
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
@@ -17,7 +17,6 @@ from .router import router
|
|||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayModel(BaseModel):
|
class ReplayModel(BaseModel):
|
||||||
@@ -32,6 +31,7 @@ class ReplayModel(BaseModel):
|
|||||||
description="获取指定谱面的回放文件。",
|
description="获取指定谱面的回放文件。",
|
||||||
)
|
)
|
||||||
async def download_replay(
|
async def download_replay(
|
||||||
|
session: Database,
|
||||||
beatmap: int = Query(..., alias="b", description="谱面 ID"),
|
beatmap: int = Query(..., alias="b", description="谱面 ID"),
|
||||||
user: str = Query(..., alias="u", description="用户"),
|
user: str = Query(..., alias="u", description="用户"),
|
||||||
ruleset_id: int | None = Query(
|
ruleset_id: int | None = Query(
|
||||||
@@ -45,7 +45,6 @@ async def download_replay(
|
|||||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||||
),
|
),
|
||||||
mods: int = Query(0, description="成绩的 MOD"),
|
mods: int = Query(0, description="成绩的 MOD"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
storage_service: StorageService = Depends(get_storage_service),
|
storage_service: StorageService = Depends(get_storage_service),
|
||||||
):
|
):
|
||||||
mods_ = int_to_mods(mods)
|
mods_ = int_to_mods(mods)
|
||||||
|
|||||||
@@ -5,16 +5,15 @@ from typing import Literal
|
|||||||
|
|
||||||
from app.database.pp_best_score import PPBestScore
|
from app.database.pp_best_score import PPBestScore
|
||||||
from app.database.score import Score, get_leaderboard
|
from app.database.score import Score, get_leaderboard
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.models.mods import int_to_mods, mod_to_save, mods_to_int
|
from app.models.mods import int_to_mods, mod_to_save, mods_to_int
|
||||||
from app.models.score import GameMode, LeaderboardType
|
from app.models.score import GameMode, LeaderboardType
|
||||||
|
|
||||||
from .router import AllStrModel, router
|
from .router import AllStrModel, router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import HTTPException, Query
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class V1Score(AllStrModel):
|
class V1Score(AllStrModel):
|
||||||
@@ -68,13 +67,13 @@ class V1Score(AllStrModel):
|
|||||||
description="获取指定用户的最好成绩。",
|
description="获取指定用户的最好成绩。",
|
||||||
)
|
)
|
||||||
async def get_user_best(
|
async def get_user_best(
|
||||||
|
session: Database,
|
||||||
user: str = Query(..., alias="u", description="用户"),
|
user: str = Query(..., alias="u", description="用户"),
|
||||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||||
type: Literal["string", "id"] | None = Query(
|
type: Literal["string", "id"] | None = Query(
|
||||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||||
),
|
),
|
||||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
scores = (
|
scores = (
|
||||||
@@ -104,13 +103,13 @@ async def get_user_best(
|
|||||||
description="获取指定用户的最近成绩。",
|
description="获取指定用户的最近成绩。",
|
||||||
)
|
)
|
||||||
async def get_user_recent(
|
async def get_user_recent(
|
||||||
|
session: Database,
|
||||||
user: str = Query(..., alias="u", description="用户"),
|
user: str = Query(..., alias="u", description="用户"),
|
||||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||||
type: Literal["string", "id"] | None = Query(
|
type: Literal["string", "id"] | None = Query(
|
||||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||||
),
|
),
|
||||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
scores = (
|
scores = (
|
||||||
@@ -140,6 +139,7 @@ async def get_user_recent(
|
|||||||
description="获取指定谱面的成绩。",
|
description="获取指定谱面的成绩。",
|
||||||
)
|
)
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
|
session: Database,
|
||||||
user: str | None = Query(None, alias="u", description="用户"),
|
user: str | None = Query(None, alias="u", description="用户"),
|
||||||
beatmap_id: int = Query(alias="b", description="谱面 ID"),
|
beatmap_id: int = Query(alias="b", description="谱面 ID"),
|
||||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||||
@@ -148,7 +148,6 @@ async def get_scores(
|
|||||||
),
|
),
|
||||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||||
mods: int = Query(0, description="成绩的 MOD"),
|
mods: int = Query(0, description="成绩的 MOD"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if user is not None:
|
if user is not None:
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ from typing import Literal
|
|||||||
|
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from .router import AllStrModel, router
|
from .router import AllStrModel, router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import HTTPException, Query
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class V1User(AllStrModel):
|
class V1User(AllStrModel):
|
||||||
@@ -41,7 +40,7 @@ class V1User(AllStrModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_db(
|
async def from_db(
|
||||||
cls, session: AsyncSession, db_user: User, ruleset: GameMode | None = None
|
cls, session: Database, db_user: User, ruleset: GameMode | None = None
|
||||||
) -> "V1User":
|
) -> "V1User":
|
||||||
ruleset = ruleset or db_user.playmode
|
ruleset = ruleset or db_user.playmode
|
||||||
current_statistics: UserStatistics | None = None
|
current_statistics: UserStatistics | None = None
|
||||||
@@ -92,6 +91,7 @@ class V1User(AllStrModel):
|
|||||||
description="获取指定用户的信息。",
|
description="获取指定用户的信息。",
|
||||||
)
|
)
|
||||||
async def get_user(
|
async def get_user(
|
||||||
|
session: Database,
|
||||||
user: str = Query(..., alias="u", description="用户"),
|
user: str = Query(..., alias="u", description="用户"),
|
||||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
|
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
|
||||||
type: Literal["string", "id"] | None = Query(
|
type: Literal["string", "id"] | None = Query(
|
||||||
@@ -100,7 +100,6 @@ async def get_user(
|
|||||||
event_days: int = Query(
|
event_days: int = Query(
|
||||||
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
|
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
|
||||||
),
|
),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_user = (
|
db_user = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import json
|
|||||||
|
|
||||||
from app.database import Beatmap, BeatmapResp, User
|
from app.database import Beatmap, BeatmapResp, User
|
||||||
from app.database.beatmap import calculate_beatmap_attributes
|
from app.database.beatmap import calculate_beatmap_attributes
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
@@ -24,7 +24,6 @@ from pydantic import BaseModel
|
|||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
import rosu_pp_py as rosu
|
import rosu_pp_py as rosu
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class BatchGetResp(BaseModel):
|
class BatchGetResp(BaseModel):
|
||||||
@@ -47,13 +46,13 @@ class BatchGetResp(BaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def lookup_beatmap(
|
async def lookup_beatmap(
|
||||||
|
db: Database,
|
||||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||||
filename: str | None = Query(
|
filename: str | None = Query(
|
||||||
default=None, alias="filename", description="谱面文件名"
|
default=None, alias="filename", description="谱面文件名"
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
if id is None and md5 is None and filename is None:
|
if id is None and md5 is None and filename is None:
|
||||||
@@ -80,9 +79,9 @@ async def lookup_beatmap(
|
|||||||
description="获取单个谱面详情。",
|
description="获取单个谱面详情。",
|
||||||
)
|
)
|
||||||
async def get_beatmap(
|
async def get_beatmap(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@@ -103,11 +102,11 @@ async def get_beatmap(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
async def batch_get_beatmaps(
|
async def batch_get_beatmaps(
|
||||||
|
db: Database,
|
||||||
beatmap_ids: list[int] = Query(
|
beatmap_ids: list[int] = Query(
|
||||||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
if not beatmap_ids:
|
if not beatmap_ids:
|
||||||
@@ -157,6 +156,7 @@ async def batch_get_beatmaps(
|
|||||||
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
|
description=("计算谱面指定 mods / ruleset 下谱面的难度属性 (难度/PP 相关属性)。"),
|
||||||
)
|
)
|
||||||
async def get_beatmap_attributes(
|
async def get_beatmap_attributes(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
mods: list[str] = Query(
|
mods: list[str] = Query(
|
||||||
@@ -170,7 +170,6 @@ async def get_beatmap_attributes(
|
|||||||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
||||||
),
|
),
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
mods_ = []
|
mods_ = []
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
|
|||||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
from app.database.beatmapset import SearchBeatmapsetsResp
|
||||||
from app.dependencies.beatmap_download import get_beatmap_download_service
|
from app.dependencies.beatmap_download import get_beatmap_download_service
|
||||||
from app.dependencies.database import engine, get_db, get_redis
|
from app.dependencies.database import Database, get_redis, with_db
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
@@ -30,11 +30,10 @@ from fastapi import (
|
|||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from sqlmodel import exists, select
|
from sqlmodel import exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
for s in sets.beatmapsets:
|
for s in sets.beatmapsets:
|
||||||
if not (
|
if not (
|
||||||
await session.exec(select(exists()).where(Beatmapset.id == s.id))
|
await session.exec(select(exists()).where(Beatmapset.id == s.id))
|
||||||
@@ -49,13 +48,13 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
|
|||||||
response_model=SearchBeatmapsetsResp,
|
response_model=SearchBeatmapsetsResp,
|
||||||
)
|
)
|
||||||
async def search_beatmapset(
|
async def search_beatmapset(
|
||||||
|
db: Database,
|
||||||
query: Annotated[SearchQueryModel, Query(...)],
|
query: Annotated[SearchQueryModel, Query(...)],
|
||||||
request: Request,
|
request: Request,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
redis = Depends(get_redis),
|
redis=Depends(get_redis),
|
||||||
):
|
):
|
||||||
params = parse_qs(qs=request.url.query, keep_blank_values=True)
|
params = parse_qs(qs=request.url.query, keep_blank_values=True)
|
||||||
cursor = {}
|
cursor = {}
|
||||||
@@ -112,9 +111,9 @@ async def search_beatmapset(
|
|||||||
description=("通过谱面 ID 查询所属谱面集。"),
|
description=("通过谱面 ID 查询所属谱面集。"),
|
||||||
)
|
)
|
||||||
async def lookup_beatmapset(
|
async def lookup_beatmapset(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Query(description="谱面 ID"),
|
beatmap_id: int = Query(description="谱面 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||||
@@ -132,9 +131,9 @@ async def lookup_beatmapset(
|
|||||||
description="获取单个谱面集详情。",
|
description="获取单个谱面集详情。",
|
||||||
)
|
)
|
||||||
async def get_beatmapset(
|
async def get_beatmapset(
|
||||||
|
db: Database,
|
||||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@@ -196,12 +195,12 @@ async def download_beatmapset(
|
|||||||
description="**客户端专属**\n收藏或取消收藏指定谱面集。",
|
description="**客户端专属**\n收藏或取消收藏指定谱面集。",
|
||||||
)
|
)
|
||||||
async def favourite_beatmapset(
|
async def favourite_beatmapset(
|
||||||
|
db: Database,
|
||||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||||
action: Literal["favourite", "unfavourite"] = Form(
|
action: Literal["favourite", "unfavourite"] = Form(
|
||||||
description="操作类型:favourite 收藏 / unfavourite 取消收藏"
|
description="操作类型:favourite 收藏 / unfavourite 取消收藏"
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
assert current_user.id is not None
|
assert current_user.id is not None
|
||||||
existing_favourite = (
|
existing_favourite = (
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ from __future__ import annotations
|
|||||||
from app.database import User, UserResp
|
from app.database import User, UserResp
|
||||||
from app.database.lazer_user import ALL_INCLUDED
|
from app.database.lazer_user import ALL_INCLUDED
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, Path, Security
|
from fastapi import Path, Security
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -20,9 +19,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info_with_ruleset(
|
async def get_user_info_with_ruleset(
|
||||||
|
session: Database,
|
||||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
return await UserResp.from_db(
|
return await UserResp.from_db(
|
||||||
current_user,
|
current_user,
|
||||||
@@ -40,8 +39,8 @@ async def get_user_info_with_ruleset(
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info_default(
|
async def get_user_info_default(
|
||||||
|
session: Database,
|
||||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
return await UserResp.from_db(
|
return await UserResp.from_db(
|
||||||
current_user,
|
current_user,
|
||||||
|
|||||||
@@ -5,15 +5,14 @@ from typing import Literal
|
|||||||
from app.database import User
|
from app.database import User
|
||||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, Path, Query, Security
|
from fastapi import Path, Query, Security
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class CountryStatistics(BaseModel):
|
class CountryStatistics(BaseModel):
|
||||||
@@ -36,10 +35,10 @@ class CountryResponse(BaseModel):
|
|||||||
tags=["排行榜"],
|
tags=["排行榜"],
|
||||||
)
|
)
|
||||||
async def get_country_ranking(
|
async def get_country_ranking(
|
||||||
|
session: Database,
|
||||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||||
page: int = Query(1, ge=1, description="页码"), # TODO
|
page: int = Query(1, ge=1, description="页码"), # TODO
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
response = CountryResponse(ranking=[])
|
response = CountryResponse(ranking=[])
|
||||||
countries = (await session.exec(select(User.country_code).distinct())).all()
|
countries = (await session.exec(select(User.country_code).distinct())).all()
|
||||||
@@ -85,6 +84,7 @@ class TopUsersResponse(BaseModel):
|
|||||||
tags=["排行榜"],
|
tags=["排行榜"],
|
||||||
)
|
)
|
||||||
async def get_user_ranking(
|
async def get_user_ranking(
|
||||||
|
session: Database,
|
||||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||||
type: Literal["performance", "score"] = Path(
|
type: Literal["performance", "score"] = Path(
|
||||||
..., description="排名类型:performance 表现分 / score 计分成绩总分"
|
..., description="排名类型:performance 表现分 / score 计分成绩总分"
|
||||||
@@ -92,7 +92,6 @@ async def get_user_ranking(
|
|||||||
country: str | None = Query(None, description="国家代码"),
|
country: str | None = Query(None, description="国家代码"),
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
wheres = [
|
wheres = [
|
||||||
col(UserStatistics.mode) == ruleset,
|
col(UserStatistics.mode) == ruleset,
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
from app.database import Relationship, RelationshipResp, RelationshipType, User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Path, Query, Request, Security
|
from fastapi import HTTPException, Path, Query, Request, Security
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -27,9 +26,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
description="获取当前用户的屏蔽用户列表。",
|
description="获取当前用户的屏蔽用户列表。",
|
||||||
)
|
)
|
||||||
async def get_relationship(
|
async def get_relationship(
|
||||||
|
db: Database,
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
relationship_type = (
|
relationship_type = (
|
||||||
RelationshipType.FOLLOW
|
RelationshipType.FOLLOW
|
||||||
@@ -67,10 +66,10 @@ class AddFriendResp(BaseModel):
|
|||||||
description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。",
|
description="**客户端专属**\n添加或更新与目标用户的屏蔽关系。",
|
||||||
)
|
)
|
||||||
async def add_relationship(
|
async def add_relationship(
|
||||||
|
db: Database,
|
||||||
request: Request,
|
request: Request,
|
||||||
target: int = Query(description="目标用户 ID"),
|
target: int = Query(description="目标用户 ID"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
assert current_user.id is not None
|
assert current_user.id is not None
|
||||||
relationship_type = (
|
relationship_type = (
|
||||||
@@ -141,10 +140,10 @@ async def add_relationship(
|
|||||||
description="**客户端专属**\n删除与目标用户的屏蔽关系。",
|
description="**客户端专属**\n删除与目标用户的屏蔽关系。",
|
||||||
)
|
)
|
||||||
async def delete_relationship(
|
async def delete_relationship(
|
||||||
|
db: Database,
|
||||||
request: Request,
|
request: Request,
|
||||||
target: int = Path(..., description="目标用户 ID"),
|
target: int = Path(..., description="目标用户 ID"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
relationship_type = (
|
relationship_type = (
|
||||||
RelationshipType.BLOCK
|
RelationshipType.BLOCK
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.database.playlists import Playlist, PlaylistResp
|
|||||||
from app.database.room import APIUploadedRoom, Room, RoomResp
|
from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||||
from app.database.room_participated_user import RoomParticipatedUser
|
from app.database.room_participated_user import RoomParticipatedUser
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
from app.models.room import RoomCategory, RoomStatus
|
from app.models.room import RoomCategory, RoomStatus
|
||||||
from app.service.room import create_playlist_room_from_api
|
from app.service.room import create_playlist_room_from_api
|
||||||
@@ -36,6 +36,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
description="获取房间列表。支持按状态/模式筛选",
|
description="获取房间列表。支持按状态/模式筛选",
|
||||||
)
|
)
|
||||||
async def get_all_rooms(
|
async def get_all_rooms(
|
||||||
|
db: Database,
|
||||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||||
default="open",
|
default="open",
|
||||||
description=(
|
description=(
|
||||||
@@ -51,7 +52,6 @@ async def get_all_rooms(
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
resp_list: list[RoomResp] = []
|
resp_list: list[RoomResp] = []
|
||||||
@@ -149,8 +149,8 @@ async def _participate_room(
|
|||||||
description="**客户端专属**\n创建一个新的房间。",
|
description="**客户端专属**\n创建一个新的房间。",
|
||||||
)
|
)
|
||||||
async def create_room(
|
async def create_room(
|
||||||
|
db: Database,
|
||||||
room: APIUploadedRoom,
|
room: APIUploadedRoom,
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
@@ -173,6 +173,7 @@ async def create_room(
|
|||||||
description="获取单个房间详情。",
|
description="获取单个房间详情。",
|
||||||
)
|
)
|
||||||
async def get_room(
|
async def get_room(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
category: str = Query(
|
category: str = Query(
|
||||||
default="",
|
default="",
|
||||||
@@ -181,7 +182,6 @@ async def get_room(
|
|||||||
" / DAILY_CHALLENGE 每日挑战 (可选)"
|
" / DAILY_CHALLENGE 每日挑战 (可选)"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
@@ -201,8 +201,8 @@ async def get_room(
|
|||||||
description="**客户端专属**\n结束歌单模式房间。",
|
description="**客户端专属**\n结束歌单模式房间。",
|
||||||
)
|
)
|
||||||
async def delete_room(
|
async def delete_room(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||||
@@ -221,9 +221,9 @@ async def delete_room(
|
|||||||
description="**客户端专属**\n加入指定歌单模式房间。",
|
description="**客户端专属**\n加入指定歌单模式房间。",
|
||||||
)
|
)
|
||||||
async def add_user_to_room(
|
async def add_user_to_room(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
user_id: int = Path(..., description="用户 ID"),
|
user_id: int = Path(..., description="用户 ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
@@ -245,9 +245,9 @@ async def add_user_to_room(
|
|||||||
description="**客户端专属**\n离开指定歌单模式房间。",
|
description="**客户端专属**\n离开指定歌单模式房间。",
|
||||||
)
|
)
|
||||||
async def remove_user_from_room(
|
async def remove_user_from_room(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
user_id: int = Path(..., description="用户 ID"),
|
user_id: int = Path(..., description="用户 ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
@@ -289,8 +289,8 @@ class APILeaderboard(BaseModel):
|
|||||||
description="获取房间内累计得分排行榜。",
|
description="获取房间内累计得分排行榜。",
|
||||||
)
|
)
|
||||||
async def get_room_leaderboard(
|
async def get_room_leaderboard(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||||
@@ -345,8 +345,8 @@ class RoomEvents(BaseModel):
|
|||||||
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
|
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
|
||||||
)
|
)
|
||||||
async def get_room_events(
|
async def get_room_events(
|
||||||
|
db: Database,
|
||||||
room_id: int = Path(..., description="房间 ID"),
|
room_id: int = Path(..., description="房间 ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||||
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
|
after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"),
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from app.database.score import (
|
|||||||
process_score,
|
process_score,
|
||||||
process_user,
|
process_user,
|
||||||
)
|
)
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
@@ -220,6 +220,7 @@ class BeatmapScores(BaseModel):
|
|||||||
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
|
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
|
||||||
)
|
)
|
||||||
async def get_beatmap_scores(
|
async def get_beatmap_scores(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(description="谱面 ID"),
|
beatmap_id: int = Path(description="谱面 ID"),
|
||||||
mode: GameMode = Query(description="指定 auleset"),
|
mode: GameMode = Query(description="指定 auleset"),
|
||||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||||
@@ -233,7 +234,6 @@ async def get_beatmap_scores(
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||||
):
|
):
|
||||||
if legacy_only:
|
if legacy_only:
|
||||||
@@ -277,13 +277,13 @@ class BeatmapUserScore(BaseModel):
|
|||||||
description="获取指定用户在指定谱面上的最高成绩。",
|
description="获取指定用户在指定谱面上的最高成绩。",
|
||||||
)
|
)
|
||||||
async def get_user_beatmap_score(
|
async def get_user_beatmap_score(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(description="谱面 ID"),
|
beatmap_id: int = Path(description="谱面 ID"),
|
||||||
user_id: int = Path(description="用户 ID"),
|
user_id: int = Path(description="用户 ID"),
|
||||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
mode: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||||
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
|
mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
if legacy_only:
|
if legacy_only:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -322,12 +322,12 @@ async def get_user_beatmap_score(
|
|||||||
description="获取指定用户在指定谱面上的全部成绩列表。",
|
description="获取指定用户在指定谱面上的全部成绩列表。",
|
||||||
)
|
)
|
||||||
async def get_user_all_beatmap_scores(
|
async def get_user_all_beatmap_scores(
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(description="谱面 ID"),
|
beatmap_id: int = Path(description="谱面 ID"),
|
||||||
user_id: int = Path(description="用户 ID"),
|
user_id: int = Path(description="用户 ID"),
|
||||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||||
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
if legacy_only:
|
if legacy_only:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -357,12 +357,12 @@ async def get_user_all_beatmap_scores(
|
|||||||
)
|
)
|
||||||
async def create_solo_score(
|
async def create_solo_score(
|
||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(description="谱面 ID"),
|
beatmap_id: int = Path(description="谱面 ID"),
|
||||||
version_hash: str = Form("", description="游戏版本哈希"),
|
version_hash: str = Form("", description="游戏版本哈希"),
|
||||||
beatmap_hash: str = Form(description="谱面文件哈希"),
|
beatmap_hash: str = Form(description="谱面文件哈希"),
|
||||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
assert current_user.id is not None
|
assert current_user.id is not None
|
||||||
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
||||||
@@ -387,11 +387,11 @@ async def create_solo_score(
|
|||||||
)
|
)
|
||||||
async def submit_solo_score(
|
async def submit_solo_score(
|
||||||
req: Request,
|
req: Request,
|
||||||
|
db: Database,
|
||||||
beatmap_id: int = Path(description="谱面 ID"),
|
beatmap_id: int = Path(description="谱面 ID"),
|
||||||
token: int = Path(description="成绩令牌 ID"),
|
token: int = Path(description="成绩令牌 ID"),
|
||||||
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
|
info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
fetcher=Depends(get_fetcher),
|
fetcher=Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
@@ -407,6 +407,7 @@ async def submit_solo_score(
|
|||||||
description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。",
|
description="**客户端专属**\n为房间游玩项目创建成绩提交令牌。",
|
||||||
)
|
)
|
||||||
async def create_playlist_score(
|
async def create_playlist_score(
|
||||||
|
session: Database,
|
||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
@@ -415,7 +416,6 @@ async def create_playlist_score(
|
|||||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||||
version_hash: str = Form("", description="谱面版本哈希"),
|
version_hash: str = Form("", description="谱面版本哈希"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
assert current_user.id is not None
|
assert current_user.id is not None
|
||||||
room = await session.get(Room, room_id)
|
room = await session.get(Room, room_id)
|
||||||
@@ -483,12 +483,12 @@ async def create_playlist_score(
|
|||||||
description="**客户端专属**\n提交房间游玩项目成绩。",
|
description="**客户端专属**\n提交房间游玩项目成绩。",
|
||||||
)
|
)
|
||||||
async def submit_playlist_score(
|
async def submit_playlist_score(
|
||||||
|
session: Database,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
token: int,
|
token: int,
|
||||||
info: SoloScoreSubmissionInfo,
|
info: SoloScoreSubmissionInfo,
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
fetcher: Fetcher = Depends(get_fetcher),
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
@@ -541,6 +541,7 @@ class IndexedScoreResp(MultiplayerScores):
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def index_playlist_scores(
|
async def index_playlist_scores(
|
||||||
|
session: Database,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||||
@@ -548,7 +549,6 @@ async def index_playlist_scores(
|
|||||||
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
|
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
|
||||||
),
|
),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
room = await session.get(Room, room_id)
|
room = await session.get(Room, room_id)
|
||||||
if not room:
|
if not room:
|
||||||
@@ -607,11 +607,11 @@ async def index_playlist_scores(
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def show_playlist_score(
|
async def show_playlist_score(
|
||||||
|
session: Database,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
score_id: int,
|
score_id: int,
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
room = await session.get(Room, room_id)
|
room = await session.get(Room, room_id)
|
||||||
@@ -678,11 +678,11 @@ async def show_playlist_score(
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def get_user_playlist_score(
|
async def get_user_playlist_score(
|
||||||
|
session: Database,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
score_record = None
|
score_record = None
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -716,9 +716,9 @@ async def get_user_playlist_score(
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def pin_score(
|
async def pin_score(
|
||||||
|
db: Database,
|
||||||
score_id: int = Path(description="成绩 ID"),
|
score_id: int = Path(description="成绩 ID"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
score_record = (
|
score_record = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
@@ -758,9 +758,9 @@ async def pin_score(
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def unpin_score(
|
async def unpin_score(
|
||||||
|
db: Database,
|
||||||
score_id: int = Path(description="成绩 ID"),
|
score_id: int = Path(description="成绩 ID"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
score_record = (
|
score_record = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
@@ -797,11 +797,11 @@ async def unpin_score(
|
|||||||
tags=["成绩"],
|
tags=["成绩"],
|
||||||
)
|
)
|
||||||
async def reorder_score_pin(
|
async def reorder_score_pin(
|
||||||
|
db: Database,
|
||||||
score_id: int = Path(description="成绩 ID"),
|
score_id: int = Path(description="成绩 ID"),
|
||||||
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
|
after_score_id: int | None = Body(default=None, description="放在该成绩之后"),
|
||||||
before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
|
before_score_id: int | None = Body(default=None, description="放在该成绩之前"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
score_record = (
|
score_record = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
@@ -892,8 +892,8 @@ async def reorder_score_pin(
|
|||||||
)
|
)
|
||||||
async def download_score_replay(
|
async def download_score_replay(
|
||||||
score_id: int,
|
score_id: int,
|
||||||
|
db: Database,
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
storage_service: StorageService = Depends(get_storage_service),
|
storage_service: StorageService = Depends(get_storage_service),
|
||||||
):
|
):
|
||||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||||
|
|||||||
@@ -15,17 +15,16 @@ from app.database.events import EventResp
|
|||||||
from app.database.lazer_user import SEARCH_INCLUDED
|
from app.database.lazer_user import SEARCH_INCLUDED
|
||||||
from app.database.pp_best_score import PPBestScore
|
from app.database.pp_best_score import PPBestScore
|
||||||
from app.database.score import Score, ScoreResp
|
from app.database.score import Score, ScoreResp
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import Database
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.models.user import BeatmapsetType
|
from app.models.user import BeatmapsetType
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
from fastapi import HTTPException, Path, Query, Security
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import exists, false, select
|
from sqlmodel import exists, false, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
from sqlmodel.sql.expression import col
|
from sqlmodel.sql.expression import col
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +42,7 @@ class BatchUserResponse(BaseModel):
|
|||||||
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
|
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
|
||||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||||
async def get_users(
|
async def get_users(
|
||||||
|
session: Database,
|
||||||
user_ids: list[int] = Query(
|
user_ids: list[int] = Query(
|
||||||
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
||||||
),
|
),
|
||||||
@@ -50,7 +50,6 @@ async def get_users(
|
|||||||
include_variant_statistics: bool = Query(
|
include_variant_statistics: bool = Query(
|
||||||
default=False, description="是否包含各模式的统计信息"
|
default=False, description="是否包含各模式的统计信息"
|
||||||
), # TODO: future use
|
), # TODO: future use
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
if user_ids:
|
if user_ids:
|
||||||
searched_users = (
|
searched_users = (
|
||||||
@@ -79,9 +78,9 @@ async def get_users(
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info_ruleset(
|
async def get_user_info_ruleset(
|
||||||
|
session: Database,
|
||||||
user_id: str = Path(description="用户 ID 或用户名"),
|
user_id: str = Path(description="用户 ID 或用户名"),
|
||||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
searched_user = (
|
searched_user = (
|
||||||
@@ -112,8 +111,8 @@ async def get_user_info_ruleset(
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_info(
|
async def get_user_info(
|
||||||
|
session: Database,
|
||||||
user_id: str = Path(description="用户 ID 或用户名"),
|
user_id: str = Path(description="用户 ID 或用户名"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
searched_user = (
|
searched_user = (
|
||||||
@@ -142,10 +141,10 @@ async def get_user_info(
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_beatmapsets(
|
async def get_user_beatmapsets(
|
||||||
|
session: Database,
|
||||||
user_id: int = Path(description="用户 ID"),
|
user_id: int = Path(description="用户 ID"),
|
||||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||||
offset: int = Query(0, ge=0, description="偏移量"),
|
offset: int = Query(0, ge=0, description="偏移量"),
|
||||||
):
|
):
|
||||||
@@ -202,6 +201,7 @@ async def get_user_beatmapsets(
|
|||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
)
|
)
|
||||||
async def get_user_scores(
|
async def get_user_scores(
|
||||||
|
session: Database,
|
||||||
user_id: int = Path(description="用户 ID"),
|
user_id: int = Path(description="用户 ID"),
|
||||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||||
description=(
|
description=(
|
||||||
@@ -216,7 +216,6 @@ async def get_user_scores(
|
|||||||
),
|
),
|
||||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||||
offset: int = Query(0, ge=0, description="偏移量"),
|
offset: int = Query(0, ge=0, description="偏移量"),
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
db_user = await session.get(User, user_id)
|
db_user = await session.get(User, user_id)
|
||||||
@@ -267,10 +266,10 @@ async def get_user_scores(
|
|||||||
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
||||||
)
|
)
|
||||||
async def get_user_events(
|
async def get_user_events(
|
||||||
|
session: Database,
|
||||||
user: int,
|
user: int,
|
||||||
limit: int | None = Query(None),
|
limit: int | None = Query(None),
|
||||||
offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移
|
offset: str | None = Query(None), # TODO: 搞清楚并且添加这个奇怪的分页偏移
|
||||||
session: AsyncSession = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
db_user = await session.get(User, user)
|
db_user = await session.get(User, user)
|
||||||
if db_user is None or db_user.id == BANCHOBOT_ID:
|
if db_user is None or db_user.id == BANCHOBOT_ID:
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ from datetime import UTC, datetime, timedelta
|
|||||||
|
|
||||||
from app.database import RankHistory, UserStatistics
|
from app.database import RankHistory, UserStatistics
|
||||||
from app.database.rank_history import RankTop
|
from app.database.rank_history import RankTop
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import with_db
|
||||||
from app.dependencies.scheduler import get_scheduler
|
from app.dependencies.scheduler import get_scheduler
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from sqlmodel import col, exists, select, update
|
from sqlmodel import col, exists, select, update
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
@get_scheduler().scheduled_job(
|
@get_scheduler().scheduled_job(
|
||||||
@@ -18,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
async def calculate_user_rank(is_today: bool = False):
|
async def calculate_user_rank(is_today: bool = False):
|
||||||
today = datetime.now(UTC).date()
|
today = datetime.now(UTC).date()
|
||||||
target_date = today if is_today else today - timedelta(days=1)
|
target_date = today if is_today else today - timedelta(days=1)
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
for gamemode in GameMode:
|
for gamemode in GameMode:
|
||||||
users = await session.exec(
|
users = await session.exec(
|
||||||
select(UserStatistics)
|
select(UserStatistics)
|
||||||
|
|||||||
@@ -3,15 +3,14 @@ from __future__ import annotations
|
|||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import with_db
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from sqlmodel import exists, select
|
from sqlmodel import exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
async def create_banchobot():
|
async def create_banchobot():
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
is_exist = (
|
is_exist = (
|
||||||
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
|
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
|
||||||
).first()
|
).first()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import json
|
|||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database.playlists import Playlist
|
from app.database.playlists import Playlist
|
||||||
from app.database.room import Room
|
from app.database.room import Room
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import get_redis, with_db
|
||||||
from app.dependencies.scheduler import get_scheduler
|
from app.dependencies.scheduler import get_scheduler
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.models.metadata_hub import DailyChallengeInfo
|
from app.models.metadata_hub import DailyChallengeInfo
|
||||||
@@ -16,13 +16,12 @@ from app.models.room import RoomCategory
|
|||||||
from .room import create_playlist_room
|
from .room import create_playlist_room
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
async def create_daily_challenge_room(
|
async def create_daily_challenge_room(
|
||||||
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
|
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
|
||||||
) -> Room:
|
) -> Room:
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
today = datetime.now(UTC).date()
|
today = datetime.now(UTC).date()
|
||||||
return await create_playlist_room(
|
return await create_playlist_room(
|
||||||
session=session,
|
session=session,
|
||||||
@@ -52,7 +51,7 @@ async def daily_challenge_job():
|
|||||||
key = f"daily_challenge:{now.date()}"
|
key = f"daily_challenge:{now.date()}"
|
||||||
if not await redis.exists(key):
|
if not await redis.exists(key):
|
||||||
return
|
return
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
room = (
|
room = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Room).where(
|
select(Room).where(
|
||||||
|
|||||||
@@ -4,16 +4,15 @@ from app.config import settings
|
|||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import with_db
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from sqlalchemy import exists
|
from sqlalchemy import exists
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
async def create_rx_statistics():
|
async def create_rx_statistics():
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
users = (await session.exec(select(User.id))).all()
|
users = (await session.exec(select(User.id))).all()
|
||||||
for i in users:
|
for i in users:
|
||||||
if i == BANCHOBOT_ID:
|
if i == BANCHOBOT_ID:
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from app.database import PlaylistBestScore, Score
|
from app.database import PlaylistBestScore, Score
|
||||||
from app.database.playlist_best_score import get_position
|
from app.database.playlist_best_score import get_position
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import with_db
|
||||||
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
|
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
|
||||||
|
|
||||||
from .base import RedisSubscriber
|
from .base import RedisSubscriber
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.signalr.hub import MetadataHub
|
from app.signalr.hub import MetadataHub
|
||||||
@@ -45,7 +44,7 @@ class ScoreSubscriber(RedisSubscriber):
|
|||||||
async def _notify_room_score_processed(self, score_id: int):
|
async def _notify_room_score_processed(self, score_id: int):
|
||||||
if not self.metadata_hub:
|
if not self.metadata_hub:
|
||||||
return
|
return
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
score = await session.get(Score, score_id)
|
score = await session.get(Score, score_id)
|
||||||
if (
|
if (
|
||||||
not score
|
not score
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore
|
|||||||
from app.database.playlists import Playlist
|
from app.database.playlists import Playlist
|
||||||
from app.database.room import Room
|
from app.database.room import Room
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import get_redis, with_db
|
||||||
from app.models.metadata_hub import (
|
from app.models.metadata_hub import (
|
||||||
TOTAL_SCORE_DISTRIBUTION_BINS,
|
TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||||
DailyChallengeInfo,
|
DailyChallengeInfo,
|
||||||
@@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber
|
|||||||
from .hub import Client, Hub
|
from .hub import Client, Hub
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||||
|
|
||||||
@@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
if await redis.exists(f"metadata:online:{state.connection_id}"):
|
if await redis.exists(f"metadata:online:{state.connection_id}"):
|
||||||
await redis.delete(f"metadata:online:{state.connection_id}")
|
await redis.delete(f"metadata:online:{state.connection_id}")
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
user = (
|
user = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
self.get_or_create_state(client)
|
self.get_or_create_state(client)
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
friends = (
|
friends = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
return list(stats.playlist_item_stats.values())
|
return list(stats.playlist_item_stats.values())
|
||||||
|
|
||||||
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
playlist_ids = (
|
playlist_ids = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Playlist.id).where(
|
select(Playlist.id).where(
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent
|
|||||||
from app.database.playlists import Playlist
|
from app.database.playlists import Playlist
|
||||||
from app.database.relationship import Relationship, RelationshipType
|
from app.database.relationship import Relationship, RelationshipType
|
||||||
from app.database.room_participated_user import RoomParticipatedUser
|
from app.database.room_participated_user import RoomParticipatedUser
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import get_redis, with_db
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -50,7 +50,6 @@ from .hub import Client, Hub
|
|||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
GAMEPLAY_LOAD_TIMEOUT = 30
|
GAMEPLAY_LOAD_TIMEOUT = 30
|
||||||
|
|
||||||
@@ -61,7 +60,7 @@ class MultiplayerEventLogger:
|
|||||||
|
|
||||||
async def log_event(self, event: MultiplayerEvent):
|
async def log_event(self, event: MultiplayerEvent):
|
||||||
try:
|
try:
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
session.add(event)
|
session.add(event)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
store = self.get_or_create_state(client)
|
store = self.get_or_create_state(client)
|
||||||
if store.room_id != 0:
|
if store.room_id != 0:
|
||||||
raise InvokeException("You are already in a room")
|
raise InvokeException("You are already in a room")
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session:
|
async with session:
|
||||||
db_room = Room(
|
db_room = Room(
|
||||||
name=room.settings.name,
|
name=room.settings.name,
|
||||||
@@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
await server_room.match_type_handler.handle_join(user)
|
await server_room.match_type_handler.handle_join(user)
|
||||||
await self.event_logger.player_joined(room_id, user.user_id)
|
await self.event_logger.player_joined(room_id, user.user_id)
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
if (
|
if (
|
||||||
participated_user := (
|
participated_user := (
|
||||||
@@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def change_db_settings(self, room: ServerMultiplayerRoom):
|
async def change_db_settings(self, room: ServerMultiplayerRoom):
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(Room)
|
update(Room)
|
||||||
.where(col(Room.id) == room.room.room_id)
|
.where(col(Room.id) == room.room.room_id)
|
||||||
@@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
room,
|
room,
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
try:
|
try:
|
||||||
beatmap = await Beatmap.get_or_fetch(
|
beatmap = await Beatmap.get_or_fetch(
|
||||||
session, fetcher, bid=room.queue.current_item.beatmap_id
|
session, fetcher, bid=room.queue.current_item.beatmap_id
|
||||||
@@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
if not room.queue.current_item.freestyle:
|
if not room.queue.current_item.freestyle:
|
||||||
raise InvokeException("Current item does not allow free user styles.")
|
raise InvokeException("Current item does not allow free user styles.")
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
item_beatmap = await session.get(
|
item_beatmap = await session.get(
|
||||||
Beatmap, room.queue.current_item.beatmap_id
|
Beatmap, room.queue.current_item.beatmap_id
|
||||||
)
|
)
|
||||||
@@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
|
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
participated_user = (
|
participated_user = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
|
|
||||||
async def end_room(self, room: ServerMultiplayerRoom):
|
async def end_room(self, room: ServerMultiplayerRoom):
|
||||||
assert room.room.host
|
assert room.room.host
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(Room)
|
update(Room)
|
||||||
.where(col(Room.id) == room.room.room_id)
|
.where(col(Room.id) == room.room.room_id)
|
||||||
@@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
if user is None:
|
if user is None:
|
||||||
raise InvokeException("You are not in this room")
|
raise InvokeException("You are not in this room")
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
db_user = await session.get(User, user_id)
|
db_user = await session.get(User, user_id)
|
||||||
target_relationship = (
|
target_relationship = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp
|
|||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.database.score_token import ScoreToken
|
from app.database.score_token import ScoreToken
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import get_redis, with_db
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
@@ -38,7 +38,6 @@ from .hub import Client, Hub
|
|||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
READ_SCORE_TIMEOUT = 30
|
READ_SCORE_TIMEOUT = 30
|
||||||
REPLAY_LATEST_VER = 30000016
|
REPLAY_LATEST_VER = 30000016
|
||||||
@@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
fetcher = await get_fetcher()
|
fetcher = await get_fetcher()
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
try:
|
try:
|
||||||
beatmap = await Beatmap.get_or_fetch(
|
beatmap = await Beatmap.get_or_fetch(
|
||||||
@@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
assert store.checksum is not None
|
assert store.checksum is not None
|
||||||
assert store.ruleset_id is not None
|
assert store.ruleset_id is not None
|
||||||
assert store.score is not None
|
assert store.score is not None
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session:
|
async with session:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
score_record = None
|
score_record = None
|
||||||
@@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
self, user_id: int, state: SpectatorState, store: StoreClientState
|
self, user_id: int, state: SpectatorState, store: StoreClientState
|
||||||
) -> None:
|
) -> None:
|
||||||
async def _add_failtime():
|
async def _add_failtime():
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
failtime = await session.get(FailTime, state.beatmap_id)
|
failtime = await session.get(FailTime, state.beatmap_id)
|
||||||
total_length = (
|
total_length = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
return
|
return
|
||||||
before_time = int(messages[0][1]["time"])
|
before_time = int(messages[0][1]["time"])
|
||||||
await redis.delete(key)
|
await redis.delete(key)
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
|
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
|
||||||
statistics = (
|
statistics = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
|
|
||||||
self.add_to_group(client, self.group_id(target_id))
|
self.add_to_group(client, self.group_id(target_id))
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with with_db() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
username = (
|
username = (
|
||||||
await session.exec(select(User.username).where(User.id == user_id))
|
await session.exec(select(User.username).where(User.id == user_id))
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import uuid
|
|||||||
|
|
||||||
from app.database import User as DBUser
|
from app.database import User as DBUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import DBFactory, get_db_factory
|
||||||
from app.models.signalr import NegotiateResponse, Transport
|
from app.models.signalr import NegotiateResponse, Transport
|
||||||
|
|
||||||
from .hub import Hubs
|
from .hub import Hubs
|
||||||
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
||||||
from fastapi.security import SecurityScopes
|
from fastapi.security import SecurityScopes
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/signalr", include_in_schema=False)
|
router = APIRouter(prefix="/signalr", include_in_schema=False)
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ async def connect(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
id: str,
|
id: str,
|
||||||
authorization: str = Header(...),
|
authorization: str = Header(...),
|
||||||
db: AsyncSession = Depends(get_db),
|
factory: DBFactory = Depends(get_db_factory),
|
||||||
):
|
):
|
||||||
token = authorization[7:]
|
token = authorization[7:]
|
||||||
user_id = id.split(":")[0]
|
user_id = id.split(":")[0]
|
||||||
@@ -56,13 +55,14 @@ async def connect(
|
|||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if (
|
async for session in factory():
|
||||||
user := await get_current_user(
|
if (
|
||||||
SecurityScopes(scopes=["*"]), db, token_pw=token
|
user := await get_current_user(
|
||||||
)
|
session, SecurityScopes(scopes=["*"]), token_pw=token
|
||||||
) is None or str(user.id) != user_id:
|
)
|
||||||
await websocket.close(code=1008)
|
) is None or str(user.id) != user_id:
|
||||||
return
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user