chore(deps): auto fix by pre-commit hooks

This commit is contained in:
pre-commit-ci[bot]
2025-08-24 03:18:58 +00:00
committed by MingxuanGame
parent b4fd4e0256
commit 7625cd99f5
25 changed files with 241 additions and 320 deletions

View File

@@ -154,21 +154,19 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
expire = utcnow() + expires_delta expire = utcnow() + expires_delta
else: else:
expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes) expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
# 添加标准JWT声明 # 添加标准JWT声明
to_encode.update({ to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
"exp": expire, if hasattr(settings, "jwt_audience") and settings.jwt_audience:
"random": secrets.token_hex(16)
})
if hasattr(settings, 'jwt_audience') and settings.jwt_audience:
to_encode["aud"] = settings.jwt_audience to_encode["aud"] = settings.jwt_audience
if hasattr(settings, 'jwt_issuer') and settings.jwt_issuer: if hasattr(settings, "jwt_issuer") and settings.jwt_issuer:
to_encode["iss"] = settings.jwt_issuer to_encode["iss"] = settings.jwt_issuer
# 编码JWT # 编码JWT
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt return encoded_jwt
def generate_refresh_token() -> str: def generate_refresh_token() -> str:
"""生成刷新令牌""" """生成刷新令牌"""
length = 64 length = 64

View File

@@ -291,11 +291,7 @@ class UserResp(UserBase):
).one() ).one()
redis = get_redis() redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = ( u.cover_url = obj.cover.get("url", "") if obj.cover else ""
obj.cover.get("url", "")
if obj.cover
else ""
)
if "friends" in include: if "friends" in include:
u.friends = [ u.friends = [

View File

@@ -1,17 +1,17 @@
from datetime import datetime from datetime import datetime
from typing import Optional, ClassVar from typing import ClassVar
from sqlalchemy import Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import SQLModel, Field, Column, DateTime, BigInteger, ForeignKey
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.utils import utcnow from app.utils import utcnow
from sqlalchemy import Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, SQLModel
class MultiplayerRealtimeRoomEventBase(SQLModel, UTCBaseModel): class MultiplayerRealtimeRoomEventBase(SQLModel, UTCBaseModel):
event_type: str = Field(index=True) event_type: str = Field(index=True)
event_detail: Optional[str] = Field(default=None, sa_column=Column(Text)) event_detail: str | None = Field(default=None, sa_column=Column(Text))
class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase, table=True): class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase, table=True):
@@ -19,9 +19,7 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
room_id: int = Field( room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False))
sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False)
)
playlist_item_id: int | None = Field( playlist_item_id: int | None = Field(
default=None, default=None,
sa_column=Column(ForeignKey("playlists.id"), index=True, nullable=True), sa_column=Column(ForeignKey("playlists.id"), index=True, nullable=True),
@@ -31,9 +29,5 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True, nullable=True), sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True, nullable=True),
) )
created_at: datetime = Field( created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
)
updated_at: datetime = Field(
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
)

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.models.mods import APIMod from app.models.mods import APIMod
@@ -60,16 +60,9 @@ class Playlist(PlaylistBase, table=True):
} }
) )
room: "Room" = Relationship() room: "Room" = Relationship()
created_at: Optional[datetime] = Field( created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
default=None, updated_at: datetime | None = Field(
sa_column_kwargs={"server_default": func.now()} default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column_kwargs={
"server_default": func.now(),
"onupdate": func.now()
}
) )
@classmethod @classmethod
@@ -139,4 +132,4 @@ class PlaylistResp(PlaylistBase):
if "beatmap" in include: if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap) data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
resp = cls.model_validate(data) resp = cls.model_validate(data)
return resp return resp

View File

@@ -74,7 +74,6 @@ class Room(AsyncAttrs, RoomBase, table=True):
) )
class RoomResp(RoomBase): class RoomResp(RoomBase):
id: int id: int
has_password: bool = False has_password: bool = False

View File

@@ -68,9 +68,7 @@ class BaseFetcher:
if response.status_code == 401: if response.status_code == 401:
logger.warning(f"Received 401 error for {url}") logger.warning(f"Received 401 error for {url}")
await self._clear_tokens() await self._clear_tokens()
raise TokenAuthError( raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -146,7 +144,7 @@ class BaseFetcher:
清除所有 token 清除所有 token
""" """
logger.warning(f"Clearing tokens for client {self.client_id}") logger.warning(f"Clearing tokens for client {self.client_id}")
# 清除内存中的 token # 清除内存中的 token
self.access_token = "" self.access_token = ""
self.refresh_token = "" self.refresh_token = ""
@@ -167,4 +165,4 @@ class BaseFetcher:
"has_refresh_token": bool(self.refresh_token), "has_refresh_token": bool(self.refresh_token),
"token_expired": self.is_token_expired(), "token_expired": self.is_token_expired(),
"authorize_url": self.authorize_url, "authorize_url": self.authorize_url,
} }

View File

@@ -14,12 +14,13 @@ from app.utils import bg_tasks
from ._base import BaseFetcher from ._base import BaseFetcher
import redis.asyncio as redis
from httpx import AsyncClient from httpx import AsyncClient
import redis.asyncio as redis
class RateLimitError(Exception): class RateLimitError(Exception):
"""速率限制异常""" """速率限制异常"""
pass pass
@@ -73,9 +74,7 @@ class BeatmapsetFetcher(BaseFetcher):
if response.status_code == 401: if response.status_code == 401:
logger.warning(f"Received 401 error for {url}") logger.warning(f"Received 401 error for {url}")
await self._clear_tokens() await self._clear_tokens()
raise TokenAuthError( raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -205,7 +204,9 @@ class BeatmapsetFetcher(BaseFetcher):
try: try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError: except RateLimitError:
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit") logger.opt(colors=True).info(
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
)
bg_tasks.add_task(delayed_prefetch) bg_tasks.add_task(delayed_prefetch)
@@ -352,4 +353,4 @@ class BeatmapsetFetcher(BaseFetcher):
except Exception as e: except Exception as e:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}" f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
) )

View File

@@ -212,7 +212,7 @@ async def oauth_token(
geoip: GeoIPHelper = Depends(get_geoip_helper), geoip: GeoIPHelper = Depends(get_geoip_helper),
): ):
scopes = scope.split(" ") scopes = scope.split(" ")
client = ( client = (
await db.exec( await db.exec(
select(OAuthClient).where( select(OAuthClient).where(

View File

@@ -1,15 +1,11 @@
"""LIO (Legacy IO) router for osu-server-spectator compatibility.""" """LIO (Legacy IO) router for osu-server-spectator compatibility."""
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Dict, List from typing import Any
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
from pydantic import BaseModel
from sqlmodel import col, select, desc
from sqlalchemy import update, func
from redis.asyncio import Redis
from app.database.chat import ChannelType, ChatChannel # ChatChannel 模型 & 枚举
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.playlists import Playlist as DBPlaylist from app.database.playlists import Playlist as DBPlaylist
from app.database.room import Room from app.database.room import Room
@@ -17,13 +13,18 @@ from app.database.room_participated_user import RoomParticipatedUser
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher from app.fetcher import Fetcher
from app.log import logger
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
from app.models.room import MatchType, QueueMode, RoomStatus from app.models.room import MatchType, QueueMode, RoomStatus
from app.utils import utcnow from app.utils import utcnow
from app.database.chat import ChatChannel, ChannelType # ChatChannel 模型 & 枚举
from .notification.server import server
from app.log import logger
from .notification.server import server
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy import func, update
from sqlmodel import col, select
router = APIRouter(prefix="/_lio", tags=["LIO"]) router = APIRouter(prefix="/_lio", tags=["LIO"])
@@ -40,9 +41,7 @@ async def _ensure_room_chat_channel(
# 1) 按 channel_id 查是否已存在 # 1) 按 channel_id 查是否已存在
try: try:
# Use db.execute instead of db.exec for better async compatibility # Use db.execute instead of db.exec for better async compatibility
result = await db.execute( result = await db.execute(select(ChatChannel).where(ChatChannel.channel_id == room.channel_id))
select(ChatChannel).where(ChatChannel.channel_id == room.channel_id)
)
ch = result.scalar_one_or_none() ch = result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.debug(f"Error querying ChatChannel: {e}") logger.debug(f"Error querying ChatChannel: {e}")
@@ -59,8 +58,8 @@ async def _ensure_room_chat_channel(
channel_id_value = int(room.channel_id) channel_id_value = int(room.channel_id)
ch = ChatChannel( ch = ChatChannel(
channel_id=channel_id_value, # 与房间绑定的同一 channel_id确保为 int channel_id=channel_id_value, # 与房间绑定的同一 channel_id确保为 int
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性) name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
description=f"Multiplayer room {room.id} chat", description=f"Multiplayer room {room.id} chat",
type=ChannelType.MULTIPLAYER, type=ChannelType.MULTIPLAYER,
) )
@@ -86,32 +85,34 @@ async def _alloc_channel_id(db: Database) -> int:
logger.debug(f"Error allocating channel_id: {e}") logger.debug(f"Error allocating channel_id: {e}")
# Fallback to a timestamp-based approach # Fallback to a timestamp-based approach
import time import time
return int(time.time()) % 1000000 + 100 return int(time.time()) % 1000000 + 100
class RoomCreateRequest(BaseModel): class RoomCreateRequest(BaseModel):
"""Request model for creating a multiplayer room.""" """Request model for creating a multiplayer room."""
name: str name: str
user_id: int user_id: int
password: str | None = None password: str | None = None
match_type: str = "HeadToHead" match_type: str = "HeadToHead"
queue_mode: str = "HostOnly" queue_mode: str = "HostOnly"
initial_playlist: List[Dict[str, Any]] = [] initial_playlist: list[dict[str, Any]] = []
playlist: List[Dict[str, Any]] = [] playlist: list[dict[str, Any]] = []
def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool: def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool:
""" """
Verify HMAC signature for shared interop requests. Verify HMAC signature for shared interop requests.
Args: Args:
request: FastAPI request object request: FastAPI request object
timestamp: Request timestamp timestamp: Request timestamp
body: Request body bytes body: Request body bytes
Returns: Returns:
bool: True if signature is valid bool: True if signature is valid
Note: Note:
Currently skips verification in development. Currently skips verification in development.
In production, implement proper HMAC verification. In production, implement proper HMAC verification.
@@ -124,13 +125,10 @@ async def _validate_user_exists(db: Database, user_id: int) -> User:
"""Validate that a user exists in the database.""" """Validate that a user exists in the database."""
user_result = await db.execute(select(User).where(User.id == user_id)) user_result = await db.execute(select(User).where(User.id == user_id))
user = user_result.scalar_one_or_none() user = user_result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with ID {user_id} not found")
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User with ID {user_id} not found"
)
return user return user
@@ -145,19 +143,19 @@ def _parse_room_enums(match_type: str, queue_mode: str) -> tuple[MatchType, Queu
queue_mode_enum = QueueMode(queue_mode.lower()) queue_mode_enum = QueueMode(queue_mode.lower())
except ValueError: except ValueError:
queue_mode_enum = QueueMode.HOST_ONLY queue_mode_enum = QueueMode.HOST_ONLY
return match_type_enum, queue_mode_enum return match_type_enum, queue_mode_enum
def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_user_id: int) -> Dict[str, Any]: def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_user_id: int) -> dict[str, Any]:
""" """
Normalize playlist item data with default values. Normalize playlist item data with default values.
Args: Args:
item_data: Raw playlist item data item_data: Raw playlist item data
default_order: Default playlist order default_order: Default playlist order
host_user_id: Host user ID for default owner host_user_id: Host user ID for default owner
Returns: Returns:
Dict with normalized playlist item data Dict with normalized playlist item data
""" """
@@ -165,7 +163,7 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
owner_id = item_data.get("owner_id", host_user_id) owner_id = item_data.get("owner_id", host_user_id)
if owner_id == 0: if owner_id == 0:
owner_id = host_user_id owner_id = host_user_id
return { return {
"owner_id": owner_id, "owner_id": owner_id,
"ruleset_id": item_data.get("ruleset_id", 0), "ruleset_id": item_data.get("ruleset_id", 0),
@@ -181,30 +179,28 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
} }
def _validate_playlist_items(items: List[Dict[str, Any]]) -> None: def _validate_playlist_items(items: list[dict[str, Any]]) -> None:
"""Validate playlist items data.""" """Validate playlist items data."""
if not items: if not items:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail="At least one playlist item is required to create a room"
detail="At least one playlist item is required to create a room"
) )
for idx, item in enumerate(items): for idx, item in enumerate(items):
if item["beatmap_id"] is None: if item["beatmap_id"] is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=f"Playlist item at index {idx} missing beatmap_id"
detail=f"Playlist item at index {idx} missing beatmap_id"
) )
ruleset_id = item["ruleset_id"] ruleset_id = item["ruleset_id"]
if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3): if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}" detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}",
) )
async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, int]: async def _create_room(db: Database, room_data: dict[str, Any]) -> tuple[Room, int]:
host_user_id = room_data.get("user_id") host_user_id = room_data.get("user_id")
room_name = room_data.get("name", "Unnamed Room") room_name = room_data.get("name", "Unnamed Room")
password = room_data.get("password") password = room_data.get("password")
@@ -232,7 +228,7 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
participant_count=1, participant_count=1,
auto_skip=False, auto_skip=False,
auto_start_duration=0, auto_start_duration=0,
channel_id=channel_id, channel_id=channel_id,
) )
db.add(room) db.add(room)
@@ -242,27 +238,27 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
return room, host_user_id return room, host_user_id
async def _add_playlist_items(db: Database, room_id: int, room_data: Dict[str, Any], host_user_id: int) -> None: async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, Any], host_user_id: int) -> None:
"""Add playlist items to the room.""" """Add playlist items to the room."""
initial_playlist = room_data.get("initial_playlist", []) initial_playlist = room_data.get("initial_playlist", [])
legacy_playlist = room_data.get("playlist", []) legacy_playlist = room_data.get("playlist", [])
items_raw: List[Dict[str, Any]] = [] items_raw: list[dict[str, Any]] = []
# Process initial playlist # Process initial playlist
for i, item in enumerate(initial_playlist): for i, item in enumerate(initial_playlist):
if hasattr(item, "dict"): if hasattr(item, "dict"):
item = item.dict() item = item.dict()
items_raw.append(_coerce_playlist_item(item, i, host_user_id)) items_raw.append(_coerce_playlist_item(item, i, host_user_id))
# Process legacy playlist # Process legacy playlist
start_index = len(items_raw) start_index = len(items_raw)
for j, item in enumerate(legacy_playlist, start=start_index): for j, item in enumerate(legacy_playlist, start=start_index):
items_raw.append(_coerce_playlist_item(item, j, host_user_id)) items_raw.append(_coerce_playlist_item(item, j, host_user_id))
# Validate playlist items # Validate playlist items
_validate_playlist_items(items_raw) _validate_playlist_items(items_raw)
# Insert playlist items # Insert playlist items
for item_data in items_raw: for item_data in items_raw:
hub_item = HubPlaylistItem( hub_item = HubPlaylistItem(
@@ -286,42 +282,31 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int
"""Add the host as a room participant and update participant count.""" """Add the host as a room participant and update participant count."""
participant = RoomParticipatedUser(room_id=room_id, user_id=host_user_id) participant = RoomParticipatedUser(room_id=room_id, user_id=host_user_id)
db.add(participant) db.add(participant)
await _update_room_participant_count(db, room_id) await _update_room_participant_count(db, room_id)
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None: async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
"""Verify room password if required.""" """Verify room password if required."""
room_result = await db.execute( room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
select(Room).where(col(Room.id) == room_id)
)
room = room_result.scalar_one_or_none() room = room_result.scalar_one_or_none()
if room is None: if room is None:
logger.debug(f"Room {room_id} not found") logger.debug(f"Room {room_id} not found")
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
status_code=status.HTTP_404_NOT_FOUND,
detail="Room not found"
)
logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}") logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}")
# If room has password but none provided # If room has password but none provided
if room.password and not provided_password: if room.password and not provided_password:
logger.debug(f"Room {room_id} requires password but none provided") logger.debug(f"Room {room_id} requires password but none provided")
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Password required")
status_code=status.HTTP_403_FORBIDDEN,
detail="Password required"
)
# If room has password and provided password doesn't match # If room has password and provided password doesn't match
if room.password and provided_password and provided_password != room.password: if room.password and provided_password and provided_password != room.password:
logger.debug(f"Room {room_id} password mismatch") logger.debug(f"Room {room_id} password mismatch")
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid password")
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid password"
)
logger.debug(f"Room {room_id} password verification passed") logger.debug(f"Room {room_id} password verification passed")
@@ -332,16 +317,16 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
select(RoomParticipatedUser.id).where( select(RoomParticipatedUser.id).where(
RoomParticipatedUser.room_id == room_id, RoomParticipatedUser.room_id == room_id,
RoomParticipatedUser.user_id == user_id, RoomParticipatedUser.user_id == user_id,
col(RoomParticipatedUser.left_at).is_(None) col(RoomParticipatedUser.left_at).is_(None),
) )
) )
existing_ids = existing_result.scalars().all() # 获取所有匹配的ID existing_ids = existing_result.scalars().all() # 获取所有匹配的ID
if existing_ids: if existing_ids:
# 如果存在多条记录,清理重复项,只保留最新的一条 # 如果存在多条记录,清理重复项,只保留最新的一条
if len(existing_ids) > 1: if len(existing_ids) > 1:
logger.debug(f"警告:用户 {user_id} 在房间 {room_id} 中发现 {len(existing_ids)} 条活跃参与记录") logger.debug(f"警告:用户 {user_id} 在房间 {room_id} 中发现 {len(existing_ids)} 条活跃参与记录")
# 将除第一条外的所有记录标记为已离开(清理重复记录) # 将除第一条外的所有记录标记为已离开(清理重复记录)
for extra_id in existing_ids[1:]: for extra_id in existing_ids[1:]:
await db.execute( await db.execute(
@@ -349,7 +334,7 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
.where(col(RoomParticipatedUser.id) == extra_id) .where(col(RoomParticipatedUser.id) == extra_id)
.values(left_at=utcnow()) .values(left_at=utcnow())
) )
# 更新剩余的活跃参与记录(刷新加入时间) # 更新剩余的活跃参与记录(刷新加入时间)
await db.execute( await db.execute(
update(RoomParticipatedUser) update(RoomParticipatedUser)
@@ -364,38 +349,36 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
class BeatmapEnsureRequest(BaseModel): class BeatmapEnsureRequest(BaseModel):
"""Request model for ensuring beatmap exists.""" """Request model for ensuring beatmap exists."""
beatmap_id: int beatmap_id: int
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]: async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> dict[str, Any]:
""" """
确保谱面存在(包括元数据和原始文件缓存)。 确保谱面存在(包括元数据和原始文件缓存)。
Args: Args:
db: 数据库会话 db: 数据库会话
fetcher: API获取器 fetcher: API获取器
redis: Redis连接 redis: Redis连接
beatmap_id: 谱面ID beatmap_id: 谱面ID
Returns: Returns:
Dict: 包含状态信息的响应 Dict: 包含状态信息的响应
""" """
try: try:
# 1. 确保谱面元数据存在于数据库中 # 1. 确保谱面元数据存在于数据库中
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
if not beatmap: if not beatmap:
return { return {"success": False, "error": f"Beatmap {beatmap_id} not found", "beatmap_id": beatmap_id}
"success": False,
"error": f"Beatmap {beatmap_id} not found",
"beatmap_id": beatmap_id
}
# 2. 预缓存谱面原始文件 # 2. 预缓存谱面原始文件
cache_key = f"beatmap:{beatmap_id}:raw" cache_key = f"beatmap:{beatmap_id}:raw"
cached = await redis.exists(cache_key) cached = await redis.exists(cache_key)
if not cached: if not cached:
# 异步预加载原始文件到缓存 # 异步预加载原始文件到缓存
try: try:
@@ -404,42 +387,34 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
except Exception as e: except Exception as e:
logger.debug(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}") logger.debug(f"Warning: Failed to cache raw beatmap {beatmap_id}: {e}")
# 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在) # 即使原始文件缓存失败,也认为确保操作成功(因为元数据已存在)
return { return {
"success": True, "success": True,
"beatmap_id": beatmap_id, "beatmap_id": beatmap_id,
"metadata_cached": True, "metadata_cached": True,
"raw_file_cached": await redis.exists(cache_key), "raw_file_cached": await redis.exists(cache_key),
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]" "beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]",
} }
except Exception as e: except Exception as e:
logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}") logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}")
return { return {"success": False, "error": str(e), "beatmap_id": beatmap_id}
"success": False,
"error": str(e),
"beatmap_id": beatmap_id
}
async def _update_room_participant_count(db: Database, room_id: int) -> int: async def _update_room_participant_count(db: Database, room_id: int) -> int:
"""更新房间参与者数量并返回当前数量。""" """更新房间参与者数量并返回当前数量。"""
# 统计活跃参与者 # 统计活跃参与者
active_participants_result = await db.execute( active_participants_result = await db.execute(
select(RoomParticipatedUser.user_id).where( select(RoomParticipatedUser.user_id).where(
RoomParticipatedUser.room_id == room_id, RoomParticipatedUser.room_id == room_id, col(RoomParticipatedUser.left_at).is_(None)
col(RoomParticipatedUser.left_at).is_(None)
) )
) )
active_participants = active_participants_result.all() active_participants = active_participants_result.all()
count = len(active_participants) count = len(active_participants)
# 更新房间参与者数量 # 更新房间参与者数量
await db.execute( await db.execute(update(Room).where(col(Room.id) == room_id).values(participant_count=count))
update(Room)
.where(col(Room.id) == room_id)
.values(participant_count=count)
)
return count return count
@@ -447,7 +422,7 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
"""如果房间为空,则标记房间结束。返回是否结束了房间。""" """如果房间为空,则标记房间结束。返回是否结束了房间。"""
# 检查房间是否还有活跃参与者 # 检查房间是否还有活跃参与者
participant_count = await _update_room_participant_count(db, room_id) participant_count = await _update_room_participant_count(db, room_id)
if participant_count == 0: if participant_count == 0:
# 房间为空,标记结束 # 房间为空,标记结束
now = utcnow() now = utcnow()
@@ -457,12 +432,12 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
.values( .values(
ends_at=now, ends_at=now,
status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态 status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态
participant_count=0 participant_count=0,
) )
) )
logger.debug(f"Room {room_id} ended automatically (no participants remaining)") logger.debug(f"Room {room_id} ended automatically (no participants remaining)")
return True return True
return False return False
@@ -474,32 +449,30 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
.where( .where(
col(RoomParticipatedUser.room_id) == room_id, col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) != leaving_user_id, col(RoomParticipatedUser.user_id) != leaving_user_id,
col(RoomParticipatedUser.left_at).is_(None) col(RoomParticipatedUser.left_at).is_(None),
) )
.order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序 .order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序
) )
remaining_participants = remaining_result.all() remaining_participants = remaining_result.all()
if remaining_participants: if remaining_participants:
# 将房主权限转让给最早加入的用户 # 将房主权限转让给最早加入的用户
new_owner_id = remaining_participants[0][0] # 获取 user_id new_owner_id = remaining_participants[0][0] # 获取 user_id
await db.execute( await db.execute(update(Room).where(col(Room.id) == room_id).values(host_id=new_owner_id))
update(Room)
.where(col(Room.id) == room_id)
.values(host_id=new_owner_id)
)
logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}") logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}")
return False # 房间继续存在 return False # 房间继续存在
else: else:
# 没有其他参与者,结束房间 # 没有其他参与者,结束房间
return await _end_room_if_empty(db, room_id) return await _end_room_if_empty(db, room_id)
# ===== API ENDPOINTS ===== # ===== API ENDPOINTS =====
@router.post("/multiplayer/rooms") @router.post("/multiplayer/rooms")
async def create_multiplayer_room( async def create_multiplayer_room(
request: Request, request: Request,
room_data: Dict[str, Any], room_data: dict[str, Any],
db: Database, db: Database,
timestamp: str = "", timestamp: str = "",
) -> int: ) -> int:
@@ -508,10 +481,7 @@ async def create_multiplayer_room(
# Verify request signature # Verify request signature
body = await request.body() body = await request.body()
if not verify_request_signature(request, str(timestamp), body): if not verify_request_signature(request, str(timestamp), body):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
# Parse room data if string # Parse room data if string
if isinstance(room_data, str): if isinstance(room_data, str):
@@ -532,13 +502,13 @@ async def create_multiplayer_room(
await server.batch_join_channel([host_user], channel, db) await server.batch_join_channel([host_user], channel, db)
# Add playlist items # Add playlist items
await _add_playlist_items(db, room_id, room_data, host_user_id) await _add_playlist_items(db, room_id, room_data, host_user_id)
# Add host as participant # Add host as participant
#await _add_host_as_participant(db, room_id, host_user_id) # await _add_host_as_participant(db, room_id, host_user_id)
await db.commit() await db.commit()
return room_id return room_id
except HTTPException: except HTTPException:
# Clean up room if playlist creation fails # Clean up room if playlist creation fails
await db.delete(room) await db.delete(room)
@@ -546,18 +516,11 @@ async def create_multiplayer_room(
raise raise
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {e!s}")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON: {str(e)}"
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create room: {e!s}")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create room: {str(e)}"
)
@router.delete("/multiplayer/rooms/{room_id}/users/{user_id}") @router.delete("/multiplayer/rooms/{room_id}/users/{user_id}")
@@ -567,36 +530,28 @@ async def remove_user_from_room(
user_id: int, user_id: int,
db: Database, db: Database,
timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0), timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0),
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Remove a user from a multiplayer room.""" """Remove a user from a multiplayer room."""
try: try:
# Verify request signature # Verify request signature
body = await request.body() body = await request.body()
now = utcnow() now = utcnow()
if not verify_request_signature(request, str(timestamp), body): if not verify_request_signature(request, str(timestamp), body):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
# 检查房间是否存在 # 检查房间是否存在
room_result = await db.execute( room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
select(Room).where(col(Room.id) == room_id)
)
room = room_result.scalar_one_or_none() room = room_result.scalar_one_or_none()
if room is None: if room is None:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
status_code=status.HTTP_404_NOT_FOUND,
detail="Room not found"
)
room_owner_id = room.host_id room_owner_id = room.host_id
room_status = room.status room_status = room.status
current_participant_count = room.participant_count current_participant_count = room.participant_count
ends_at = room.ends_at ends_at = room.ends_at
channel_id = room.channel_id channel_id = room.channel_id
# 如果房间已经结束,直接返回 # 如果房间已经结束,直接返回
if ends_at is not None: if ends_at is not None:
logger.debug(f"Room {room_id} is already ended") logger.debug(f"Room {room_id} is already ended")
@@ -604,11 +559,10 @@ async def remove_user_from_room(
# 检查用户是否在房间中 # 检查用户是否在房间中
participant_result = await db.execute( participant_result = await db.execute(
select(RoomParticipatedUser.id) select(RoomParticipatedUser.id).where(
.where(
col(RoomParticipatedUser.room_id) == room_id, col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) == user_id, col(RoomParticipatedUser.user_id) == user_id,
col(RoomParticipatedUser.left_at).is_(None) col(RoomParticipatedUser.left_at).is_(None),
) )
) )
participant_query = participant_result.first() participant_query = participant_result.first()
@@ -634,13 +588,13 @@ async def remove_user_from_room(
.where( .where(
col(RoomParticipatedUser.room_id) == room_id, col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) == user_id, col(RoomParticipatedUser.user_id) == user_id,
col(RoomParticipatedUser.left_at).is_(None) col(RoomParticipatedUser.left_at).is_(None),
) )
.values(left_at=now) .values(left_at=now)
) )
room_ended = False room_ended = False
# 检查是否是房主离开 # 检查是否是房主离开
if user_id == room_owner_id: if user_id == room_owner_id:
logger.debug(f"Host {user_id} is leaving room {room_id}") logger.debug(f"Host {user_id} is leaving room {room_id}")
@@ -648,10 +602,10 @@ async def remove_user_from_room(
else: else:
# 不是房主离开,只需检查房间是否为空 # 不是房主离开,只需检查房间是否为空
room_ended = await _end_room_if_empty(db, room_id) room_ended = await _end_room_if_empty(db, room_id)
await db.commit() await db.commit()
logger.debug(f"Successfully removed user {user_id} from room {room_id}, room_ended: {room_ended}") logger.debug(f"Successfully removed user {user_id} from room {room_id}, room_ended: {room_ended}")
# ===== 新增:提交后,把用户从聊天频道移除;若房间已结束,清理内存频道 ===== # ===== 新增:提交后,把用户从聊天频道移除;若房间已结束,清理内存频道 =====
try: try:
if channel_id: if channel_id:
@@ -660,17 +614,16 @@ async def remove_user_from_room(
server.channels.pop(int(channel_id), None) server.channels.pop(int(channel_id), None)
except Exception as e: except Exception as e:
logger.debug(f"[warn] failed to leave user {user_id} from channel {channel_id}: {e}") logger.debug(f"[warn] failed to leave user {user_id} from channel {channel_id}: {e}")
return {"success": True, "room_ended": room_ended} return {"success": True, "room_ended": room_ended}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.debug(f"Error removing user from room: {str(e)}") logger.debug(f"Error removing user from room: {e!s}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user from room: {e!s}"
detail=f"Failed to remove user from room: {str(e)}"
) )
@@ -681,32 +634,28 @@ async def add_user_to_room(
user_id: int, user_id: int,
db: Database, db: Database,
timestamp: str = "", timestamp: str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Add a user to a multiplayer room.""" """Add a user to a multiplayer room."""
logger.debug(f"Adding user {user_id} to room {room_id}") logger.debug(f"Adding user {user_id} to room {room_id}")
# Get request body and parse user_data # Get request body and parse user_data
body = await request.body() body = await request.body()
user_data = None user_data = None
if body: if body:
try: try:
user_data = json.loads(body.decode('utf-8')) user_data = json.loads(body.decode("utf-8"))
logger.debug(f"Parsed user_data: {user_data}") logger.debug(f"Parsed user_data: {user_data}")
except json.JSONDecodeError: except json.JSONDecodeError:
logger.debug("Failed to parse user_data from request body") logger.debug("Failed to parse user_data from request body")
user_data = None user_data = None
# Verify request signature # Verify request signature
if not verify_request_signature(request, timestamp, body): if not verify_request_signature(request, timestamp, body):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
# 检查房间是否已结束 # 检查房间是否已结束
room_result = await db.execute( room_result = await db.execute(
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id) select(Room.id, Room.ends_at, Room.channel_id, Room.host_id).where(col(Room.id) == room_id)
.where(col(Room.id) == room_id)
) )
room_row = room_result.first() room_row = room_result.first()
if not room_row: if not room_row:
@@ -763,10 +712,10 @@ async def ensure_beatmap_present(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
timestamp: str = "", timestamp: str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
确保谱面在服务器中存在(包括元数据和原始文件缓存)。 确保谱面在服务器中存在(包括元数据和原始文件缓存)。
这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用, 这个接口用于 osu-server-spectator 确保谱面文件在服务器端可用,
避免在需要时才获取导致的延迟。 避免在需要时才获取导致的延迟。
""" """
@@ -774,20 +723,17 @@ async def ensure_beatmap_present(
# Verify request signature # Verify request signature
body = await request.body() body = await request.body()
if not verify_request_signature(request, timestamp, body): if not verify_request_signature(request, timestamp, body):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
beatmap_id = beatmap_data.beatmap_id beatmap_id = beatmap_data.beatmap_id
logger.debug(f"Ensuring beatmap {beatmap_id} is present") logger.debug(f"Ensuring beatmap {beatmap_id} is present")
# 确保谱面存在 # 确保谱面存在
result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id) result = await _ensure_beatmap_exists(db, fetcher, redis, beatmap_id)
# 提交数据库更改 # 提交数据库更改
await db.commit() await db.commit()
logger.debug(f"Ensure beatmap {beatmap_id} result: {result}") logger.debug(f"Ensure beatmap {beatmap_id} result: {result}")
return result return result
@@ -795,8 +741,7 @@ async def ensure_beatmap_present(
raise raise
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.debug(f"Error ensuring beatmap: {str(e)}") logger.debug(f"Error ensuring beatmap: {e!s}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to ensure beatmap: {e!s}"
detail=f"Failed to ensure beatmap: {str(e)}" )
)

View File

@@ -96,6 +96,7 @@ async def send_message(
if channel_type == ChannelType.MULTIPLAYER: if channel_type == ChannelType.MULTIPLAYER:
try: try:
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
redis = get_redis() redis = get_redis()
key = f"channel:{channel_id}:messages" key = f"channel:{channel_id}:messages"
key_type = await redis.type(key) key_type = await redis.type(key)
@@ -162,13 +163,9 @@ async def get_message(
): ):
# 1) 查频道 # 1) 查频道
if channel.isdigit(): if channel.isdigit():
db_channel = (await session.exec( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)).first()
else: else:
db_channel = (await session.exec( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
select(ChatChannel).where(ChatChannel.name == channel)
)).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -176,7 +173,6 @@ async def get_message(
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
try: try:
messages = await redis_message_system.get_messages(channel_id, limit, since) messages = await redis_message_system.get_messages(channel_id, limit, since)
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id: if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
messages.reverse() messages.reverse()
@@ -188,11 +184,7 @@ async def get_message(
if since > 0 and until is None: if since > 0 and until is None:
# 向前加载新消息 → 直接 ASC # 向前加载新消息 → 直接 ASC
query = ( query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
base.where(col(ChatMessage.message_id) > since)
.order_by(col(ChatMessage.message_id).asc())
.limit(limit)
)
rows = (await session.exec(query)).all() rows = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(m, session) for m in rows] resp = [await ChatMessageResp.from_db(m, session) for m in rows]
# 已经 ASC无需反转 # 已经 ASC无需反转
@@ -202,9 +194,7 @@ async def get_message(
if until is not None: if until is not None:
# 用 DESC 取最近的更早消息,再反转为 ASC # 用 DESC 取最近的更早消息,再反转为 ASC
query = ( query = (
base.where(col(ChatMessage.message_id) < until) base.where(col(ChatMessage.message_id) < until).order_by(col(ChatMessage.message_id).desc()).limit(limit)
.order_by(col(ChatMessage.message_id).desc())
.limit(limit)
) )
rows = (await session.exec(query)).all() rows = (await session.exec(query)).all()
rows = list(rows) rows = list(rows)
@@ -221,7 +211,6 @@ async def get_message(
return resp return resp
@router.put( @router.put(
"/chat/channels/{channel}/mark-as-read/{message}", "/chat/channels/{channel}/mark-as-read/{message}",
status_code=204, status_code=204,

View File

@@ -76,20 +76,21 @@ class ChatServer:
async def broadcast(self, channel_id: int, event: ChatEvent): async def broadcast(self, channel_id: int, event: ChatEvent):
users_in_channel = self.channels.get(channel_id, []) users_in_channel = self.channels.get(channel_id, [])
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}") logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
# 如果频道中没有用户,检查是否是多人游戏频道 # 如果频道中没有用户,检查是否是多人游戏频道
if not users_in_channel: if not users_in_channel:
try: try:
async with with_db() as session: async with with_db() as session:
from sqlmodel import select
channel = await session.get(ChatChannel, channel_id) channel = await session.get(ChatChannel, channel_id)
if channel and channel.type == ChannelType.MULTIPLAYER: if channel and channel.type == ChannelType.MULTIPLAYER:
logger.warning(f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone") logger.warning(
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
)
# 对于多人游戏房间,这可能是正常的(用户都离开了房间) # 对于多人游戏房间,这可能是正常的(用户都离开了房间)
# 但我们仍然记录这个情况以便调试 # 但我们仍然记录这个情况以便调试
except Exception as e: except Exception as e:
logger.error(f"Failed to check channel type for {channel_id}: {e}") logger.error(f"Failed to check channel type for {channel_id}: {e}")
for user_id in users_in_channel: for user_id in users_in_channel:
await self.send_event(user_id, event) await self.send_event(user_id, event)
logger.debug(f"Sent event to user {user_id} in channel {channel_id}") logger.debug(f"Sent event to user {user_id} in channel {channel_id}")

View File

@@ -13,8 +13,8 @@ from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher from app.fetcher import Fetcher
from app.models.beatmap import SearchQueryModel from app.models.beatmap import SearchQueryModel
from app.service.beatmap_download_service import BeatmapDownloadService
from app.service.asset_proxy_helper import process_response_assets from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmap_download_service import BeatmapDownloadService
from .router import router from .router import router
@@ -97,7 +97,7 @@ async def search_beatmapset(
try: try:
sets = await fetcher.search_beatmapset(query, cursor, redis) sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets) background_tasks.add_task(_save_to_db, sets)
# 处理资源代理 # 处理资源代理
processed_sets = await process_response_assets(sets, request) processed_sets = await process_response_assets(sets, request)
return processed_sets return processed_sets
@@ -121,7 +121,7 @@ async def lookup_beatmapset(
): ):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user) resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
# 处理资源代理 # 处理资源代理
processed_resp = await process_response_assets(resp, request) processed_resp = await process_response_assets(resp, request)
return processed_resp return processed_resp

View File

@@ -56,20 +56,22 @@ async def get_all_rooms(
db_category = category db_category = category
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == db_category] where_clauses: list[ColumnElement[bool]] = [col(Room.category) == db_category]
now = utcnow() now = utcnow()
if status is not None: if status is not None:
where_clauses.append(col(Room.status) == status) where_clauses.append(col(Room.status) == status)
#print(mode, category, status, current_user.id) # print(mode, category, status, current_user.id)
if mode == "open": if mode == "open":
# 修改为新的查询逻辑:状态为 idle 或 playingstarts_at 不为空ends_at 为空 # 修改为新的查询逻辑:状态为 idle 或 playingstarts_at 不为空ends_at 为空
where_clauses.extend([ where_clauses.extend(
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]), [
col(Room.starts_at).is_not(None), col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
col(Room.ends_at).is_(None) col(Room.starts_at).is_not(None),
]) col(Room.ends_at).is_(None),
#if category == RoomCategory.REALTIME: ]
)
# if category == RoomCategory.REALTIME:
# where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys())) # where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated": if mode == "participated":
where_clauses.append( where_clauses.append(
exists().where( exists().where(
@@ -77,10 +79,10 @@ async def get_all_rooms(
col(RoomParticipatedUser.user_id) == current_user.id, col(RoomParticipatedUser.user_id) == current_user.id,
) )
) )
if mode == "owned": if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id) where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended": if mode == "ended":
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC))) where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
@@ -96,7 +98,7 @@ async def get_all_rooms(
.unique() .unique()
.all() .all()
) )
#print("Retrieved rooms:", db_rooms) # print("Retrieved rooms:", db_rooms)
for room in db_rooms: for room in db_rooms:
resp = await RoomResp.from_db(room, db) resp = await RoomResp.from_db(room, db)
resp.has_password = bool((room.password or "").strip()) resp.has_password = bool((room.password or "").strip())
@@ -424,4 +426,4 @@ async def get_room_events(
playlist_items=playlist_items_resps, playlist_items=playlist_items_resps,
room=room_resp, room=room_resp,
user=user_resps, user=user_resps,
) )

View File

@@ -5,27 +5,29 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from fastapi import Request
from app.config import settings from app.config import settings
from app.service.asset_proxy_service import get_asset_proxy_service from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request
async def process_response_assets(data: Any, request: Request) -> Any: async def process_response_assets(data: Any, request: Request) -> Any:
""" """
根据配置处理响应数据中的资源URL 根据配置处理响应数据中的资源URL
Args: Args:
data: API响应数据 data: API响应数据
request: FastAPI请求对象 request: FastAPI请求对象
Returns: Returns:
处理后的数据 处理后的数据
""" """
if not settings.enable_asset_proxy: if not settings.enable_asset_proxy:
return data return data
asset_service = get_asset_proxy_service() asset_service = get_asset_proxy_service()
# 仅URL替换模式 # 仅URL替换模式
return await asset_service.replace_asset_urls(data) return await asset_service.replace_asset_urls(data)
@@ -47,7 +49,7 @@ def should_process_asset_proxy(path: str) -> bool:
"/api/v2/beatmapsets/", "/api/v2/beatmapsets/",
# 可以根据需要添加更多端点 # 可以根据需要添加更多端点
] ]
return any(path.startswith(endpoint) for endpoint in asset_proxy_endpoints) return any(path.startswith(endpoint) for endpoint in asset_proxy_endpoints)
@@ -56,6 +58,7 @@ def asset_proxy_response(func):
""" """
装饰器自动处理响应中的资源URL 装饰器自动处理响应中的资源URL
""" """
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# 获取request对象 # 获取request对象
request = None request = None
@@ -63,14 +66,14 @@ def asset_proxy_response(func):
if isinstance(arg, Request): if isinstance(arg, Request):
request = arg request = arg
break break
# 执行原函数 # 执行原函数
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# 如果有request对象且启用了资源代理则处理响应 # 如果有request对象且启用了资源代理则处理响应
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path): if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
result = await process_response_assets(result, request) result = await process_response_assets(result, request)
return result return result
return wrapper return wrapper

View File

@@ -7,8 +7,8 @@ from __future__ import annotations
import re import re
from typing import Any from typing import Any
from app.config import settings from app.config import settings
from app.log import logger
class AssetProxyService: class AssetProxyService:
@@ -26,7 +26,7 @@ class AssetProxyService:
递归替换数据中的osu!资源URL为自定义域名 递归替换数据中的osu!资源URL为自定义域名
""" """
# 处理Pydantic模型 # 处理Pydantic模型
if hasattr(data, 'model_dump'): if hasattr(data, "model_dump"):
# 转换为字典,处理后再转换回模型 # 转换为字典,处理后再转换回模型
data_dict = data.model_dump() data_dict = data.model_dump()
processed_dict = await self.replace_asset_urls(data_dict) processed_dict = await self.replace_asset_urls(data_dict)
@@ -46,35 +46,25 @@ class AssetProxyService:
elif isinstance(data, str): elif isinstance(data, str):
# 替换各种osu!资源域名 # 替换各种osu!资源域名
result = data result = data
# 替换 assets.ppy.sh (用户头像、封面、奖章等) # 替换 assets.ppy.sh (用户头像、封面、奖章等)
result = re.sub( result = re.sub(
r"https://assets\.ppy\.sh/", r"https://assets\.ppy\.sh/", f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/", result
f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/",
result
) )
# 替换 b.ppy.sh 预览音频 (保持//前缀) # 替换 b.ppy.sh 预览音频 (保持//前缀)
result = re.sub( result = re.sub(r"//b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result)
r"//b\.ppy\.sh/",
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
result
)
# 替换 https://b.ppy.sh 预览音频 (转换为//前缀) # 替换 https://b.ppy.sh 预览音频 (转换为//前缀)
result = re.sub( result = re.sub(
r"https://b\.ppy\.sh/", r"https://b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
result
) )
# 替换 a.ppy.sh 头像 # 替换 a.ppy.sh 头像
result = re.sub( result = re.sub(
r"https://a\.ppy\.sh/", r"https://a\.ppy\.sh/", f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/", result
f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/",
result
) )
return result return result
else: else:
return data return data

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, UTC from datetime import UTC, datetime, timedelta
import json import json
from app.dependencies.database import get_redis, get_redis_message from app.dependencies.database import get_redis, get_redis_message

View File

@@ -14,8 +14,8 @@ from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.log import logger from app.log import logger
from app.models.score import GameMode from app.models.score import GameMode
from app.utils import utcnow
from app.service.asset_proxy_service import get_asset_proxy_service from app.service.asset_proxy_service import get_asset_proxy_service
from app.utils import utcnow
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
@@ -284,7 +284,7 @@ class RankingCacheService:
ranking_data = [] ranking_data = []
for statistics in statistics_data: for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include) user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
# 应用资源代理处理 # 应用资源代理处理
if settings.enable_asset_proxy: if settings.enable_asset_proxy:
try: try:
@@ -292,7 +292,7 @@ class RankingCacheService:
user_stats_resp = await asset_proxy_service.replace_asset_urls(user_stats_resp) user_stats_resp = await asset_proxy_service.replace_asset_urls(user_stats_resp)
except Exception as e: except Exception as e:
logger.warning(f"Asset proxy processing failed for ranking cache: {e}") logger.warning(f"Asset proxy processing failed for ranking cache: {e}")
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题 # 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json()) user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict) ranking_data.append(user_dict)

View File

@@ -254,14 +254,16 @@ class RedisMessageSystem:
# 键类型错误,需要清理 # 键类型错误,需要清理
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, channel_messages_key) await self._redis_exec(self.redis.delete, channel_messages_key)
# 验证删除是否成功 # 验证删除是否成功
verify_type = await self._redis_exec(self.redis.type, channel_messages_key) verify_type = await self._redis_exec(self.redis.type, channel_messages_key)
if verify_type != "none": if verify_type != "none":
logger.error(f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}") logger.error(
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
)
# 强制删除 # 强制删除
await self._redis_exec(self.redis.unlink, channel_messages_key) await self._redis_exec(self.redis.unlink, channel_messages_key)
except Exception as type_check_error: except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,尝试强制删除键以确保清理 # 如果检查失败,尝试强制删除键以确保清理
@@ -597,13 +599,13 @@ class RedisMessageSystem:
elif key_type != "zset": elif key_type != "zset":
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, key) await self._redis_exec(self.redis.delete, key)
# 验证删除是否成功 # 验证删除是否成功
verify_type = await self._redis_exec(self.redis.type, key) verify_type = await self._redis_exec(self.redis.type, key)
if verify_type != "none": if verify_type != "none":
logger.error(f"Failed to delete problematic key {key}, trying unlink...") logger.error(f"Failed to delete problematic key {key}, trying unlink...")
await self._redis_exec(self.redis.unlink, key) await self._redis_exec(self.redis.unlink, key)
fixed_count += 1 fixed_count += 1
except Exception as cleanup_error: except Exception as cleanup_error:
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
@@ -634,10 +636,10 @@ class RedisMessageSystem:
await asyncio.sleep(300) await asyncio.sleep(300)
if not self._running: if not self._running:
break break
logger.debug("Running periodic Redis keys cleanup...") logger.debug("Running periodic Redis keys cleanup...")
await self._cleanup_redis_keys() await self._cleanup_redis_keys()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:

View File

@@ -299,7 +299,7 @@ class UserCacheService:
"""缓存单个用户""" """缓存单个用户"""
try: try:
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED) user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
# 应用资源代理处理 # 应用资源代理处理
if settings.enable_asset_proxy: if settings.enable_asset_proxy:
try: try:
@@ -307,7 +307,7 @@ class UserCacheService:
user_resp = await asset_proxy_service.replace_asset_urls(user_resp) user_resp = await asset_proxy_service.replace_asset_urls(user_resp)
except Exception as e: except Exception as e:
logger.warning(f"Asset proxy processing failed for user cache {user.id}: {e}") logger.warning(f"Asset proxy processing failed for user cache {user.id}: {e}")
await self.cache_user(user_resp) await self.cache_user(user_resp)
except Exception as e: except Exception as e:
logger.error(f"Error caching single user {user.id}: {e}") logger.error(f"Error caching single user {user.id}: {e}")

View File

@@ -67,11 +67,11 @@ async def lifespan(app: FastAPI):
start_stats_scheduler() # 启动统计调度器 start_stats_scheduler() # 启动统计调度器
schedule_online_status_maintenance() # 启动在线状态维护任务 schedule_online_status_maintenance() # 启动在线状态维护任务
load_achievements() load_achievements()
# 显示资源代理状态 # 显示资源代理状态
if settings.enable_asset_proxy: if settings.enable_asset_proxy:
logger.info(f"Asset Proxy enabled - Domain: {settings.custom_asset_domain}") logger.info(f"Asset Proxy enabled - Domain: {settings.custom_asset_domain}")
# on shutdown # on shutdown
yield yield
bg_tasks.stop() bg_tasks.stop()

View File

@@ -5,6 +5,7 @@ Revises: 57bacf936413
Create Date: 2025-08-24 00:08:14.704724 Create Date: 2025-08-24 00:08:14.704724
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence

View File

@@ -5,6 +5,7 @@ Revises: 8d2af11343b9
Create Date: 2025-08-24 04:00:02.063347 Create Date: 2025-08-24 04:00:02.063347
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
@@ -22,10 +23,14 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
#op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements") # op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
#op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements") # op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
op.add_column("room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)) op.add_column(
op.add_column("room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)) "room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
)
op.add_column(
"room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
)
# ### end Alembic commands ### # ### end Alembic commands ###
@@ -35,5 +40,7 @@ def downgrade() -> None:
op.drop_column("room_playlists", "updated_at") op.drop_column("room_playlists", "updated_at")
op.drop_column("room_playlists", "created_at") op.drop_column("room_playlists", "created_at")
op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True) op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True)
op.create_index(op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False) op.create_index(
op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False
)
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -5,13 +5,13 @@ Revises: 178873984b22
Create Date: 2025-08-23 18:45:03.009632 Create Date: 2025-08-23 18:45:03.009632
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "57bacf936413" revision: str = "57bacf936413"

View File

@@ -5,6 +5,7 @@ Revises: 20c6df84813f
Create Date: 2025-08-24 00:08:42.419252 Create Date: 2025-08-24 00:08:42.419252
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence

View File

@@ -5,6 +5,7 @@ Revises: 7576ca1e056d
Create Date: 2025-08-24 00:11:05.064099 Create Date: 2025-08-24 00:11:05.064099
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence