refactor(app): update database code

This commit is contained in:
MingxuanGame
2025-08-18 16:37:30 +00:00
parent 6bae937e01
commit 1c65b21bb9
34 changed files with 167 additions and 188 deletions

View File

@@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore
from app.database.playlists import Playlist
from app.database.room import Room
from app.database.score import Score
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.models.metadata_hub import (
TOTAL_SCORE_DISTRIBUTION_BINS,
DailyChallengeInfo,
@@ -30,7 +30,6 @@ from app.service.subscribers.score_processed import ScoreSubscriber
from .hub import Client, Hub
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
@@ -97,7 +96,7 @@ class MetadataHub(Hub[MetadataClientState]):
redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
user = (
await session.exec(
@@ -118,7 +117,7 @@ class MetadataHub(Hub[MetadataClientState]):
user_id = int(client.connection_id)
self.get_or_create_state(client)
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
friends = (
await session.exec(
@@ -233,7 +232,7 @@ class MetadataHub(Hub[MetadataClientState]):
return list(stats.playlist_item_stats.values())
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
async with AsyncSession(engine) as session:
async with with_db() as session:
playlist_ids = (
await session.exec(
select(Playlist.id).where(

View File

@@ -12,7 +12,7 @@ from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist
from app.database.relationship import Relationship, RelationshipType
from app.database.room_participated_user import RoomParticipatedUser
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.log import logger
@@ -50,7 +50,6 @@ from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy import update
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30
@@ -61,7 +60,7 @@ class MultiplayerEventLogger:
async def log_event(self, event: MultiplayerEvent):
try:
async with AsyncSession(engine) as session:
async with with_db() as session:
session.add(event)
await session.commit()
except Exception as e:
@@ -192,7 +191,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
store = self.get_or_create_state(client)
if store.room_id != 0:
raise InvokeException("You are already in a room")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session:
db_room = Room(
name=room.settings.name,
@@ -282,7 +281,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await server_room.match_type_handler.handle_join(user)
await self.event_logger.player_joined(room_id, user.user_id)
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
if (
participated_user := (
@@ -398,7 +397,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
)
async def change_db_settings(self, room: ServerMultiplayerRoom):
async with AsyncSession(engine) as session:
async with with_db() as session:
await session.execute(
update(Room)
.where(col(Room.id) == room.room.room_id)
@@ -477,7 +476,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room,
user,
)
async with AsyncSession(engine) as session:
async with with_db() as session:
try:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=room.queue.current_item.beatmap_id
@@ -535,7 +534,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if not room.queue.current_item.freestyle:
raise InvokeException("Current item does not allow free user styles.")
async with AsyncSession(engine) as session:
async with with_db() as session:
item_beatmap = await session.get(
Beatmap, room.queue.current_item.beatmap_id
)
@@ -910,7 +909,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
participated_user = (
await session.exec(
@@ -954,7 +953,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host
async with AsyncSession(engine) as session:
async with with_db() as session:
await session.execute(
update(Room)
.where(col(Room.id) == room.room.room_id)
@@ -1171,7 +1170,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if user is None:
raise InvokeException("You are not in this room")
async with AsyncSession(engine) as session:
async with with_db() as session:
db_user = await session.get(User, user_id)
target_relationship = (
await session.exec(

View File

@@ -14,7 +14,7 @@ from app.database.failtime import FailTime, FailTimeResp
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.statistics import UserStatistics
from app.dependencies.database import engine, get_redis
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.exception import InvokeException
@@ -38,7 +38,6 @@ from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 30
REPLAY_LATEST_VER = 30000016
@@ -194,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
return
fetcher = await get_fetcher()
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
try:
beatmap = await Beatmap.get_or_fetch(
@@ -285,7 +284,7 @@ class SpectatorHub(Hub[StoreClientState]):
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session:
start_time = time.time()
score_record = None
@@ -332,7 +331,7 @@ class SpectatorHub(Hub[StoreClientState]):
self, user_id: int, state: SpectatorState, store: StoreClientState
) -> None:
async def _add_failtime():
async with AsyncSession(engine) as session:
async with with_db() as session:
failtime = await session.get(FailTime, state.beatmap_id)
total_length = (
await session.exec(
@@ -366,7 +365,7 @@ class SpectatorHub(Hub[StoreClientState]):
return
before_time = int(messages[0][1]["time"])
await redis.delete(key)
async with AsyncSession(engine) as session:
async with with_db() as session:
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
statistics = (
await session.exec(
@@ -430,7 +429,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.add_to_group(client, self.group_id(target_id))
async with AsyncSession(engine) as session:
async with with_db() as session:
async with session.begin():
username = (
await session.exec(select(User.username).where(User.id == user_id))

View File

@@ -8,7 +8,7 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.database import DBFactory, get_db_factory
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(prefix="/signalr", include_in_schema=False)
@@ -47,7 +46,7 @@ async def connect(
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: AsyncSession = Depends(get_db),
factory: DBFactory = Depends(get_db_factory),
):
token = authorization[7:]
user_id = id.split(":")[0]
@@ -56,13 +55,14 @@ async def connect(
await websocket.close(code=1008)
return
try:
if (
user := await get_current_user(
SecurityScopes(scopes=["*"]), db, token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
async for session in factory():
if (
user := await get_current_user(
session, SecurityScopes(scopes=["*"]), token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
except HTTPException:
await websocket.close(code=1008)
return