chore(deps): auto fix by pre-commit hooks
This commit is contained in:
committed by
MingxuanGame
parent
b4fd4e0256
commit
7625cd99f5
14
app/auth.py
14
app/auth.py
@@ -154,21 +154,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
|
||||
expire = utcnow() + expires_delta
|
||||
else:
|
||||
expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
|
||||
# 添加标准JWT声明
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"random": secrets.token_hex(16)
|
||||
})
|
||||
if hasattr(settings, 'jwt_audience') and settings.jwt_audience:
|
||||
to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
|
||||
if hasattr(settings, "jwt_audience") and settings.jwt_audience:
|
||||
to_encode["aud"] = settings.jwt_audience
|
||||
if hasattr(settings, 'jwt_issuer') and settings.jwt_issuer:
|
||||
if hasattr(settings, "jwt_issuer") and settings.jwt_issuer:
|
||||
to_encode["iss"] = settings.jwt_issuer
|
||||
|
||||
|
||||
# 编码JWT
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""生成刷新令牌"""
|
||||
length = 64
|
||||
|
||||
@@ -291,11 +291,7 @@ class UserResp(UserBase):
|
||||
).one()
|
||||
redis = get_redis()
|
||||
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
|
||||
u.cover_url = (
|
||||
obj.cover.get("url", "")
|
||||
if obj.cover
|
||||
else ""
|
||||
)
|
||||
u.cover_url = obj.cover.get("url", "") if obj.cover else ""
|
||||
|
||||
if "friends" in include:
|
||||
u.friends = [
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, ClassVar
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import SQLModel, Field, Column, DateTime, BigInteger, ForeignKey
|
||||
from typing import ClassVar
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, SQLModel
|
||||
|
||||
|
||||
class MultiplayerRealtimeRoomEventBase(SQLModel, UTCBaseModel):
|
||||
event_type: str = Field(index=True)
|
||||
event_detail: Optional[str] = Field(default=None, sa_column=Column(Text))
|
||||
event_detail: str | None = Field(default=None, sa_column=Column(Text))
|
||||
|
||||
|
||||
class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase, table=True):
|
||||
@@ -19,9 +19,7 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
|
||||
room_id: int = Field(
|
||||
sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False)
|
||||
)
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False))
|
||||
playlist_item_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(ForeignKey("playlists.id"), index=True, nullable=True),
|
||||
@@ -31,9 +29,5 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True, nullable=True),
|
||||
)
|
||||
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
|
||||
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
@@ -60,16 +60,9 @@ class Playlist(PlaylistBase, table=True):
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
created_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
sa_column_kwargs={"server_default": func.now()}
|
||||
)
|
||||
updated_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
sa_column_kwargs={
|
||||
"server_default": func.now(),
|
||||
"onupdate": func.now()
|
||||
}
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
updated_at: datetime | None = Field(
|
||||
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -139,4 +132,4 @@ class PlaylistResp(PlaylistBase):
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
return resp
|
||||
|
||||
@@ -74,7 +74,6 @@ class Room(AsyncAttrs, RoomBase, table=True):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
|
||||
@@ -68,9 +68,7 @@ class BaseFetcher:
|
||||
if response.status_code == 401:
|
||||
logger.warning(f"Received 401 error for {url}")
|
||||
await self._clear_tokens()
|
||||
raise TokenAuthError(
|
||||
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
|
||||
)
|
||||
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -146,7 +144,7 @@ class BaseFetcher:
|
||||
清除所有 token
|
||||
"""
|
||||
logger.warning(f"Clearing tokens for client {self.client_id}")
|
||||
|
||||
|
||||
# 清除内存中的 token
|
||||
self.access_token = ""
|
||||
self.refresh_token = ""
|
||||
@@ -167,4 +165,4 @@ class BaseFetcher:
|
||||
"has_refresh_token": bool(self.refresh_token),
|
||||
"token_expired": self.is_token_expired(),
|
||||
"authorize_url": self.authorize_url,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,13 @@ from app.utils import bg_tasks
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
import redis.asyncio as redis
|
||||
from httpx import AsyncClient
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""速率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -73,9 +74,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
if response.status_code == 401:
|
||||
logger.warning(f"Received 401 error for {url}")
|
||||
await self._clear_tokens()
|
||||
raise TokenAuthError(
|
||||
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
|
||||
)
|
||||
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -205,7 +204,9 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit")
|
||||
logger.opt(colors=True).info(
|
||||
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
|
||||
)
|
||||
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
@@ -352,4 +353,4 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).error(
|
||||
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -212,7 +212,7 @@ async def oauth_token(
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
|
||||
client = (
|
||||
await db.exec(
|
||||
select(OAuthClient).where(
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
"""LIO (Legacy IO) router for osu-server-spectator compatibility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select, desc
|
||||
from sqlalchemy import update, func
|
||||
from redis.asyncio import Redis
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel # ChatChannel 模型 & 枚举
|
||||
from app.database.lazer_user import User
|
||||
from app.database.playlists import Playlist as DBPlaylist
|
||||
from app.database.room import Room
|
||||
@@ -17,13 +13,18 @@ from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.log import logger
|
||||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||||
from app.utils import utcnow
|
||||
from app.database.chat import ChatChannel, ChannelType # ChatChannel 模型 & 枚举
|
||||
from .notification.server import server
|
||||
from app.log import logger
|
||||
|
||||
from .notification.server import server
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, update
|
||||
from sqlmodel import col, select
|
||||
|
||||
router = APIRouter(prefix="/_lio", tags=["LIO"])
|
||||
|
||||
@@ -40,9 +41,7 @@ async def _ensure_room_chat_channel(
|
||||
# 1) 按 channel_id 查是否已存在
|
||||
try:
|
||||
# Use db.execute instead of db.exec for better async compatibility
|
||||
result = await db.execute(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == room.channel_id)
|
||||
)
|
||||
result = await db.execute(select(ChatChannel).where(ChatChannel.channel_id == room.channel_id))
|
||||
ch = result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying ChatChannel: {e}")
|
||||
@@ -59,8 +58,8 @@ async def _ensure_room_chat_channel(
|
||||
channel_id_value = int(room.channel_id)
|
||||
|
||||
ch = ChatChannel(
|
||||
channel_id=channel_id_value, # 与房间绑定的同一 channel_id(确保为 int)
|
||||
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
|
||||
channel_id=channel_id_value, # 与房间绑定的同一 channel_id(确保为 int)
|
||||
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
|
||||
description=f"Multiplayer room {room.id} chat",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
@@ -86,32 +85,34 @@ async def _alloc_channel_id(db: Database) -> int:
|
||||
logger.debug(f"Error allocating channel_id: {e}")
|
||||
# Fallback to a timestamp-based approach
|
||||
import time
|
||||
|
||||
return int(time.time()) % 1000000 + 100
|
||||
|
||||
|
||||
|
||||
|
||||
class RoomCreateRequest(BaseModel):
|
||||
"""Request model for creating a multiplayer room."""
|
||||
|
||||
name: str
|
||||
user_id: int
|
||||
password: str | None = None
|
||||
match_type: str = "HeadToHead"
|
||||
queue_mode: str = "HostOnly"
|
||||
initial_playlist: List[Dict[str, Any]] = []
|
||||
playlist: List[Dict[str, Any]] = []
|
||||
initial_playlist: list[dict[str, Any]] = []
|
||||
playlist: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool:
|
||||
"""
|
||||
Verify HMAC signature for shared interop requests.
|
||||
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
timestamp: Request timestamp
|
||||
body: Request body bytes
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if signature is valid
|
||||
|
||||
|
||||
Note:
|
||||
Currently skips verification in development.
|
||||
In production, implement proper HMAC verification.
|
||||
@@ -124,13 +125,10 @@ async def _validate_user_exists(db: Database, user_id: int) -> User:
|
||||
"""Validate that a user exists in the database."""
|
||||
user_result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User with ID {user_id} not found"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with ID {user_id} not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -145,19 +143,19 @@ def _parse_room_enums(match_type: str, queue_mode: str) -> tuple[MatchType, Queu
|
||||
queue_mode_enum = QueueMode(queue_mode.lower())
|
||||
except ValueError:
|
||||
queue_mode_enum = QueueMode.HOST_ONLY
|
||||
|
||||
|
||||
return match_type_enum, queue_mode_enum
|
||||
|
||||
|
||||
def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_user_id: int) -> Dict[str, Any]:
|
||||
def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_user_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
Normalize playlist item data with default values.
|
||||
|
||||
|
||||
Args:
|
||||
item_data: Raw playlist item data
|
||||
default_order: Default playlist order
|
||||
host_user_id: Host user ID for default owner
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with normalized playlist item data
|
||||
"""
|
||||
@@ -165,7 +163,7 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
|
||||
owner_id = item_data.get("owner_id", host_user_id)
|
||||
if owner_id == 0:
|
||||
owner_id = host_user_id
|
||||
|
||||
|
||||
return {
|
||||
"owner_id": owner_id,
|
||||
"ruleset_id": item_data.get("ruleset_id", 0),
|
||||
@@ -181,30 +179,28 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
|
||||
}
|
||||
|
||||
|
||||
def _validate_playlist_items(items: List[Dict[str, Any]]) -> None:
|
||||
def _validate_playlist_items(items: list[dict[str, Any]]) -> None:
|
||||
"""Validate playlist items data."""
|
||||
if not items:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one playlist item is required to create a room"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="At least one playlist item is required to create a room"
|
||||
)
|
||||
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if item["beatmap_id"] is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Playlist item at index {idx} missing beatmap_id"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Playlist item at index {idx} missing beatmap_id"
|
||||
)
|
||||
|
||||
|
||||
ruleset_id = item["ruleset_id"]
|
||||
if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}"
|
||||
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}",
|
||||
)
|
||||
|
||||
|
||||
async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, int]:
|
||||
async def _create_room(db: Database, room_data: dict[str, Any]) -> tuple[Room, int]:
|
||||
host_user_id = room_data.get("user_id")
|
||||
room_name = room_data.get("name", "Unnamed Room")
|
||||
password = room_data.get("password")
|
||||
@@ -232,7 +228,7 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
|
||||
participant_count=1,
|
||||
auto_skip=False,
|
||||
auto_start_duration=0,
|
||||
channel_id=channel_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
|
||||
db.add(room)
|
||||
@@ -242,27 +238,27 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
|
||||
return room, host_user_id
|
||||
|
||||
|
||||
async def _add_playlist_items(db: Database, room_id: int, room_data: Dict[str, Any], host_user_id: int) -> None:
|
||||
async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, Any], host_user_id: int) -> None:
|
||||
"""Add playlist items to the room."""
|
||||
initial_playlist = room_data.get("initial_playlist", [])
|
||||
legacy_playlist = room_data.get("playlist", [])
|
||||
|
||||
items_raw: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
items_raw: list[dict[str, Any]] = []
|
||||
|
||||
# Process initial playlist
|
||||
for i, item in enumerate(initial_playlist):
|
||||
if hasattr(item, "dict"):
|
||||
item = item.dict()
|
||||
items_raw.append(_coerce_playlist_item(item, i, host_user_id))
|
||||
|
||||
|
||||
# Process legacy playlist
|
||||
start_index = len(items_raw)
|
||||
for j, item in enumerate(legacy_playlist, start=start_index):
|
||||
items_raw.append(_coerce_playlist_item(item, j, host_user_id))
|
||||
|
||||
|
||||
# Validate playlist items
|
||||
_validate_playlist_items(items_raw)
|
||||
|
||||
|
||||
# Insert playlist items
|
||||
for item_data in items_raw:
|
||||
hub_item = HubPlaylistItem(
|
||||
@@ -286,42 +282,31 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int
|
||||
"""Add the host as a room participant and update participant count."""
|
||||
participant = RoomParticipatedUser(room_id=room_id, user_id=host_user_id)
|
||||
db.add(participant)
|
||||
|
||||
|
||||
await _update_room_participant_count(db, room_id)
|
||||
|
||||
|
||||
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
|
||||
"""Verify room password if required."""
|
||||
room_result = await db.execute(
|
||||
select(Room).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
|
||||
room = room_result.scalar_one_or_none()
|
||||
|
||||
|
||||
if room is None:
|
||||
logger.debug(f"Room {room_id} not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Room not found"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}")
|
||||
|
||||
|
||||
# If room has password but none provided
|
||||
if room.password and not provided_password:
|
||||
logger.debug(f"Room {room_id} requires password but none provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Password required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Password required")
|
||||
|
||||
# If room has password and provided password doesn't match
|
||||
if room.password and provided_password and provided_password != room.password:
|
||||
logger.debug(f"Room {room_id} password mismatch")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid password"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid password")
|
||||
|
||||
logger.debug(f"Room {room_id} password verification passed")
|
||||
|
||||
|
||||
@@ -332,16 +317,16 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
select(RoomParticipatedUser.id).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
RoomParticipatedUser.user_id == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
)
|
||||
existing_ids = existing_result.scalars().all() # 获取所有匹配的ID
|
||||
|
||||
|
||||
if existing_ids:
|
||||
# 如果存在多条记录,清理重复项,只保留最新的一条
|
||||
if len(existing_ids) > 1:
|
||||
logger.debug(f"警告:用户 {user_id} 在房间 {room_id} 中发现 {len(existing_ids)} 条活跃参与记录")
|
||||
|
||||
|
||||
# 将除第一条外的所有记录标记为已离开(清理重复记录)
|
||||
for extra_id in existing_ids[1:]:
|
||||
await db.execute(
|
||||
@@ -349,7 +334,7 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
.where(col(RoomParticipatedUser.id) == extra_id)
|
||||
.values(left_at=utcnow())
|
||||
)
|
||||
|
||||
|
||||
# 更新剩余的活跃参与记录(刷新加入时间)
|
||||
await db.execute(
|
||||
update(RoomParticipatedUser)
|
||||
@@ -364,38 +349,36 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
|
||||
|
||||
class BeatmapEnsureRequest(BaseModel):
|
||||
"""Request model for ensuring beatmap exists."""
|
||||
|
||||
beatmap_id: int
|
||||
|
||||
|
||||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]:
|
||||
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> dict[str, Any]:
|
||||
"""
|
||||
确保谱面存在(包括元数据和原始文件缓存)。
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
fetcher: API获取器
|
||||
redis: Redis连接
|
||||
beatmap_id: 谱面ID
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 包含状态信息的响应
|
||||
"""
|
||||
try:
|
||||
# 1. 确保谱面元数据存在于数据库中
|
||||
from app.database.beatmap import Beatmap
|
||||
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
|
||||
|
||||
if not beatmap:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Beatmap {beatmap_id} not found",
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
|
||||
return {"success": False, "error": f"Beatmap {beatmap_id} not found", "beatmap_id": beatmap_id}
|
||||
|
||||
# 2. 预缓存谱面原始文件
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
cached = await redis.exists(cache_key)
|
||||
|
||||
|
||||
if not cached:
|
||||
# 异步预加载原始文件到缓存
|
||||
try:
|
||||
@@ -404,42 +387,34 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
|
||||
except Exception as e:
|
||||
logger.debug(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}")
|
||||
# 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"beatmap_id": beatmap_id,
|
||||
"metadata_cached": True,
|
||||
"raw_file_cached": await redis.exists(cache_key),
|
||||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]"
|
||||
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]",
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"beatmap_id": beatmap_id
|
||||
}
|
||||
return {"success": False, "error": str(e), "beatmap_id": beatmap_id}
|
||||
|
||||
|
||||
async def _update_room_participant_count(db: Database, room_id: int) -> int:
|
||||
"""更新房间参与者数量并返回当前数量。"""
|
||||
# 统计活跃参与者
|
||||
active_participants_result = await db.execute(
|
||||
select(RoomParticipatedUser.user_id).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
RoomParticipatedUser.room_id == room_id, col(RoomParticipatedUser.left_at).is_(None)
|
||||
)
|
||||
)
|
||||
active_participants = active_participants_result.all()
|
||||
count = len(active_participants)
|
||||
|
||||
|
||||
# 更新房间参与者数量
|
||||
await db.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room_id)
|
||||
.values(participant_count=count)
|
||||
)
|
||||
|
||||
await db.execute(update(Room).where(col(Room.id) == room_id).values(participant_count=count))
|
||||
|
||||
return count
|
||||
|
||||
|
||||
@@ -447,7 +422,7 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
|
||||
"""如果房间为空,则标记房间结束。返回是否结束了房间。"""
|
||||
# 检查房间是否还有活跃参与者
|
||||
participant_count = await _update_room_participant_count(db, room_id)
|
||||
|
||||
|
||||
if participant_count == 0:
|
||||
# 房间为空,标记结束
|
||||
now = utcnow()
|
||||
@@ -457,12 +432,12 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
|
||||
.values(
|
||||
ends_at=now,
|
||||
status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态
|
||||
participant_count=0
|
||||
participant_count=0,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Room {room_id} ended automatically (no participants remaining)")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -474,32 +449,30 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
|
||||
.where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) != leaving_user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序
|
||||
)
|
||||
remaining_participants = remaining_result.all()
|
||||
|
||||
|
||||
if remaining_participants:
|
||||
# 将房主权限转让给最早加入的用户
|
||||
new_owner_id = remaining_participants[0][0] # 获取 user_id
|
||||
await db.execute(
|
||||
update(Room)
|
||||
.where(col(Room.id) == room_id)
|
||||
.values(host_id=new_owner_id)
|
||||
)
|
||||
await db.execute(update(Room).where(col(Room.id) == room_id).values(host_id=new_owner_id))
|
||||
logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}")
|
||||
return False # 房间继续存在
|
||||
else:
|
||||
# 没有其他参与者,结束房间
|
||||
return await _end_room_if_empty(db, room_id)
|
||||
|
||||
|
||||
# ===== API ENDPOINTS =====
|
||||
|
||||
|
||||
@router.post("/multiplayer/rooms")
|
||||
async def create_multiplayer_room(
|
||||
request: Request,
|
||||
room_data: Dict[str, Any],
|
||||
room_data: dict[str, Any],
|
||||
db: Database,
|
||||
timestamp: str = "",
|
||||
) -> int:
|
||||
@@ -508,10 +481,7 @@ async def create_multiplayer_room(
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# Parse room data if string
|
||||
if isinstance(room_data, str):
|
||||
@@ -532,13 +502,13 @@ async def create_multiplayer_room(
|
||||
await server.batch_join_channel([host_user], channel, db)
|
||||
# Add playlist items
|
||||
await _add_playlist_items(db, room_id, room_data, host_user_id)
|
||||
|
||||
|
||||
# Add host as participant
|
||||
#await _add_host_as_participant(db, room_id, host_user_id)
|
||||
|
||||
# await _add_host_as_participant(db, room_id, host_user_id)
|
||||
|
||||
await db.commit()
|
||||
return room_id
|
||||
|
||||
|
||||
except HTTPException:
|
||||
# Clean up room if playlist creation fails
|
||||
await db.delete(room)
|
||||
@@ -546,18 +516,11 @@ async def create_multiplayer_room(
|
||||
raise
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid JSON: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {e!s}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create room: {str(e)}"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create room: {e!s}")
|
||||
|
||||
|
||||
@router.delete("/multiplayer/rooms/{room_id}/users/{user_id}")
|
||||
@@ -567,36 +530,28 @@ async def remove_user_from_room(
|
||||
user_id: int,
|
||||
db: Database,
|
||||
timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0),
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Remove a user from a multiplayer room."""
|
||||
try:
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
now = utcnow()
|
||||
if not verify_request_signature(request, str(timestamp), body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# 检查房间是否存在
|
||||
room_result = await db.execute(
|
||||
select(Room).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
|
||||
room = room_result.scalar_one_or_none()
|
||||
|
||||
|
||||
if room is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Room not found"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
room_owner_id = room.host_id
|
||||
room_status = room.status
|
||||
current_participant_count = room.participant_count
|
||||
ends_at = room.ends_at
|
||||
channel_id = room.channel_id
|
||||
|
||||
|
||||
# 如果房间已经结束,直接返回
|
||||
if ends_at is not None:
|
||||
logger.debug(f"Room {room_id} is already ended")
|
||||
@@ -604,11 +559,10 @@ async def remove_user_from_room(
|
||||
|
||||
# 检查用户是否在房间中
|
||||
participant_result = await db.execute(
|
||||
select(RoomParticipatedUser.id)
|
||||
.where(
|
||||
select(RoomParticipatedUser.id).where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
)
|
||||
participant_query = participant_result.first()
|
||||
@@ -634,13 +588,13 @@ async def remove_user_from_room(
|
||||
.where(
|
||||
col(RoomParticipatedUser.room_id) == room_id,
|
||||
col(RoomParticipatedUser.user_id) == user_id,
|
||||
col(RoomParticipatedUser.left_at).is_(None)
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.values(left_at=now)
|
||||
)
|
||||
|
||||
room_ended = False
|
||||
|
||||
|
||||
# 检查是否是房主离开
|
||||
if user_id == room_owner_id:
|
||||
logger.debug(f"Host {user_id} is leaving room {room_id}")
|
||||
@@ -648,10 +602,10 @@ async def remove_user_from_room(
|
||||
else:
|
||||
# 不是房主离开,只需检查房间是否为空
|
||||
room_ended = await _end_room_if_empty(db, room_id)
|
||||
|
||||
|
||||
await db.commit()
|
||||
logger.debug(f"Successfully removed user {user_id} from room {room_id}, room_ended: {room_ended}")
|
||||
|
||||
|
||||
# ===== 新增:提交后,把用户从聊天频道移除;若房间已结束,清理内存频道 =====
|
||||
try:
|
||||
if channel_id:
|
||||
@@ -660,17 +614,16 @@ async def remove_user_from_room(
|
||||
server.channels.pop(int(channel_id), None)
|
||||
except Exception as e:
|
||||
logger.debug(f"[warn] failed to leave user {user_id} from channel {channel_id}: {e}")
|
||||
|
||||
|
||||
return {"success": True, "room_ended": room_ended}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.debug(f"Error removing user from room: {str(e)}")
|
||||
logger.debug(f"Error removing user from room: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove user from room: {str(e)}"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user from room: {e!s}"
|
||||
)
|
||||
|
||||
|
||||
@@ -681,32 +634,28 @@ async def add_user_to_room(
|
||||
user_id: int,
|
||||
db: Database,
|
||||
timestamp: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Add a user to a multiplayer room."""
|
||||
logger.debug(f"Adding user {user_id} to room {room_id}")
|
||||
|
||||
|
||||
# Get request body and parse user_data
|
||||
body = await request.body()
|
||||
user_data = None
|
||||
if body:
|
||||
try:
|
||||
user_data = json.loads(body.decode('utf-8'))
|
||||
user_data = json.loads(body.decode("utf-8"))
|
||||
logger.debug(f"Parsed user_data: {user_data}")
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Failed to parse user_data from request body")
|
||||
user_data = None
|
||||
|
||||
|
||||
# Verify request signature
|
||||
if not verify_request_signature(request, timestamp, body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
# 检查房间是否已结束
|
||||
room_result = await db.execute(
|
||||
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id)
|
||||
.where(col(Room.id) == room_id)
|
||||
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id).where(col(Room.id) == room_id)
|
||||
)
|
||||
room_row = room_result.first()
|
||||
if not room_row:
|
||||
@@ -763,10 +712,10 @@ async def ensure_beatmap_present(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
timestamp: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
|
||||
|
||||
|
||||
这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用,
|
||||
避免在需要时才获取导致的延迟。
|
||||
"""
|
||||
@@ -774,20 +723,17 @@ async def ensure_beatmap_present(
|
||||
# Verify request signature
|
||||
body = await request.body()
|
||||
if not verify_request_signature(request, timestamp, body):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid request signature"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
|
||||
|
||||
beatmap_id = beatmap_data.beatmap_id
|
||||
logger.debug(f"Ensuring beatmap {beatmap_id} is present")
|
||||
|
||||
# 确保谱面存在
|
||||
result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id)
|
||||
|
||||
|
||||
# 提交数据库更改
|
||||
await db.commit()
|
||||
|
||||
|
||||
logger.debug(f"Ensure beatmap {beatmap_id} result: {result}")
|
||||
return result
|
||||
|
||||
@@ -795,8 +741,7 @@ async def ensure_beatmap_present(
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.debug(f"Error ensuring beatmap: {str(e)}")
|
||||
logger.debug(f"Error ensuring beatmap: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to ensure beatmap: {str(e)}"
|
||||
)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to ensure beatmap: {e!s}"
|
||||
)
|
||||
|
||||
@@ -96,6 +96,7 @@ async def send_message(
|
||||
if channel_type == ChannelType.MULTIPLAYER:
|
||||
try:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
key = f"channel:{channel_id}:messages"
|
||||
key_type = await redis.type(key)
|
||||
@@ -162,13 +163,9 @@ async def get_message(
|
||||
):
|
||||
# 1) 查频道
|
||||
if channel.isdigit():
|
||||
db_channel = (await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -176,7 +173,6 @@ async def get_message(
|
||||
channel_id = db_channel.channel_id
|
||||
|
||||
try:
|
||||
|
||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
|
||||
messages.reverse()
|
||||
@@ -188,11 +184,7 @@ async def get_message(
|
||||
|
||||
if since > 0 and until is None:
|
||||
# 向前加载新消息 → 直接 ASC
|
||||
query = (
|
||||
base.where(col(ChatMessage.message_id) > since)
|
||||
.order_by(col(ChatMessage.message_id).asc())
|
||||
.limit(limit)
|
||||
)
|
||||
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
# 已经 ASC,无需反转
|
||||
@@ -202,9 +194,7 @@ async def get_message(
|
||||
if until is not None:
|
||||
# 用 DESC 取最近的更早消息,再反转为 ASC
|
||||
query = (
|
||||
base.where(col(ChatMessage.message_id) < until)
|
||||
.order_by(col(ChatMessage.message_id).desc())
|
||||
.limit(limit)
|
||||
base.where(col(ChatMessage.message_id) < until).order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
)
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
@@ -221,7 +211,6 @@ async def get_message(
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/mark-as-read/{message}",
|
||||
status_code=204,
|
||||
|
||||
@@ -76,20 +76,21 @@ class ChatServer:
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
|
||||
|
||||
|
||||
# 如果频道中没有用户,检查是否是多人游戏频道
|
||||
if not users_in_channel:
|
||||
try:
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select
|
||||
channel = await session.get(ChatChannel, channel_id)
|
||||
if channel and channel.type == ChannelType.MULTIPLAYER:
|
||||
logger.warning(f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone")
|
||||
logger.warning(
|
||||
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
|
||||
)
|
||||
# 对于多人游戏房间,这可能是正常的(用户都离开了房间)
|
||||
# 但我们仍然记录这个情况以便调试
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check channel type for {channel_id}: {e}")
|
||||
|
||||
|
||||
for user_id in users_in_channel:
|
||||
await self.send_event(user_id, event)
|
||||
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
|
||||
|
||||
@@ -13,8 +13,8 @@ from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
from app.service.asset_proxy_helper import process_response_assets
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -97,7 +97,7 @@ async def search_beatmapset(
|
||||
try:
|
||||
sets = await fetcher.search_beatmapset(query, cursor, redis)
|
||||
background_tasks.add_task(_save_to_db, sets)
|
||||
|
||||
|
||||
# 处理资源代理
|
||||
processed_sets = await process_response_assets(sets, request)
|
||||
return processed_sets
|
||||
@@ -121,7 +121,7 @@ async def lookup_beatmapset(
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
||||
|
||||
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(resp, request)
|
||||
return processed_resp
|
||||
|
||||
@@ -56,20 +56,22 @@ async def get_all_rooms(
|
||||
db_category = category
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == db_category]
|
||||
now = utcnow()
|
||||
|
||||
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
#print(mode, category, status, current_user.id)
|
||||
# print(mode, category, status, current_user.id)
|
||||
if mode == "open":
|
||||
# 修改为新的查询逻辑:状态为 idle 或 playing,starts_at 不为空,ends_at 为空
|
||||
where_clauses.extend([
|
||||
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
|
||||
col(Room.starts_at).is_not(None),
|
||||
col(Room.ends_at).is_(None)
|
||||
])
|
||||
#if category == RoomCategory.REALTIME:
|
||||
where_clauses.extend(
|
||||
[
|
||||
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
|
||||
col(Room.starts_at).is_not(None),
|
||||
col(Room.ends_at).is_(None),
|
||||
]
|
||||
)
|
||||
# if category == RoomCategory.REALTIME:
|
||||
# where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
|
||||
|
||||
if mode == "participated":
|
||||
where_clauses.append(
|
||||
exists().where(
|
||||
@@ -77,10 +79,10 @@ async def get_all_rooms(
|
||||
col(RoomParticipatedUser.user_id) == current_user.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
|
||||
|
||||
if mode == "ended":
|
||||
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
|
||||
|
||||
@@ -96,7 +98,7 @@ async def get_all_rooms(
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
#print("Retrieved rooms:", db_rooms)
|
||||
# print("Retrieved rooms:", db_rooms)
|
||||
for room in db_rooms:
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
resp.has_password = bool((room.password or "").strip())
|
||||
@@ -424,4 +426,4 @@ async def get_room_events(
|
||||
playlist_items=playlist_items_resps,
|
||||
room=room_resp,
|
||||
user=user_resps,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,27 +5,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from fastapi import Request
|
||||
|
||||
from app.config import settings
|
||||
from app.service.asset_proxy_service import get_asset_proxy_service
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
async def process_response_assets(data: Any, request: Request) -> Any:
|
||||
"""
|
||||
根据配置处理响应数据中的资源URL
|
||||
|
||||
|
||||
Args:
|
||||
data: API响应数据
|
||||
request: FastAPI请求对象
|
||||
|
||||
|
||||
Returns:
|
||||
处理后的数据
|
||||
"""
|
||||
if not settings.enable_asset_proxy:
|
||||
return data
|
||||
|
||||
|
||||
asset_service = get_asset_proxy_service()
|
||||
|
||||
|
||||
# 仅URL替换模式
|
||||
return await asset_service.replace_asset_urls(data)
|
||||
|
||||
@@ -47,7 +49,7 @@ def should_process_asset_proxy(path: str) -> bool:
|
||||
"/api/v2/beatmapsets/",
|
||||
# 可以根据需要添加更多端点
|
||||
]
|
||||
|
||||
|
||||
return any(path.startswith(endpoint) for endpoint in asset_proxy_endpoints)
|
||||
|
||||
|
||||
@@ -56,6 +58,7 @@ def asset_proxy_response(func):
|
||||
"""
|
||||
装饰器:自动处理响应中的资源URL
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 获取request对象
|
||||
request = None
|
||||
@@ -63,14 +66,14 @@ def asset_proxy_response(func):
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
|
||||
# 执行原函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
|
||||
# 如果有request对象且启用了资源代理,则处理响应
|
||||
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
|
||||
result = await process_response_assets(result, request)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -7,8 +7,8 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AssetProxyService:
|
||||
@@ -26,7 +26,7 @@ class AssetProxyService:
|
||||
递归替换数据中的osu!资源URL为自定义域名
|
||||
"""
|
||||
# 处理Pydantic模型
|
||||
if hasattr(data, 'model_dump'):
|
||||
if hasattr(data, "model_dump"):
|
||||
# 转换为字典,处理后再转换回模型
|
||||
data_dict = data.model_dump()
|
||||
processed_dict = await self.replace_asset_urls(data_dict)
|
||||
@@ -46,35 +46,25 @@ class AssetProxyService:
|
||||
elif isinstance(data, str):
|
||||
# 替换各种osu!资源域名
|
||||
result = data
|
||||
|
||||
|
||||
# 替换 assets.ppy.sh (用户头像、封面、奖章等)
|
||||
result = re.sub(
|
||||
r"https://assets\.ppy\.sh/",
|
||||
f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://assets\.ppy\.sh/", f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
|
||||
# 替换 b.ppy.sh 预览音频 (保持//前缀)
|
||||
result = re.sub(
|
||||
r"//b\.ppy\.sh/",
|
||||
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
)
|
||||
|
||||
result = re.sub(r"//b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result)
|
||||
|
||||
# 替换 https://b.ppy.sh 预览音频 (转换为//前缀)
|
||||
result = re.sub(
|
||||
r"https://b\.ppy\.sh/",
|
||||
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
|
||||
# 替换 a.ppy.sh 头像
|
||||
result = re.sub(
|
||||
r"https://a\.ppy\.sh/",
|
||||
f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/",
|
||||
result
|
||||
r"https://a\.ppy\.sh/", f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/", result
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
else:
|
||||
return data
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
|
||||
@@ -14,8 +14,8 @@ from app.config import settings
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
from app.service.asset_proxy_service import get_asset_proxy_service
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
@@ -284,7 +284,7 @@ class RankingCacheService:
|
||||
ranking_data = []
|
||||
for statistics in statistics_data:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
|
||||
|
||||
# 应用资源代理处理
|
||||
if settings.enable_asset_proxy:
|
||||
try:
|
||||
@@ -292,7 +292,7 @@ class RankingCacheService:
|
||||
user_stats_resp = await asset_proxy_service.replace_asset_urls(user_stats_resp)
|
||||
except Exception as e:
|
||||
logger.warning(f"Asset proxy processing failed for ranking cache: {e}")
|
||||
|
||||
|
||||
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
|
||||
user_dict = json.loads(user_stats_resp.model_dump_json())
|
||||
ranking_data.append(user_dict)
|
||||
|
||||
@@ -254,14 +254,16 @@ class RedisMessageSystem:
|
||||
# 键类型错误,需要清理
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if verify_type != "none":
|
||||
logger.error(f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}")
|
||||
logger.error(
|
||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
||||
)
|
||||
# 强制删除
|
||||
await self._redis_exec(self.redis.unlink, channel_messages_key)
|
||||
|
||||
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,尝试强制删除键以确保清理
|
||||
@@ -597,13 +599,13 @@ class RedisMessageSystem:
|
||||
elif key_type != "zset":
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self._redis_exec(self.redis.type, key)
|
||||
if verify_type != "none":
|
||||
logger.error(f"Failed to delete problematic key {key}, trying unlink...")
|
||||
await self._redis_exec(self.redis.unlink, key)
|
||||
|
||||
|
||||
fixed_count += 1
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
@@ -634,10 +636,10 @@ class RedisMessageSystem:
|
||||
await asyncio.sleep(300)
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
|
||||
logger.debug("Running periodic Redis keys cleanup...")
|
||||
await self._cleanup_redis_keys()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
|
||||
@@ -299,7 +299,7 @@ class UserCacheService:
|
||||
"""缓存单个用户"""
|
||||
try:
|
||||
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
|
||||
|
||||
|
||||
# 应用资源代理处理
|
||||
if settings.enable_asset_proxy:
|
||||
try:
|
||||
@@ -307,7 +307,7 @@ class UserCacheService:
|
||||
user_resp = await asset_proxy_service.replace_asset_urls(user_resp)
|
||||
except Exception as e:
|
||||
logger.warning(f"Asset proxy processing failed for user cache {user.id}: {e}")
|
||||
|
||||
|
||||
await self.cache_user(user_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching single user {user.id}: {e}")
|
||||
|
||||
4
main.py
4
main.py
@@ -67,11 +67,11 @@ async def lifespan(app: FastAPI):
|
||||
start_stats_scheduler() # 启动统计调度器
|
||||
schedule_online_status_maintenance() # 启动在线状态维护任务
|
||||
load_achievements()
|
||||
|
||||
|
||||
# 显示资源代理状态
|
||||
if settings.enable_asset_proxy:
|
||||
logger.info(f"Asset Proxy enabled - Domain: {settings.custom_asset_domain}")
|
||||
|
||||
|
||||
# on shutdown
|
||||
yield
|
||||
bg_tasks.stop()
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 57bacf936413
|
||||
Create Date: 2025-08-24 00:08:14.704724
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 8d2af11343b9
|
||||
Create Date: 2025-08-24 04:00:02.063347
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
@@ -22,10 +23,14 @@ depends_on: str | Sequence[str] | None = None
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
#op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
|
||||
#op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
|
||||
op.add_column("room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
|
||||
# op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
|
||||
# op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
|
||||
op.add_column(
|
||||
"room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -35,5 +40,7 @@ def downgrade() -> None:
|
||||
op.drop_column("room_playlists", "updated_at")
|
||||
op.drop_column("room_playlists", "created_at")
|
||||
op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True)
|
||||
op.create_index(op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -5,13 +5,13 @@ Revises: 178873984b22
|
||||
Create Date: 2025-08-23 18:45:03.009632
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "57bacf936413"
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 20c6df84813f
|
||||
Create Date: 2025-08-24 00:08:42.419252
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 7576ca1e056d
|
||||
Create Date: 2025-08-24 00:11:05.064099
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
Reference in New Issue
Block a user