refactor(app): update database code
This commit is contained in:
@@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.models.metadata_hub import (
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
DailyChallengeInfo,
|
||||
@@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber
|
||||
from .hub import Client, Hub
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
|
||||
@@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
redis = get_redis()
|
||||
if await redis.exists(f"metadata:online:{state.connection_id}"):
|
||||
await redis.delete(f"metadata:online:{state.connection_id}")
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
user = (
|
||||
await session.exec(
|
||||
@@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
user_id = int(client.connection_id)
|
||||
self.get_or_create_state(client)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
friends = (
|
||||
await session.exec(
|
||||
@@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
return list(stats.playlist_item_stats.values())
|
||||
|
||||
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
playlist_ids = (
|
||||
await session.exec(
|
||||
select(Playlist.id).where(
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.relationship import Relationship, RelationshipType
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
@@ -50,7 +50,6 @@ from .hub import Client, Hub
|
||||
from httpx import HTTPError
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
GAMEPLAY_LOAD_TIMEOUT = 30
|
||||
|
||||
@@ -61,7 +60,7 @@ class MultiplayerEventLogger:
|
||||
|
||||
async def log_event(self, event: MultiplayerEvent):
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
@@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
store = self.get_or_create_state(client)
|
||||
if store.room_id != 0:
|
||||
raise InvokeException("You are already in a room")
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session:
|
||||
db_room = Room(
|
||||
name=room.settings.name,
|
||||
@@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
await server_room.match_type_handler.handle_join(user)
|
||||
await self.event_logger.player_joined(room_id, user.user_id)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
if (
|
||||
participated_user := (
|
||||
@@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
)
|
||||
|
||||
async def change_db_settings(self, room: ServerMultiplayerRoom):
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
await session.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room.room.room_id)
|
||||
@@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
room,
|
||||
user,
|
||||
)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=room.queue.current_item.beatmap_id
|
||||
@@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if not room.queue.current_item.freestyle:
|
||||
raise InvokeException("Current item does not allow free user styles.")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
item_beatmap = await session.get(
|
||||
Beatmap, room.queue.current_item.beatmap_id
|
||||
)
|
||||
@@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
redis = get_redis()
|
||||
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
@@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
|
||||
async def end_room(self, room: ServerMultiplayerRoom):
|
||||
assert room.room.host
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
await session.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room.room.room_id)
|
||||
@@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if user is None:
|
||||
raise InvokeException("You are not in this room")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
db_user = await session.get(User, user_id)
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.exception import InvokeException
|
||||
@@ -38,7 +38,6 @@ from .hub import Client, Hub
|
||||
from httpx import HTTPError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 30
|
||||
REPLAY_LATEST_VER = 30000016
|
||||
@@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
return
|
||||
|
||||
fetcher = await get_fetcher()
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
@@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
assert store.checksum is not None
|
||||
assert store.ruleset_id is not None
|
||||
assert store.score is not None
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session:
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
@@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self, user_id: int, state: SpectatorState, store: StoreClientState
|
||||
) -> None:
|
||||
async def _add_failtime():
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
failtime = await session.get(FailTime, state.beatmap_id)
|
||||
total_length = (
|
||||
await session.exec(
|
||||
@@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
return
|
||||
before_time = int(messages[0][1]["time"])
|
||||
await redis.delete(key)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
|
||||
statistics = (
|
||||
await session.exec(
|
||||
@@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
|
||||
self.add_to_group(client, self.group_id(target_id))
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
username = (
|
||||
await session.exec(select(User.username).where(User.id == user_id))
|
||||
|
||||
@@ -8,7 +8,7 @@ import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import DBFactory, get_db_factory
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
from .hub import Hubs
|
||||
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
||||
from fastapi.security import SecurityScopes
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/signalr", include_in_schema=False)
|
||||
|
||||
@@ -47,7 +46,7 @@ async def connect(
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
@@ -56,13 +55,14 @@ async def connect(
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
try:
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["*"]), db, token_pw=token
|
||||
)
|
||||
) is None or str(user.id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
async for session in factory():
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["*"]), token_pw=token
|
||||
)
|
||||
) is None or str(user.id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except HTTPException:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user