chore(deps): auto fix by pre-commit hooks
This commit is contained in:
committed by
MingxuanGame
parent
b4fd4e0256
commit
7625cd99f5
10
app/auth.py
10
app/auth.py
@@ -156,19 +156,17 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
|
||||
expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
# 添加标准JWT声明
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"random": secrets.token_hex(16)
|
||||
})
|
||||
if hasattr(settings, 'jwt_audience') and settings.jwt_audience:
|
||||
to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
|
||||
if hasattr(settings, "jwt_audience") and settings.jwt_audience:
|
||||
to_encode["aud"] = settings.jwt_audience
|
||||
if hasattr(settings, 'jwt_issuer') and settings.jwt_issuer:
|
||||
if hasattr(settings, "jwt_issuer") and settings.jwt_issuer:
|
||||
to_encode["iss"] = settings.jwt_issuer
|
||||
|
||||
# 编码JWT
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""生成刷新令牌"""
|
||||
length = 64
|
||||
|
||||
@@ -291,11 +291,7 @@ class UserResp(UserBase):
|
||||
).one()
|
||||
redis = get_redis()
|
||||
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
|
||||
u.cover_url = (
|
||||
obj.cover.get("url", "")
|
||||
if obj.cover
|
||||
else ""
|
||||
)
|
||||
u.cover_url = obj.cover.get("url", "") if obj.cover else ""
|
||||
|
||||
if "friends" in include:
|
||||
u.friends = [
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, ClassVar
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import SQLModel, Field, Column, DateTime, BigInteger, ForeignKey
|
||||
from typing import ClassVar
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, SQLModel
|
||||
|
||||
|
||||
class MultiplayerRealtimeRoomEventBase(SQLModel, UTCBaseModel):
|
||||
event_type: str = Field(index=True)
|
||||
event_detail: Optional[str] = Field(default=None, sa_column=Column(Text))
|
||||
event_detail: str | None = Field(default=None, sa_column=Column(Text))
|
||||
|
||||
|
||||
class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase, table=True):
|
||||
@@ -19,9 +19,7 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
|
||||
room_id: int = Field(
|
||||
sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False)
|
||||
)
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False))
|
||||
playlist_item_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(ForeignKey("playlists.id"), index=True, nullable=True),
|
||||
@@ -31,9 +29,5 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True, nullable=True),
|
||||
)
|
||||
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
|
||||
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
@@ -60,16 +60,9 @@ class Playlist(PlaylistBase, table=True):
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
created_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
sa_column_kwargs={"server_default": func.now()}
|
||||
)
|
||||
updated_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
sa_column_kwargs={
|
||||
"server_default": func.now(),
|
||||
"onupdate": func.now()
|
||||
}
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
updated_at: datetime | None = Field(
|
||||
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -74,7 +74,6 @@ class Room(AsyncAttrs, RoomBase, table=True):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
|
||||
@@ -68,9 +68,7 @@ class BaseFetcher:
|
||||
if response.status_code == 401:
|
||||
logger.warning(f"Received 401 error for {url}")
|
||||
await self._clear_tokens()
|
||||
raise TokenAuthError(
|
||||
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
|
||||
)
|
||||
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -14,12 +14,13 @@ from app.utils import bg_tasks
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
import redis.asyncio as redis
|
||||
from httpx import AsyncClient
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""速率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -73,9 +74,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
if response.status_code == 401:
|
||||
logger.warning(f"Received 401 error for {url}")
|
||||
await self._clear_tokens()
|
||||
raise TokenAuthError(
|
||||
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
|
||||
)
|
||||
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -205,7 +204,9 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit")
|
||||
logger.opt(colors=True).info(
|
||||
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
|
||||
)
|
||||
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
"""LIO (Legacy IO) router for osu-server-spectator compatibility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select, desc
|
||||
from sqlalchemy import update, func
|
||||
from redis.asyncio import Redis
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel # ChatChannel 模型 & 枚举
|
||||
from app.database.lazer_user import User
|
||||
from app.database.playlists import Playlist as DBPlaylist
|
||||
from app.database.room import Room
|
||||
@@ -17,13 +13,18 @@ from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.log import logger
|
||||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||||
from app.utils import utcnow
|
||||
from app.database.chat import ChatChannel, ChannelType # ChatChannel 模型 & 枚举
|
||||
from .notification.server import server
|
||||
from app.log import logger
|
||||
|
||||
from .notification.server import server
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, update
|
||||
from sqlmodel import col, select
|
||||
|
||||
router = APIRouter(prefix="/_lio", tags=["LIO"])
|
||||
|
||||
@@ -40,9 +41,7 @@ async def _ensure_room_chat_channel(
|
||||
# 1) 按 channel_id 查是否已存在
|
||||
try:
|
||||
# Use db.execute instead of db.exec for better async compatibility
|
||||
result = await db.execute(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == room.channel_id)
|
||||
)
|
||||
result = await db.execute(select(ChatChannel).where(ChatChannel.channel_id == room.channel_id))
|
||||
ch = result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying ChatChannel: {e}")
|
||||
@@ -59,8 +58,8 @@ async def _ensure_room_chat_channel(
|
||||
channel_id_value = int(room.channel_id)
|
||||
|
||||
ch = ChatChannel(
|
||||
channel_id=channel_id_value, # 与房间绑定的同一 channel_id(确保为 int)
|
||||
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
|
||||
channel_id=channel_id_value, # 与房间绑定的同一 channel_id(确保为 int)
|
||||
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
|
||||
description=f"Multiplayer room {room.id} chat",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
@@ -86,18 +85,20 @@ async def _alloc_channel_id(db: Database) -> int:
|
||||
logger.debug(f"Error allocating channel_id: {e}")
|
||||
# Fallback to a timestamp-based approach
|
||||
import time
|
||||
|
||||
return int(time.time()) % 1000000 + 100
|
||||
|
||||
|
||||
class RoomCreateRequest(BaseModel):
|
||||
"""Request model for creating a multiplayer room."""
|
||||
|
||||
name: str
|
||||
user_id: int
|
||||
password: str | None = None
|
||||
match_type: str = "HeadToHead"
|
||||
queue_mode: str = "HostOnly"
|
||||
initial_playlist: List[Dict[str, Any]] = []
|
||||
playlist: List[Dict[str, Any]] = []
|
||||
initial_playlist: list[dict[str, Any]] = []
|
||||
playlist: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool:
|
||||
@@ -126,10 +127,7 @@ async def _validate_user_exists(db: Database, user_id: int) -> User:
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User with ID {user_id} not found"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with ID {user_id} not found")
|
||||
|
||||
return user
|
||||
|
||||
@@ -149,7 +147,7 @@ def _parse_room_enums(match_type: str, queue_mode: str) -> tuple[MatchType, Queu
|
||||
return match_type_enum, queue_mode_enum
|
||||
|
||||
|
||||
def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_user_id: int) -> Dict[str, Any]:
|
||||
def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_user_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Normalize playlist item data with default values.
|
||||
|
||||
@@ -181,30 +179,28 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
|
||||
}
|
||||
|
||||
|
||||
def _validate_playlist_items(items: List[Dict[str, Any]]) -> None:
|
||||
def _validate_playlist_items(items: list[dict[str, Any]]) -> None:
|
||||
"""Validate playlist items data."""
|
||||
if not items:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one playlist item is required to create a room"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="At least one playlist item is required to create a room"
|
||||
)
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if item["beatmap_id"] is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Playlist item at index {idx} missing beatmap_id"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Playlist item at index {idx} missing beatmap_id"
|
||||
)
|
||||
|
||||
ruleset_id = item["ruleset_id"]
|
||||
if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}"
|
||||
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}",
|
||||
)
|
||||
|
||||
|
||||
async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, int]:
|
||||
async def _create_room(db: Database, room_data: dict[str, Any]) -> tuple[Room, int]:
|
||||
host_user_id = room_data.get("user_id")
|
||||
room_name = room_data.get("name", "Unnamed Room")
|
||||
password = room_data.get("password")
|
||||
@@ -242,12 +238,12 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
|
||||
return room, host_user_id
|
||||
|
||||
|
||||
async def _add_playlist_items(db: Database, room_id: int, room_data: Dict[str, Any], host_user_id: int) -> None:
|
||||
async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, Any], host_user_id: int) -> None:
|
||||
"""Add playlist items to the room."""
|
||||
initial_playlist = room_data.get("initial_playlist", [])
|
||||
legacy_playlist = room_data.get("playlist", [])
|
||||
|
||||
items_raw: List[Dict[str, Any]] = []
|
||||
items_raw: list[dict[str, Any]] = []
|
||||
|
||||
# Process initial playlist
|
||||
for i, item in enumerate(initial_playlist):
|
||||
@@ -292,35 +288,24 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int
|
||||
|
||||
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
|
||||
"""Verify room password if required."""
|
||||
room_result = await db.execute(
|
||||
select(Room).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
|
||||
room = room_result.scalar_one_or_none()
|
||||
|
||||
if room is None:
|
||||
logger.debug(f"Room {room_id} not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Room not found"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}")
|
||||
|
||||
# If room has password but none provided
|
||||
if room.password and not provided_password:
|
||||
logger.debug(f"Room {room_id} requires password but none provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Password required"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Password required")
|
||||
|
||||
# If room has password and provided password doesn't match
|
||||
if room.password and provided_password and provided_password != room.password:
|
||||
logger.debug(f"Room {room_id} password mismatch")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid password"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid password")
|
||||
|
||||
logger.debug(f"Room {room_id} password verification passed")
|
||||
|
||||
@@ -332,7 +317,7 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
select(RoomParticipatedUser.id).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
RoomParticipatedUser.user_id == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
)
|
||||
existing_ids = existing_result.scalars().all() # 获取所有匹配的ID
|
||||
@@ -364,10 +349,11 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
|
||||
class BeatmapEnsureRequest(BaseModel):
|
||||
"""Request model for ensuring beatmap exists."""
|
||||
|
||||
beatmap_id: int
|
||||
|
||||
|
||||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]:
|
||||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
确保谱面存在(包括元数据和原始文件缓存)。
|
||||
|
||||
@@ -383,14 +369,11 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
|
||||
try:
|
||||
# 1. 确保谱面元数据存在于数据库中
|
||||
from app.database.beatmap import Beatmap
|
||||
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
|
||||
if not beatmap:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Beatmap {beatmap_id} not found",
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
return {"success": False, "error": f"Beatmap {beatmap_id} not found", "beatmap_id": beatmap_id}
|
||||
|
||||
# 2. 预缓存谱面原始文件
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
@@ -410,35 +393,27 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
|
||||
"beatmap_id": beatmap_id,
|
||||
"metadata_cached": True,
|
||||
"raw_file_cached": await redis.exists(cache_key),
|
||||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]"
|
||||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
return {"success": False, "error": str(e), "beatmap_id": beatmap_id}
|
||||
|
||||
|
||||
async def _update_room_participant_count(db: Database, room_id: int) -> int:
|
||||
"""更新房间参与者数量并返回当前数量。"""
|
||||
# 统计活跃参与者
|
||||
active_participants_result = await db.execute(
|
||||
select(RoomParticipatedUser.user_id).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
RoomParticipatedUser.room_id == room_id, col(RoomParticipatedUser.left_at).is_(None)
|
||||
)
|
||||
)
|
||||
active_participants = active_participants_result.all()
|
||||
count = len(active_participants)
|
||||
|
||||
# 更新房间参与者数量
|
||||
await db.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room_id)
|
||||
.values(participant_count=count)
|
||||
)
|
||||
await db.execute(update(Room).where(col(Room.id) == room_id).values(participant_count=count))
|
||||
|
||||
return count
|
||||
|
||||
@@ -457,7 +432,7 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
|
||||
.values(
|
||||
ends_at=now,
|
||||
status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态
|
||||
participant_count=0
|
||||
participant_count=0,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Room {room_id} ended automatically (no participants remaining)")
|
||||
@@ -474,7 +449,7 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
|
||||
.where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) != leaving_user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序
|
||||
)
|
||||
@@ -483,23 +458,21 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
|
||||
if remaining_participants:
|
||||
# 将房主权限转让给最早加入的用户
|
||||
new_owner_id = remaining_participants[0][0] # 获取 user_id
|
||||
await db.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room_id)
|
||||
.values(host_id=new_owner_id)
|
||||
)
|
||||
await db.execute(update(Room).where(col(Room.id) == room_id).values(host_id=new_owner_id))
|
||||
logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}")
|
||||
return False # 房间继续存在
|
||||
else:
|
||||
# 没有其他参与者,结束房间
|
||||
return await _end_room_if_empty(db, room_id)
|
||||
|
||||
|
||||
# ===== API ENDPOINTS =====
|
||||
|
||||
|
||||
@router.post("/multiplayer/rooms")
|
||||
async def create_multiplayer_room(
|
||||
request: Request,
|
||||
room_data: Dict[str, Any],
|
||||
room_data: dict[str, Any],
|
||||
db: Database,
|
||||
timestamp: str = "",
|
||||
) -> int:
|
||||
@@ -508,10 +481,7 @@ async def create_multiplayer_room(
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# Parse room data if string
|
||||
if isinstance(room_data, str):
|
||||
@@ -534,7 +504,7 @@ async def create_multiplayer_room(
|
||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||
|
||||
# Add host as participant
|
||||
#await _add_host_as_participant(db, room_id, host_user_id)
|
||||
# await _add_host_as_participant(db, room_id, host_user_id)
|
||||
|
||||
await db.commit()
|
||||
return room_id
|
||||
@@ -546,18 +516,11 @@ async def create_multiplayer_room(
|
||||
raise
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {e!s}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create room: {str(e)}"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create room: {e!s}")
|
||||
|
||||
|
||||
@router.delete("/multiplayer/rooms/{room_id}/users/{user_id}")
|
||||
@@ -567,29 +530,21 @@ async def remove_user_from_room(
|
||||
user_id: int,
|
||||
db: Database,
|
||||
timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0),
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Remove a user from a multiplayer room."""
|
||||
try:
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
now = utcnow()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# 检查房间是否存在
|
||||
room_result = await db.execute(
|
||||
select(Room).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
|
||||
room = room_result.scalar_one_or_none()
|
||||
|
||||
if room is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Room not found"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
room_owner_id = room.host_id
|
||||
room_status = room.status
|
||||
@@ -604,11 +559,10 @@ async def remove_user_from_room(
|
||||
|
||||
# 检查用户是否在房间中
|
||||
participant_result = await db.execute(
|
||||
select(RoomParticipatedUser.id)
|
||||
.where(
|
||||
select(RoomParticipatedUser.id).where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
)
|
||||
participant_query = participant_result.first()
|
||||
@@ -634,7 +588,7 @@ async def remove_user_from_room(
|
||||
.where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.values(left_at=now)
|
||||
)
|
||||
@@ -667,10 +621,9 @@ async def remove_user_from_room(
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.debug(f"Error removing user from room: {str(e)}")
|
||||
logger.debug(f"Error removing user from room: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove user from room: {str(e)}"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user from room: {e!s}"
|
||||
)
|
||||
|
||||
|
||||
@@ -681,7 +634,7 @@ async def add_user_to_room(
|
||||
user_id: int,
|
||||
db: Database,
|
||||
timestamp: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Add a user to a multiplayer room."""
|
||||
logger.debug(f"Adding user {user_id} to room {room_id}")
|
||||
|
||||
@@ -690,7 +643,7 @@ async def add_user_to_room(
|
||||
user_data = None
|
||||
if body:
|
||||
try:
|
||||
user_data = json.loads(body.decode('utf-8'))
|
||||
user_data = json.loads(body.decode("utf-8"))
|
||||
logger.debug(f"Parsed user_data: {user_data}")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Failed to parse user_data from request body")
|
||||
@@ -698,15 +651,11 @@ async def add_user_to_room(
|
||||
|
||||
# Verify request signature
|
||||
if not verify_request_signature(request, timestamp, body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# 检查房间是否已结束
|
||||
room_result = await db.execute(
|
||||
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id)
|
||||
.where(col(Room.id) == room_id)
|
||||
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_row = room_result.first()
|
||||
if not room_row:
|
||||
@@ -763,7 +712,7 @@ async def ensure_beatmap_present(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
timestamp: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
|
||||
|
||||
@@ -774,10 +723,7 @@ async def ensure_beatmap_present(
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, timestamp, body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
beatmap_id = beatmap_data.beatmap_id
|
||||
logger.debug(f"Ensuring beatmap {beatmap_id} is present")
|
||||
@@ -795,8 +741,7 @@ async def ensure_beatmap_present(
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.debug(f"Error ensuring beatmap: {str(e)}")
|
||||
logger.debug(f"Error ensuring beatmap: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to ensure beatmap: {str(e)}"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to ensure beatmap: {e!s}"
|
||||
)
|
||||
@@ -96,6 +96,7 @@ async def send_message(
|
||||
if channel_type == ChannelType.MULTIPLAYER:
|
||||
try:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
key = f"channel:{channel_id}:messages"
|
||||
key_type = await redis.type(key)
|
||||
@@ -162,13 +163,9 @@ async def get_message(
|
||||
):
|
||||
# 1) 查频道
|
||||
if channel.isdigit():
|
||||
db_channel = (await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -176,7 +173,6 @@ async def get_message(
|
||||
channel_id = db_channel.channel_id
|
||||
|
||||
try:
|
||||
|
||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
|
||||
messages.reverse()
|
||||
@@ -188,11 +184,7 @@ async def get_message(
|
||||
|
||||
if since > 0 and until is None:
|
||||
# 向前加载新消息 → 直接 ASC
|
||||
query = (
|
||||
base.where(col(ChatMessage.message_id) > since)
|
||||
.order_by(col(ChatMessage.message_id).asc())
|
||||
.limit(limit)
|
||||
)
|
||||
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
# 已经 ASC,无需反转
|
||||
@@ -202,9 +194,7 @@ async def get_message(
|
||||
if until is not None:
|
||||
# 用 DESC 取最近的更早消息,再反转为 ASC
|
||||
query = (
|
||||
base.where(col(ChatMessage.message_id) < until)
|
||||
.order_by(col(ChatMessage.message_id).desc())
|
||||
.limit(limit)
|
||||
base.where(col(ChatMessage.message_id) < until).order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
)
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
@@ -221,7 +211,6 @@ async def get_message(
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/mark-as-read/{message}",
|
||||
status_code=204,
|
||||
|
||||
@@ -81,10 +81,11 @@ class ChatServer:
|
||||
if not users_in_channel:
|
||||
try:
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select
|
||||
channel = await session.get(ChatChannel, channel_id)
|
||||
if channel and channel.type == ChannelType.MULTIPLAYER:
|
||||
logger.warning(f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone")
|
||||
logger.warning(
|
||||
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
|
||||
)
|
||||
# 对于多人游戏房间,这可能是正常的(用户都离开了房间)
|
||||
# 但我们仍然记录这个情况以便调试
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,8 +13,8 @@ from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
from app.service.asset_proxy_helper import process_response_assets
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
|
||||
from .router import router
|
||||
|
||||
|
||||
@@ -59,15 +59,17 @@ async def get_all_rooms(
|
||||
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
#print(mode, category, status, current_user.id)
|
||||
# print(mode, category, status, current_user.id)
|
||||
if mode == "open":
|
||||
# 修改为新的查询逻辑:状态为 idle 或 playing,starts_at 不为空,ends_at 为空
|
||||
where_clauses.extend([
|
||||
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
|
||||
col(Room.starts_at).is_not(None),
|
||||
col(Room.ends_at).is_(None)
|
||||
])
|
||||
#if category == RoomCategory.REALTIME:
|
||||
where_clauses.extend(
|
||||
[
|
||||
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
|
||||
col(Room.starts_at).is_not(None),
|
||||
col(Room.ends_at).is_(None),
|
||||
]
|
||||
)
|
||||
# if category == RoomCategory.REALTIME:
|
||||
# where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
|
||||
if mode == "participated":
|
||||
@@ -96,7 +98,7 @@ async def get_all_rooms(
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
#print("Retrieved rooms:", db_rooms)
|
||||
# print("Retrieved rooms:", db_rooms)
|
||||
for room in db_rooms:
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
resp.has_password = bool((room.password or "").strip())
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from fastapi import Request
|
||||
|
||||
from app.config import settings
|
||||
from app.service.asset_proxy_service import get_asset_proxy_service
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
async def process_response_assets(data: Any, request: Request) -> Any:
|
||||
"""
|
||||
@@ -56,6 +58,7 @@ def asset_proxy_response(func):
|
||||
"""
|
||||
装饰器:自动处理响应中的资源URL
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 获取request对象
|
||||
request = None
|
||||
|
||||
@@ -7,8 +7,8 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AssetProxyService:
|
||||
@@ -26,7 +26,7 @@ class AssetProxyService:
|
||||
递归替换数据中的osu!资源URL为自定义域名
|
||||
"""
|
||||
# 处理Pydantic模型
|
||||
if hasattr(data, 'model_dump'):
|
||||
if hasattr(data, "model_dump"):
|
||||
# 转换为字典,处理后再转换回模型
|
||||
data_dict = data.model_dump()
|
||||
processed_dict = await self.replace_asset_urls(data_dict)
|
||||
@@ -49,30 +49,20 @@ class AssetProxyService:
|
||||
|
||||
# 替换 assets.ppy.sh (用户头像、封面、奖章等)
|
||||
result = re.sub(
|
||||
r"https://assets\.ppy\.sh/",
|
||||
f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://assets\.ppy\.sh/", f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
# 替换 b.ppy.sh 预览音频 (保持//前缀)
|
||||
result = re.sub(
|
||||
r"//b\.ppy\.sh/",
|
||||
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
)
|
||||
result = re.sub(r"//b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result)
|
||||
|
||||
# 替换 https://b.ppy.sh 预览音频 (转换为//前缀)
|
||||
result = re.sub(
|
||||
r"https://b\.ppy\.sh/",
|
||||
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
# 替换 a.ppy.sh 头像
|
||||
result = re.sub(
|
||||
r"https://a\.ppy\.sh/",
|
||||
f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://a\.ppy\.sh/", f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
|
||||
@@ -14,8 +14,8 @@ from app.config import settings
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
from app.service.asset_proxy_service import get_asset_proxy_service
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -258,7 +258,9 @@ class RedisMessageSystem:
|
||||
# 验证删除是否成功
|
||||
verify_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if verify_type != "none":
|
||||
logger.error(f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}")
|
||||
logger.error(
|
||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
||||
)
|
||||
# 强制删除
|
||||
await self._redis_exec(self.redis.unlink, channel_messages_key)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 57bacf936413
|
||||
Create Date: 2025-08-24 00:08:14.704724
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 8d2af11343b9
|
||||
Create Date: 2025-08-24 04:00:02.063347
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
@@ -22,10 +23,14 @@ depends_on: str | Sequence[str] | None = None
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
#op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
|
||||
#op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
|
||||
op.add_column("room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
|
||||
# op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
|
||||
# op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
|
||||
op.add_column(
|
||||
"room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -35,5 +40,7 @@ def downgrade() -> None:
|
||||
op.drop_column("room_playlists", "updated_at")
|
||||
op.drop_column("room_playlists", "created_at")
|
||||
op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True)
|
||||
op.create_index(op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,13 +5,13 @@ Revises: 178873984b22
|
||||
Create Date: 2025-08-23 18:45:03.009632
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "57bacf936413"
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 20c6df84813f
|
||||
Create Date: 2025-08-24 00:08:42.419252
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 7576ca1e056d
|
||||
Create Date: 2025-08-24 00:11:05.064099
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
Reference in New Issue
Block a user