diff --git a/app/database/playlists.py b/app/database/playlists.py index 73755bf..adcf844 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -26,7 +26,11 @@ if TYPE_CHECKING: class PlaylistBase(SQLModel, UTCBaseModel): - id: int = Field(index=True) + id: int | None = Field( + default=None, + primary_key=True, + index=True, + ) owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) ruleset_id: int = Field(ge=0, le=3) expired: bool = Field(default=False) @@ -116,9 +120,11 @@ class Playlist(PlaylistBase, table=True): async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): db_playlist = await cls.from_hub(playlist, room_id, session) session.add(db_playlist) + await session.flush() + assert db_playlist.id is not None, "db_playlist.id should be set after flush" + playlist.id = db_playlist.id await session.commit() - await session.refresh(db_playlist) - playlist.id = db_playlist.id + @classmethod async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): diff --git a/app/database/room.py b/app/database/room.py index 2fffab0..572e0d4 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -74,6 +74,7 @@ class Room(AsyncAttrs, RoomBase, table=True): ) + class RoomResp(RoomBase): id: int has_password: bool = False diff --git a/app/router/lio.py b/app/router/lio.py index d8ebcfd..322582f 100644 --- a/app/router/lio.py +++ b/app/router/lio.py @@ -4,16 +4,19 @@ from __future__ import annotations import json from typing import Any, Dict, List -from fastapi import APIRouter, HTTPException, Request, status, Query +from fastapi import APIRouter, HTTPException, Request, status, Query, Depends from pydantic import BaseModel from sqlmodel import col, select, desc from sqlalchemy import update +from redis.asyncio import Redis from app.database.lazer_user import User from app.database.playlists import Playlist as DBPlaylist from app.database.room import Room from app.database.room_participated_user import RoomParticipatedUser -from app.dependencies.database import Database +from app.dependencies.database import Database, get_redis +from app.dependencies.fetcher import get_fetcher +from app.fetcher import Fetcher from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem from app.models.room import MatchType, QueueMode, RoomStatus from app.utils import utcnow @@ -224,25 +227,6 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int await _update_room_participant_count(db, room_id) -async def _update_room_participant_count(db: Database, room_id: int) -> None: - """Update the participant count for a room.""" - # Count active participants - active_participants = await db.execute( - select(RoomParticipatedUser).where( - RoomParticipatedUser.room_id == room_id, - col(RoomParticipatedUser.left_at).is_(None) - ) - ) - count = len(active_participants.all()) - - # Update room participant count using SQLAlchemy update statement - await db.execute( - update(Room) - .where(col(Room.id) == room_id) - .values(participant_count=count) - ) - - async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None: """Verify room password if required.""" room_result = await db.execute( @@ -315,6 +299,66 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) - db.add(participant) +class BeatmapEnsureRequest(BaseModel): + """Request model for ensuring beatmap exists.""" + beatmap_id: int + + +async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]: + """ + 确保谱面存在(包括元数据和原始文件缓存)。 + + Args: + db: 数据库会话 + fetcher: API获取器 + redis: Redis连接 + beatmap_id: 谱面ID + + Returns: + Dict: 包含状态信息的响应 + """ + try: + # 1. 确保谱面元数据存在于数据库中 + from app.database.beatmap import Beatmap + beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) + + if not beatmap: + return { + "success": False, + "error": f"Beatmap {beatmap_id} not found", + "beatmap_id": beatmap_id + } + + # 2. 预缓存谱面原始文件 + cache_key = f"beatmap:{beatmap_id}:raw" + cached = await redis.exists(cache_key) + + if not cached: + # 异步预加载原始文件到缓存 + try: + await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + print(f"Successfully cached raw beatmap file for {beatmap_id}") + except Exception as e: + print(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}") + # 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在) + + return { + "success": True, + "beatmap_id": beatmap_id, + "metadata_cached": True, + "raw_file_cached": await redis.exists(cache_key), + "beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]" + } + + except Exception as e: + print(f"Error ensuring beatmap {beatmap_id}: {e}") + return { + "success": False, + "error": str(e), + "beatmap_id": beatmap_id + } + + # ===== API ENDPOINTS ===== @router.post("/multiplayer/rooms") @@ -325,42 +369,42 @@ async def create_multiplayer_room( timestamp: str = "", ) -> int: """Create a new multiplayer room with initial playlist.""" + #try: + # Verify request signature + body = await request.body() + if not verify_request_signature(request, str(timestamp), body): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid request signature" + ) + + # Parse room data if string + if isinstance(room_data, str): + room_data = json.loads(room_data) + + print(f"Creating room with data: {room_data}") + + # Create room + room, host_user_id = await _create_room(db, room_data) + room_id = room.id + try: - # Verify request signature - body = await request.body() - if not verify_request_signature(request, str(timestamp), body): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid request signature" - ) + # Add playlist items + await _add_playlist_items(db, room_id, room_data, host_user_id) + + # Add host as participant + #await _add_host_as_participant(db, room_id, host_user_id) + + await db.commit() + return room_id + + except HTTPException: + # Clean up room if playlist creation fails + await db.delete(room) + await db.commit() + raise - # Parse room data if string - if isinstance(room_data, str): - room_data = json.loads(room_data) - - print(f"Creating room with data: {room_data}") - - # Create room - room, host_user_id = await _create_room(db, room_data) - room_id = room.id - - try: - # Add playlist items - await _add_playlist_items(db, room_id, room_data, host_user_id) - - # Add host as participant - #await _add_host_as_participant(db, room_id, host_user_id) - - await db.commit() - return room_id - - except HTTPException: - # Clean up room if playlist creation fails - await db.delete(room) - await db.commit() - raise - - except json.JSONDecodeError as e: + """ except json.JSONDecodeError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {str(e)}" @@ -371,7 +415,7 @@ async def create_multiplayer_room( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create room: {str(e)}" - ) + ) """ async def _update_room_participant_count(db: Database, room_id: int) -> int: @@ -602,4 +646,51 @@ async def add_user_to_room( await db.commit() print(f"Successfully added user {user_id} to room {room_id}") - return {"success": True} \ No newline at end of file + return {"success": True} + + +@router.post("/beatmaps/ensure") +async def ensure_beatmap_present( + request: Request, + beatmap_data: BeatmapEnsureRequest, + db: Database, + redis: Redis = Depends(get_redis), + fetcher: Fetcher = Depends(get_fetcher), + timestamp: str = "", +) -> Dict[str, Any]: + """ + 确保谱面在服务器中存在(包括元数据和原始文件缓存)。 + + 这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用, + 避免在需要时才获取导致的延迟。 + """ + try: + # Verify request signature + body = await request.body() + if not verify_request_signature(request, timestamp, body): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid request signature" + ) + + beatmap_id = beatmap_data.beatmap_id + print(f"Ensuring beatmap {beatmap_id} is present") + + # 确保谱面存在 + result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id) + + # 提交数据库更改 + await db.commit() + + print(f"Ensure beatmap {beatmap_id} result: {result}") + return result + + except HTTPException: + raise + except Exception as e: + await db.rollback() + print(f"Error ensuring beatmap: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to ensure beatmap: {str(e)}" + ) \ No newline at end of file