Adds standard JWT claims (audience and issuer) to access tokens and updates config for these fields. Refactors multiplayer room chat channel logic to ensure reliable user join/leave with retry mechanisms, improves error handling and cleanup, and ensures host is correctly added as a participant. Updates Docker entrypoint for better compatibility and connection handling, modifies Docker Compose and Nginx config for improved deployment and proxy header forwarding.
886 lines
32 KiB
Python
886 lines
32 KiB
Python
"""LIO (Legacy IO) router for osu-server-spectator compatibility."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from typing import Any, Dict, List
|
||
|
||
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
|
||
from pydantic import BaseModel
|
||
from sqlmodel import col, select, desc
|
||
from sqlalchemy import update, func
|
||
from redis.asyncio import Redis
|
||
|
||
from app.database.lazer_user import User
|
||
from app.database.playlists import Playlist as DBPlaylist
|
||
from app.database.room import Room
|
||
from app.database.room_participated_user import RoomParticipatedUser
|
||
from app.dependencies.database import Database, get_redis
|
||
from app.dependencies.fetcher import get_fetcher
|
||
from app.fetcher import Fetcher
|
||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||
from app.utils import utcnow
|
||
from app.database.chat import ChatChannel, ChannelType # ChatChannel 模型 & 枚举
|
||
from .notification.server import server
|
||
from app.log import logger
|
||
|
||
|
||
router = APIRouter(prefix="/_lio", tags=["LIO"])
|
||
|
||
|
||
async def _ensure_room_chat_channel(
|
||
db: Database,
|
||
room: Room,
|
||
host_user_id: int,
|
||
) -> ChatChannel:
|
||
"""
|
||
为房间创建/确保存在对应的聊天频道,channel_id 与 room.channel_id 保持一致,
|
||
名称使用 mp_{room.id}(可按需调整)。
|
||
"""
|
||
# 1) 按 channel_id 查是否已存在
|
||
ch = (await db.exec(
|
||
select(ChatChannel).where(ChatChannel.channel_id == room.channel_id)
|
||
)).first()
|
||
|
||
if ch is None:
|
||
# 确保为房间分配一个有效的 channel_id(ChatChannel.channel_id 需要 int)
|
||
if room.channel_id is None:
|
||
channel_id_value = await _alloc_channel_id(db)
|
||
# 同步回写到房间以保证二者一致
|
||
room.channel_id = channel_id_value
|
||
db.add(room)
|
||
else:
|
||
channel_id_value = int(room.channel_id)
|
||
|
||
ch = ChatChannel(
|
||
channel_id=channel_id_value, # 与房间绑定的同一 channel_id(确保为 int)
|
||
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
|
||
description=f"Multiplayer room {room.id} chat",
|
||
type=ChannelType.MULTIPLAYER,
|
||
)
|
||
db.add(ch)
|
||
await db.commit()
|
||
await db.refresh(ch)
|
||
|
||
return ch
|
||
|
||
|
||
async def _alloc_channel_id(db: Database) -> int:
|
||
"""
|
||
自动分配一个 >100 的 channel_id。
|
||
策略:取当前 rooms.channel_id 的最大值(没有时从100开始)+1。
|
||
"""
|
||
result = await db.execute(select(func.max(Room.channel_id)))
|
||
current_max = result.scalar() or 100
|
||
return int(current_max) + 1
|
||
|
||
|
||
class RoomCreateRequest(BaseModel):
|
||
"""Request model for creating a multiplayer room."""
|
||
name: str
|
||
user_id: int
|
||
password: str | None = None
|
||
match_type: str = "HeadToHead"
|
||
queue_mode: str = "HostOnly"
|
||
initial_playlist: List[Dict[str, Any]] = []
|
||
playlist: List[Dict[str, Any]] = []
|
||
|
||
|
||
def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool:
|
||
"""
|
||
Verify HMAC signature for shared interop requests.
|
||
|
||
Args:
|
||
request: FastAPI request object
|
||
timestamp: Request timestamp
|
||
body: Request body bytes
|
||
|
||
Returns:
|
||
bool: True if signature is valid
|
||
|
||
Note:
|
||
Currently skips verification in development.
|
||
In production, implement proper HMAC verification.
|
||
"""
|
||
# TODO: Implement proper HMAC verification for production
|
||
return True
|
||
|
||
|
||
async def _validate_user_exists(db: Database, user_id: int) -> User:
|
||
"""Validate that a user exists in the database."""
|
||
user_result = await db.execute(select(User).where(User.id == user_id))
|
||
user = user_result.scalar_one_or_none()
|
||
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"User with ID {user_id} not found"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
def _parse_room_enums(match_type: str, queue_mode: str) -> tuple[MatchType, QueueMode]:
|
||
"""Parse and validate room type enums."""
|
||
try:
|
||
match_type_enum = MatchType(match_type.lower())
|
||
except ValueError:
|
||
match_type_enum = MatchType.HEAD_TO_HEAD
|
||
|
||
try:
|
||
queue_mode_enum = QueueMode(queue_mode.lower())
|
||
except ValueError:
|
||
queue_mode_enum = QueueMode.HOST_ONLY
|
||
|
||
return match_type_enum, queue_mode_enum
|
||
|
||
|
||
def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_user_id: int) -> Dict[str, Any]:
|
||
"""
|
||
Normalize playlist item data with default values.
|
||
|
||
Args:
|
||
item_data: Raw playlist item data
|
||
default_order: Default playlist order
|
||
host_user_id: Host user ID for default owner
|
||
|
||
Returns:
|
||
Dict with normalized playlist item data
|
||
"""
|
||
# Use host_user_id if owner_id is 0 or not provided
|
||
owner_id = item_data.get("owner_id", host_user_id)
|
||
if owner_id == 0:
|
||
owner_id = host_user_id
|
||
|
||
return {
|
||
"owner_id": owner_id,
|
||
"ruleset_id": item_data.get("ruleset_id", 0),
|
||
"beatmap_id": item_data.get("beatmap_id"),
|
||
"required_mods": item_data.get("required_mods", []),
|
||
"allowed_mods": item_data.get("allowed_mods", []),
|
||
"expired": bool(item_data.get("expired", False)),
|
||
"playlist_order": item_data.get("playlist_order", default_order),
|
||
"played_at": item_data.get("played_at", None),
|
||
"freestyle": bool(item_data.get("freestyle", True)),
|
||
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
|
||
"star_rating": item_data.get("star_rating", 0.0),
|
||
}
|
||
|
||
|
||
def _validate_playlist_items(items: List[Dict[str, Any]]) -> None:
|
||
"""Validate playlist items data."""
|
||
if not items:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="At least one playlist item is required to create a room"
|
||
)
|
||
|
||
for idx, item in enumerate(items):
|
||
if item["beatmap_id"] is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Playlist item at index {idx} missing beatmap_id"
|
||
)
|
||
|
||
ruleset_id = item["ruleset_id"]
|
||
if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}"
|
||
)
|
||
|
||
|
||
async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, int]:
|
||
host_user_id = room_data.get("user_id")
|
||
room_name = room_data.get("name", "Unnamed Room")
|
||
password = room_data.get("password")
|
||
match_type = room_data.get("match_type", "HeadToHead")
|
||
queue_mode = room_data.get("queue_mode", "HostOnly")
|
||
|
||
if not host_user_id or not isinstance(host_user_id, int):
|
||
raise HTTPException(status_code=400, detail="Missing or invalid user_id")
|
||
|
||
await _validate_user_exists(db, host_user_id)
|
||
|
||
match_type_enum, queue_mode_enum = _parse_room_enums(match_type, queue_mode)
|
||
|
||
# 自动分配一个 channel_id (>100)
|
||
channel_id = await _alloc_channel_id(db)
|
||
|
||
# 创建房间
|
||
room = Room(
|
||
name=room_name,
|
||
host_id=host_user_id,
|
||
password=password if password else None,
|
||
type=match_type_enum,
|
||
queue_mode=queue_mode_enum,
|
||
status=RoomStatus.IDLE,
|
||
participant_count=1,
|
||
auto_skip=False,
|
||
auto_start_duration=0,
|
||
channel_id=channel_id,
|
||
)
|
||
|
||
db.add(room)
|
||
await db.commit()
|
||
await db.refresh(room)
|
||
|
||
return room, host_user_id
|
||
|
||
|
||
async def _add_playlist_items(db: Database, room_id: int, room_data: Dict[str, Any], host_user_id: int) -> None:
|
||
"""Add playlist items to the room."""
|
||
initial_playlist = room_data.get("initial_playlist", [])
|
||
legacy_playlist = room_data.get("playlist", [])
|
||
|
||
items_raw: List[Dict[str, Any]] = []
|
||
|
||
# Process initial playlist
|
||
for i, item in enumerate(initial_playlist):
|
||
if hasattr(item, "dict"):
|
||
item = item.dict()
|
||
items_raw.append(_coerce_playlist_item(item, i, host_user_id))
|
||
|
||
# Process legacy playlist
|
||
start_index = len(items_raw)
|
||
for j, item in enumerate(legacy_playlist, start=start_index):
|
||
items_raw.append(_coerce_playlist_item(item, j, host_user_id))
|
||
|
||
# Validate playlist items
|
||
_validate_playlist_items(items_raw)
|
||
|
||
# Insert playlist items
|
||
for item_data in items_raw:
|
||
hub_item = HubPlaylistItem(
|
||
id=-1, # Placeholder, will be assigned by add_to_db
|
||
owner_id=item_data["owner_id"],
|
||
ruleset_id=item_data["ruleset_id"],
|
||
expired=item_data["expired"],
|
||
playlist_order=item_data["playlist_order"],
|
||
played_at=item_data["played_at"],
|
||
allowed_mods=item_data["allowed_mods"],
|
||
required_mods=item_data["required_mods"],
|
||
beatmap_id=item_data["beatmap_id"],
|
||
freestyle=item_data["freestyle"],
|
||
beatmap_checksum=item_data["beatmap_checksum"],
|
||
star_rating=item_data["star_rating"],
|
||
)
|
||
await DBPlaylist.add_to_db(hub_item, room_id=room_id, session=db)
|
||
|
||
|
||
async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int) -> None:
|
||
"""Add the host as a room participant and update participant count."""
|
||
participant = RoomParticipatedUser(room_id=room_id, user_id=host_user_id)
|
||
db.add(participant)
|
||
|
||
await _update_room_participant_count(db, room_id)
|
||
|
||
|
||
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
|
||
"""Verify room password if required."""
|
||
room_result = await db.execute(
|
||
select(Room).where(col(Room.id) == room_id)
|
||
)
|
||
room = room_result.scalar_one_or_none()
|
||
|
||
if room is None:
|
||
logger.debug(f"Room {room_id} not found")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="Room not found"
|
||
)
|
||
|
||
logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}")
|
||
|
||
# If room has password but none provided
|
||
if room.password and not provided_password:
|
||
logger.debug(f"Room {room_id} requires password but none provided")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Password required"
|
||
)
|
||
|
||
# If room has password and provided password doesn't match
|
||
if room.password and provided_password and provided_password != room.password:
|
||
logger.debug(f"Room {room_id} password mismatch")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Invalid password"
|
||
)
|
||
|
||
logger.debug(f"Room {room_id} password verification passed")
|
||
|
||
|
||
async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -> None:
|
||
"""添加用户为参与者或更新现有参与记录。"""
|
||
# 检查用户是否已有活跃的参与记录
|
||
existing_result = await db.execute(
|
||
select(RoomParticipatedUser.id).where(
|
||
RoomParticipatedUser.room_id == room_id,
|
||
RoomParticipatedUser.user_id == user_id,
|
||
col(RoomParticipatedUser.left_at).is_(None)
|
||
)
|
||
)
|
||
existing_ids = existing_result.scalars().all() # 获取所有匹配的ID
|
||
|
||
if existing_ids:
|
||
# 如果存在多条记录,清理重复项,只保留最新的一条
|
||
if len(existing_ids) > 1:
|
||
logger.debug(f"警告:用户 {user_id} 在房间 {room_id} 中发现 {len(existing_ids)} 条活跃参与记录")
|
||
|
||
# 将除第一条外的所有记录标记为已离开(清理重复记录)
|
||
for extra_id in existing_ids[1:]:
|
||
await db.execute(
|
||
update(RoomParticipatedUser)
|
||
.where(col(RoomParticipatedUser.id) == extra_id)
|
||
.values(left_at=utcnow())
|
||
)
|
||
|
||
# 更新剩余的活跃参与记录(刷新加入时间)
|
||
await db.execute(
|
||
update(RoomParticipatedUser)
|
||
.where(col(RoomParticipatedUser.id) == existing_ids[0])
|
||
.values(joined_at=utcnow())
|
||
)
|
||
else:
|
||
# 创建新的参与记录
|
||
participant = RoomParticipatedUser(room_id=room_id, user_id=user_id)
|
||
db.add(participant)
|
||
|
||
|
||
class BeatmapEnsureRequest(BaseModel):
|
||
"""Request model for ensuring beatmap exists."""
|
||
beatmap_id: int
|
||
|
||
|
||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]:
|
||
"""
|
||
确保谱面存在(包括元数据和原始文件缓存)。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
fetcher: API获取器
|
||
redis: Redis连接
|
||
beatmap_id: 谱面ID
|
||
|
||
Returns:
|
||
Dict: 包含状态信息的响应
|
||
"""
|
||
try:
|
||
# 1. 确保谱面元数据存在于数据库中
|
||
from app.database.beatmap import Beatmap
|
||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||
|
||
if not beatmap:
|
||
return {
|
||
"success": False,
|
||
"error": f"Beatmap {beatmap_id} not found",
|
||
"beatmap_id": beatmap_id
|
||
}
|
||
|
||
# 2. 预缓存谱面原始文件
|
||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||
cached = await redis.exists(cache_key)
|
||
|
||
if not cached:
|
||
# 异步预加载原始文件到缓存
|
||
try:
|
||
await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||
logger.debug(f"Successfully cached raw beatmap file for {beatmap_id}")
|
||
except Exception as e:
|
||
logger.debug(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}")
|
||
# 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在)
|
||
|
||
return {
|
||
"success": True,
|
||
"beatmap_id": beatmap_id,
|
||
"metadata_cached": True,
|
||
"raw_file_cached": await redis.exists(cache_key),
|
||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]"
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"beatmap_id": beatmap_id
|
||
}
|
||
|
||
async def _update_room_participant_count(db: Database, room_id: int) -> int:
|
||
"""更新房间参与者数量并返回当前数量。"""
|
||
# 统计活跃参与者
|
||
active_participants_result = await db.execute(
|
||
select(RoomParticipatedUser.user_id).where(
|
||
RoomParticipatedUser.room_id == room_id,
|
||
col(RoomParticipatedUser.left_at).is_(None)
|
||
)
|
||
)
|
||
active_participants = active_participants_result.all()
|
||
count = len(active_participants)
|
||
|
||
# 更新房间参与者数量
|
||
await db.execute(
|
||
update(Room)
|
||
.where(col(Room.id) == room_id)
|
||
.values(participant_count=count)
|
||
)
|
||
|
||
return count
|
||
|
||
|
||
async def _end_room_if_empty(db: Database, room_id: int) -> bool:
|
||
"""如果房间为空,则标记房间结束。返回是否结束了房间。"""
|
||
# 检查房间是否还有活跃参与者
|
||
participant_count = await _update_room_participant_count(db, room_id)
|
||
|
||
if participant_count == 0:
|
||
# 房间为空,标记结束
|
||
now = utcnow()
|
||
await db.execute(
|
||
update(Room)
|
||
.where(col(Room.id) == room_id)
|
||
.values(
|
||
ends_at=now,
|
||
status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态
|
||
participant_count=0
|
||
)
|
||
)
|
||
logger.debug(f"Room {room_id} ended automatically (no participants remaining)")
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_user_id: int) -> bool:
|
||
"""处理房主离开的逻辑:转让房主权限或结束房间。返回是否结束了房间。"""
|
||
# 查找其他活跃参与者来转让房主权限
|
||
remaining_result = await db.execute(
|
||
select(RoomParticipatedUser.user_id)
|
||
.where(
|
||
col(RoomParticipatedUser.room_id) == room_id,
|
||
col(RoomParticipatedUser.user_id) != leaving_user_id,
|
||
col(RoomParticipatedUser.left_at).is_(None)
|
||
)
|
||
.order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序
|
||
)
|
||
remaining_participants = remaining_result.all()
|
||
|
||
if remaining_participants:
|
||
# 将房主权限转让给最早加入的用户
|
||
new_owner_id = remaining_participants[0][0] # 获取 user_id
|
||
await db.execute(
|
||
update(Room)
|
||
.where(col(Room.id) == room_id)
|
||
.values(host_id=new_owner_id)
|
||
)
|
||
logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}")
|
||
return False # 房间继续存在
|
||
else:
|
||
# 没有其他参与者,结束房间
|
||
return await _end_room_if_empty(db, room_id)
|
||
|
||
|
||
async def _safely_join_channel(channel_id: int, user_id: int, max_retries: int = 3) -> bool:
|
||
"""安全地让用户加入聊天频道,带重试机制"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
await server.join_room_channel(int(channel_id), int(user_id))
|
||
logger.debug(f"Successfully joined user {user_id} to channel {channel_id} on attempt {attempt + 1}")
|
||
return True
|
||
except Exception as e:
|
||
logger.debug(f"Attempt {attempt + 1} failed to join user {user_id} to channel {channel_id}: {e}")
|
||
if attempt == max_retries - 1:
|
||
logger.debug(f"Failed to join user {user_id} to channel {channel_id} after {max_retries} attempts")
|
||
return False
|
||
return False
|
||
|
||
|
||
async def _safely_leave_channel(channel_id: int, user_id: int, max_retries: int = 3) -> bool:
|
||
"""安全地让用户离开聊天频道,带重试机制"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
await server.leave_room_channel(int(channel_id), int(user_id))
|
||
logger.debug(f"Successfully removed user {user_id} from channel {channel_id} on attempt {attempt + 1}")
|
||
return True
|
||
except Exception as e:
|
||
logger.debug(f"Attempt {attempt + 1} failed to remove user {user_id} from channel {channel_id}: {e}")
|
||
if attempt == max_retries - 1:
|
||
logger.debug(f"Failed to remove user {user_id} from channel {channel_id} after {max_retries} attempts")
|
||
return False
|
||
return False
|
||
|
||
|
||
# ===== API ENDPOINTS =====
|
||
|
||
@router.post("/multiplayer/rooms")
|
||
async def create_multiplayer_room(
|
||
request: Request,
|
||
room_data: Dict[str, Any],
|
||
db: Database,
|
||
timestamp: str = "",
|
||
) -> int:
|
||
"""Create a new multiplayer room with initial playlist."""
|
||
try:
|
||
# Verify request signature
|
||
body = await request.body()
|
||
if not verify_request_signature(request, str(timestamp), body):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid request signature"
|
||
)
|
||
|
||
# Parse room data if string
|
||
if isinstance(room_data, str):
|
||
room_data = json.loads(room_data)
|
||
|
||
logger.debug(f"Creating room with data: {room_data}")
|
||
|
||
# Create room
|
||
room, host_user_id = await _create_room(db, room_data)
|
||
room_id = room.id
|
||
|
||
try:
|
||
# 确保聊天频道存在
|
||
channel = await _ensure_room_chat_channel(db, room, host_user_id)
|
||
|
||
# Add playlist items
|
||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||
|
||
# 修复:确保房主被添加为参与者
|
||
await _add_host_as_participant(db, room_id, host_user_id)
|
||
|
||
# 提交数据库更改
|
||
await db.commit()
|
||
|
||
# 房主加入聊天频道(在数据库提交后进行)
|
||
host_user = await db.get(User, host_user_id)
|
||
if host_user and channel:
|
||
try:
|
||
# 使用批量加入确保房主正确加入频道
|
||
await server.batch_join_channel([host_user], channel, db)
|
||
await db.commit() # 提交频道加入状态
|
||
|
||
# 额外确保房主在内存频道中注册
|
||
success = await _safely_join_channel(channel.channel_id, host_user_id)
|
||
if not success:
|
||
logger.error(f"Critical: Failed to register host {host_user_id} in channel {channel.channel_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to add host {host_user_id} to channel {channel.channel_id}: {e}")
|
||
# 不中断房间创建流程,但记录严重错误
|
||
|
||
return room_id
|
||
|
||
except HTTPException:
|
||
# Clean up room if setup fails
|
||
await db.rollback()
|
||
try:
|
||
await db.delete(room)
|
||
await db.commit()
|
||
except:
|
||
pass
|
||
raise
|
||
except Exception as e:
|
||
# Clean up on unexpected errors
|
||
await db.rollback()
|
||
try:
|
||
await db.delete(room)
|
||
await db.commit()
|
||
except:
|
||
pass
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to setup room: {str(e)}"
|
||
)
|
||
|
||
except json.JSONDecodeError as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"Invalid JSON: {str(e)}"
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to create room: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.delete("/multiplayer/rooms/{room_id}/users/{user_id}")
|
||
async def remove_user_from_room(
|
||
request: Request,
|
||
room_id: int,
|
||
user_id: int,
|
||
db: Database,
|
||
timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0),
|
||
) -> Dict[str, Any]:
|
||
"""Remove a user from a multiplayer room."""
|
||
try:
|
||
# Verify request signature
|
||
body = await request.body()
|
||
now = utcnow()
|
||
if not verify_request_signature(request, str(timestamp), body):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid request signature"
|
||
)
|
||
|
||
# 检查房间是否存在
|
||
room_result = await db.execute(
|
||
select(Room).where(col(Room.id) == room_id)
|
||
)
|
||
room = room_result.scalar_one_or_none()
|
||
|
||
if room is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="Room not found"
|
||
)
|
||
|
||
room_owner_id = room.host_id
|
||
room_status = room.status
|
||
current_participant_count = room.participant_count
|
||
ends_at = room.ends_at
|
||
channel_id = room.channel_id
|
||
|
||
# 如果房间已经结束,直接返回
|
||
if ends_at is not None:
|
||
logger.debug(f"Room {room_id} is already ended")
|
||
# 仍然尝试清理频道状态
|
||
if channel_id:
|
||
await _safely_leave_channel(int(channel_id), int(user_id))
|
||
return {"success": True, "room_ended": True}
|
||
|
||
# 检查用户是否在房间中
|
||
participant_result = await db.execute(
|
||
select(RoomParticipatedUser.id)
|
||
.where(
|
||
col(RoomParticipatedUser.room_id) == room_id,
|
||
col(RoomParticipatedUser.user_id) == user_id,
|
||
col(RoomParticipatedUser.left_at).is_(None)
|
||
)
|
||
)
|
||
participant_query = participant_result.first()
|
||
|
||
if not participant_query:
|
||
# 用户不在房间中,检查房间是否需要结束(幂等操作)
|
||
room_ended = await _end_room_if_empty(db, room_id)
|
||
await db.commit()
|
||
|
||
# 清理频道状态(即使用户不在参与者列表中)
|
||
if channel_id:
|
||
await _safely_leave_channel(int(channel_id), int(user_id))
|
||
if room_ended:
|
||
try:
|
||
server.channels.pop(int(channel_id), None)
|
||
except:
|
||
pass
|
||
|
||
return {"success": True, "room_ended": room_ended}
|
||
|
||
# 标记用户离开房间
|
||
await db.execute(
|
||
update(RoomParticipatedUser)
|
||
.where(
|
||
col(RoomParticipatedUser.room_id) == room_id,
|
||
col(RoomParticipatedUser.user_id) == user_id,
|
||
col(RoomParticipatedUser.left_at).is_(None)
|
||
)
|
||
.values(left_at=now)
|
||
)
|
||
|
||
room_ended = False
|
||
|
||
# 检查是否是房主离开
|
||
if user_id == room_owner_id:
|
||
logger.debug(f"Host {user_id} is leaving room {room_id}")
|
||
room_ended = await _transfer_ownership_or_end_room(db, room_id, user_id)
|
||
else:
|
||
# 不是房主离开,只需检查房间是否为空
|
||
room_ended = await _end_room_if_empty(db, room_id)
|
||
|
||
# 提交数据库更改
|
||
await db.commit()
|
||
logger.debug(f"Successfully removed user {user_id} from room {room_id}, room_ended: {room_ended}")
|
||
|
||
# 清理聊天频道状态
|
||
if channel_id:
|
||
success = await _safely_leave_channel(int(channel_id), int(user_id))
|
||
if not success:
|
||
logger.warning(f"Failed to remove user {user_id} from channel {channel_id}, but continuing")
|
||
|
||
if room_ended:
|
||
try:
|
||
# 清理内存中的频道数据
|
||
server.channels.pop(int(channel_id), None)
|
||
logger.debug(f"Cleaned up channel {channel_id} from memory")
|
||
except Exception as e:
|
||
logger.debug(f"Warning: Failed to cleanup channel {channel_id} from memory: {e}")
|
||
|
||
return {"success": True, "room_ended": room_ended}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Error removing user from room: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to remove user from room: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.put("/multiplayer/rooms/{room_id}/users/{user_id}")
|
||
async def add_user_to_room(
|
||
request: Request,
|
||
room_id: int,
|
||
user_id: int,
|
||
db: Database,
|
||
timestamp: str = "",
|
||
) -> Dict[str, Any]:
|
||
"""Add a user to a multiplayer room."""
|
||
logger.debug(f"Adding user {user_id} to room {room_id}")
|
||
|
||
# Get request body and parse user_data
|
||
body = await request.body()
|
||
user_data = None
|
||
if body:
|
||
try:
|
||
user_data = json.loads(body.decode('utf-8'))
|
||
logger.debug(f"Parsed user_data: {user_data}")
|
||
except json.JSONDecodeError:
|
||
logger.debug("Failed to parse user_data from request body")
|
||
user_data = None
|
||
|
||
# Verify request signature
|
||
if not verify_request_signature(request, timestamp, body):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid request signature"
|
||
)
|
||
|
||
try:
|
||
# 检查房间是否已结束
|
||
room_result = await db.execute(
|
||
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id)
|
||
.where(col(Room.id) == room_id)
|
||
)
|
||
room_row = room_result.first()
|
||
if not room_row:
|
||
raise HTTPException(status_code=404, detail="Room not found")
|
||
|
||
_, ends_at, channel_id, host_user_id = room_row
|
||
if ends_at is not None:
|
||
logger.debug(f"User {user_id} attempted to join ended room {room_id}")
|
||
raise HTTPException(status_code=410, detail="Room has ended and cannot accept new participants")
|
||
|
||
# Verify room password
|
||
provided_password = user_data.get("password") if user_data else None
|
||
logger.debug(f"Verifying room {room_id} with password: {provided_password}")
|
||
await _verify_room_password(db, room_id, provided_password)
|
||
|
||
# 验证用户存在
|
||
user = await _validate_user_exists(db, user_id)
|
||
|
||
# Add or update participant
|
||
await _add_or_update_participant(db, room_id, user_id)
|
||
# Update participant count
|
||
await _update_room_participant_count(db, room_id)
|
||
|
||
# 先提交 DB 状态,确保参与关系已生效
|
||
await db.commit()
|
||
logger.debug(f"Successfully added user {user_id} to room {room_id}")
|
||
|
||
# 确保聊天频道存在并让用户加入
|
||
try:
|
||
# 若房间还没分配/创建频道,补建并同步回写
|
||
if not channel_id:
|
||
room = await db.get(Room, room_id)
|
||
if room is None:
|
||
raise HTTPException(status_code=404, detail="Room not found")
|
||
channel = await _ensure_room_chat_channel(db, room, host_user_id)
|
||
await db.commit()
|
||
await db.refresh(room)
|
||
channel_id = room.channel_id
|
||
|
||
if channel_id:
|
||
# 使用安全的加入频道方法
|
||
success = await _safely_join_channel(int(channel_id), int(user_id))
|
||
if success:
|
||
logger.debug(f"User {user_id} successfully joined channel {channel_id}")
|
||
else:
|
||
logger.error(f"Critical: User {user_id} failed to join channel {channel_id}")
|
||
# 不抛出异常,允许用户继续在房间中,但记录错误
|
||
else:
|
||
logger.warning(f"Room {room_id} has no channel_id after ensure")
|
||
|
||
except Exception as e:
|
||
# 频道加入失败不应该影响用户加入房间的主要功能
|
||
logger.error(f"Failed to join user {user_id} to channel of room {room_id}: {e}")
|
||
# 返回成功,但标记频道状态异常
|
||
return {
|
||
"success": True,
|
||
"channel_error": f"Failed to join chat channel: {str(e)}"
|
||
}
|
||
|
||
return {"success": True}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Error adding user to room: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to add user to room: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/beatmaps/ensure")
|
||
async def ensure_beatmap_present(
|
||
request: Request,
|
||
beatmap_data: BeatmapEnsureRequest,
|
||
db: Database,
|
||
redis: Redis = Depends(get_redis),
|
||
fetcher: Fetcher = Depends(get_fetcher),
|
||
timestamp: str = "",
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
|
||
|
||
这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用,
|
||
避免在需要时才获取导致的延迟。
|
||
"""
|
||
try:
|
||
# Verify request signature
|
||
body = await request.body()
|
||
if not verify_request_signature(request, timestamp, body):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid request signature"
|
||
)
|
||
|
||
beatmap_id = beatmap_data.beatmap_id
|
||
logger.debug(f"Ensuring beatmap {beatmap_id} is present")
|
||
|
||
# 确保谱面存在
|
||
result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id)
|
||
|
||
# 提交数据库更改
|
||
await db.commit()
|
||
|
||
logger.debug(f"Ensure beatmap {beatmap_id} result: {result}")
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.debug(f"Error ensuring beatmap: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to ensure beatmap: {str(e)}"
|
||
) |