Add endpoint to ensure beatmap presence and cache
Introduces a new /beatmaps/ensure API endpoint to verify and cache beatmap metadata and raw files. Updates Playlist model to use auto-incrementing primary key and improves playlist DB insertion logic. Minor formatting and import changes in room and lio modules.
This commit is contained in:
@@ -26,7 +26,11 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
index=True,
|
||||
)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
ruleset_id: int = Field(ge=0, le=3)
|
||||
expired: bool = Field(default=False)
|
||||
@@ -116,9 +120,11 @@ class Playlist(PlaylistBase, table=True):
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.flush()
|
||||
assert db_playlist.id is not None, "db_playlist.id should be set after flush"
|
||||
playlist.id = db_playlist.id
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
playlist.id = db_playlist.id
|
||||
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
|
||||
@@ -74,6 +74,7 @@ class Room(AsyncAttrs, RoomBase, table=True):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
|
||||
@@ -4,16 +4,19 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status, Query
|
||||
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select, desc
|
||||
from sqlalchemy import update
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.database.lazer_user import User
|
||||
from app.database.playlists import Playlist as DBPlaylist
|
||||
from app.database.room import Room
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||||
from app.utils import utcnow
|
||||
@@ -224,25 +227,6 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int
|
||||
await _update_room_participant_count(db, room_id)
|
||||
|
||||
|
||||
async def _update_room_participant_count(db: Database, room_id: int) -> None:
|
||||
"""Update the participant count for a room."""
|
||||
# Count active participants
|
||||
active_participants = await db.execute(
|
||||
select(RoomParticipatedUser).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
)
|
||||
)
|
||||
count = len(active_participants.all())
|
||||
|
||||
# Update room participant count using SQLAlchemy update statement
|
||||
await db.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room_id)
|
||||
.values(participant_count=count)
|
||||
)
|
||||
|
||||
|
||||
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
|
||||
"""Verify room password if required."""
|
||||
room_result = await db.execute(
|
||||
@@ -315,6 +299,66 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
db.add(participant)
|
||||
|
||||
|
||||
class BeatmapEnsureRequest(BaseModel):
|
||||
"""Request model for ensuring beatmap exists."""
|
||||
beatmap_id: int
|
||||
|
||||
|
||||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
确保谱面存在(包括元数据和原始文件缓存)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
fetcher: API获取器
|
||||
redis: Redis连接
|
||||
beatmap_id: 谱面ID
|
||||
|
||||
Returns:
|
||||
Dict: 包含状态信息的响应
|
||||
"""
|
||||
try:
|
||||
# 1. 确保谱面元数据存在于数据库中
|
||||
from app.database.beatmap import Beatmap
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
|
||||
if not beatmap:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Beatmap {beatmap_id} not found",
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
|
||||
# 2. 预缓存谱面原始文件
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
cached = await redis.exists(cache_key)
|
||||
|
||||
if not cached:
|
||||
# 异步预加载原始文件到缓存
|
||||
try:
|
||||
await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
print(f"Successfully cached raw beatmap file for {beatmap_id}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}")
|
||||
# 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"beatmap_id": beatmap_id,
|
||||
"metadata_cached": True,
|
||||
"raw_file_cached": await redis.exists(cache_key),
|
||||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error ensuring beatmap {beatmap_id}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
|
||||
|
||||
# ===== API ENDPOINTS =====
|
||||
|
||||
@router.post("/multiplayer/rooms")
|
||||
@@ -325,42 +369,42 @@ async def create_multiplayer_room(
|
||||
timestamp: str = "",
|
||||
) -> int:
|
||||
"""Create a new multiplayer room with initial playlist."""
|
||||
#try:
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
|
||||
# Parse room data if string
|
||||
if isinstance(room_data, str):
|
||||
room_data = json.loads(room_data)
|
||||
|
||||
print(f"Creating room with data: {room_data}")
|
||||
|
||||
# Create room
|
||||
room, host_user_id = await _create_room(db, room_data)
|
||||
room_id = room.id
|
||||
|
||||
try:
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
# Add playlist items
|
||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||
|
||||
# Add host as participant
|
||||
#await _add_host_as_participant(db, room_id, host_user_id)
|
||||
|
||||
await db.commit()
|
||||
return room_id
|
||||
|
||||
except HTTPException:
|
||||
# Clean up room if playlist creation fails
|
||||
await db.delete(room)
|
||||
await db.commit()
|
||||
raise
|
||||
|
||||
# Parse room data if string
|
||||
if isinstance(room_data, str):
|
||||
room_data = json.loads(room_data)
|
||||
|
||||
print(f"Creating room with data: {room_data}")
|
||||
|
||||
# Create room
|
||||
room, host_user_id = await _create_room(db, room_data)
|
||||
room_id = room.id
|
||||
|
||||
try:
|
||||
# Add playlist items
|
||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||
|
||||
# Add host as participant
|
||||
#await _add_host_as_participant(db, room_id, host_user_id)
|
||||
|
||||
await db.commit()
|
||||
return room_id
|
||||
|
||||
except HTTPException:
|
||||
# Clean up room if playlist creation fails
|
||||
await db.delete(room)
|
||||
await db.commit()
|
||||
raise
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
""" except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON: {str(e)}"
|
||||
@@ -371,7 +415,7 @@ async def create_multiplayer_room(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create room: {str(e)}"
|
||||
)
|
||||
) """
|
||||
|
||||
|
||||
async def _update_room_participant_count(db: Database, room_id: int) -> int:
|
||||
@@ -602,4 +646,51 @@ async def add_user_to_room(
|
||||
|
||||
await db.commit()
|
||||
print(f"Successfully added user {user_id} to room {room_id}")
|
||||
return {"success": True}
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/beatmaps/ensure")
|
||||
async def ensure_beatmap_present(
|
||||
request: Request,
|
||||
beatmap_data: BeatmapEnsureRequest,
|
||||
db: Database,
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
timestamp: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
|
||||
|
||||
这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用,
|
||||
避免在需要时才获取导致的延迟。
|
||||
"""
|
||||
try:
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, timestamp, body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
|
||||
beatmap_id = beatmap_data.beatmap_id
|
||||
print(f"Ensuring beatmap {beatmap_id} is present")
|
||||
|
||||
# 确保谱面存在
|
||||
result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id)
|
||||
|
||||
# 提交数据库更改
|
||||
await db.commit()
|
||||
|
||||
print(f"Ensure beatmap {beatmap_id} result: {result}")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
print(f"Error ensuring beatmap: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to ensure beatmap: {str(e)}"
|
||||
)
|
||||
Reference in New Issue
Block a user