chore(deps): auto fix by pre-commit hooks

This commit is contained in:
pre-commit-ci[bot]
2025-08-24 03:18:58 +00:00
committed by MingxuanGame
parent b4fd4e0256
commit 7625cd99f5
25 changed files with 241 additions and 320 deletions

View File

@@ -156,19 +156,17 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
# 添加标准JWT声明
to_encode.update({
"exp": expire,
"random": secrets.token_hex(16)
})
if hasattr(settings, 'jwt_audience') and settings.jwt_audience:
to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
if hasattr(settings, "jwt_audience") and settings.jwt_audience:
to_encode["aud"] = settings.jwt_audience
if hasattr(settings, 'jwt_issuer') and settings.jwt_issuer:
if hasattr(settings, "jwt_issuer") and settings.jwt_issuer:
to_encode["iss"] = settings.jwt_issuer
# 编码JWT
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
def generate_refresh_token() -> str:
"""生成刷新令牌"""
length = 64

View File

@@ -291,11 +291,7 @@ class UserResp(UserBase):
).one()
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = (
obj.cover.get("url", "")
if obj.cover
else ""
)
u.cover_url = obj.cover.get("url", "") if obj.cover else ""
if "friends" in include:
u.friends = [

View File

@@ -1,17 +1,17 @@
from datetime import datetime
from typing import Optional, ClassVar
from sqlalchemy import Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import SQLModel, Field, Column, DateTime, BigInteger, ForeignKey
from typing import ClassVar
from app.models.model import UTCBaseModel
from app.utils import utcnow
from sqlalchemy import Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, SQLModel
class MultiplayerRealtimeRoomEventBase(SQLModel, UTCBaseModel):
event_type: str = Field(index=True)
event_detail: Optional[str] = Field(default=None, sa_column=Column(Text))
event_detail: str | None = Field(default=None, sa_column=Column(Text))
class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase, table=True):
@@ -19,9 +19,7 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
id: int | None = Field(default=None, primary_key=True, index=True)
room_id: int = Field(
sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False)
)
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), index=True, nullable=False))
playlist_item_id: int | None = Field(
default=None,
sa_column=Column(ForeignKey("playlists.id"), index=True, nullable=True),
@@ -31,9 +29,5 @@ class MultiplayerRealtimeRoomEvent(AsyncAttrs, MultiplayerRealtimeRoomEventBase,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True, nullable=True),
)
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
)
updated_at: datetime = Field(
sa_column=Column(DateTime(timezone=True)), default_factory=utcnow
)
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)
updated_at: datetime = Field(sa_column=Column(DateTime(timezone=True)), default_factory=utcnow)

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
@@ -60,16 +60,9 @@ class Playlist(PlaylistBase, table=True):
}
)
room: "Room" = Relationship()
created_at: Optional[datetime] = Field(
default=None,
sa_column_kwargs={"server_default": func.now()}
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column_kwargs={
"server_default": func.now(),
"onupdate": func.now()
}
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
updated_at: datetime | None = Field(
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
@classmethod

View File

@@ -74,7 +74,6 @@ class Room(AsyncAttrs, RoomBase, table=True):
)
class RoomResp(RoomBase):
id: int
has_password: bool = False

View File

@@ -68,9 +68,7 @@ class BaseFetcher:
if response.status_code == 401:
logger.warning(f"Received 401 error for {url}")
await self._clear_tokens()
raise TokenAuthError(
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
)
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
response.raise_for_status()
return response.json()

View File

@@ -14,12 +14,13 @@ from app.utils import bg_tasks
from ._base import BaseFetcher
import redis.asyncio as redis
from httpx import AsyncClient
import redis.asyncio as redis
class RateLimitError(Exception):
"""速率限制异常"""
pass
@@ -73,9 +74,7 @@ class BeatmapsetFetcher(BaseFetcher):
if response.status_code == 401:
logger.warning(f"Received 401 error for {url}")
await self._clear_tokens()
raise TokenAuthError(
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
)
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
response.raise_for_status()
return response.json()
@@ -205,7 +204,9 @@ class BeatmapsetFetcher(BaseFetcher):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit")
logger.opt(colors=True).info(
"<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit"
)
bg_tasks.add_task(delayed_prefetch)

View File

@@ -1,15 +1,11 @@
"""LIO (Legacy IO) router for osu-server-spectator compatibility."""
from __future__ import annotations
import json
from typing import Any, Dict, List
from fastapi import APIRouter, HTTPException, Request, status, Query, Depends
from pydantic import BaseModel
from sqlmodel import col, select, desc
from sqlalchemy import update, func
from redis.asyncio import Redis
from typing import Any
from app.database.chat import ChannelType, ChatChannel # ChatChannel 模型 & 枚举
from app.database.lazer_user import User
from app.database.playlists import Playlist as DBPlaylist
from app.database.room import Room
@@ -17,13 +13,18 @@ from app.database.room_participated_user import RoomParticipatedUser
from app.dependencies.database import Database, get_redis
from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.log import logger
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
from app.models.room import MatchType, QueueMode, RoomStatus
from app.utils import utcnow
from app.database.chat import ChatChannel, ChannelType # ChatChannel 模型 & 枚举
from .notification.server import server
from app.log import logger
from .notification.server import server
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlalchemy import func, update
from sqlmodel import col, select
router = APIRouter(prefix="/_lio", tags=["LIO"])
@@ -40,9 +41,7 @@ async def _ensure_room_chat_channel(
# 1) 按 channel_id 查是否已存在
try:
# Use db.execute instead of db.exec for better async compatibility
result = await db.execute(
select(ChatChannel).where(ChatChannel.channel_id == room.channel_id)
)
result = await db.execute(select(ChatChannel).where(ChatChannel.channel_id == room.channel_id))
ch = result.scalar_one_or_none()
except Exception as e:
logger.debug(f"Error querying ChatChannel: {e}")
@@ -59,8 +58,8 @@ async def _ensure_room_chat_channel(
channel_id_value = int(room.channel_id)
ch = ChatChannel(
channel_id=channel_id_value, # 与房间绑定的同一 channel_id确保为 int
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
channel_id=channel_id_value, # 与房间绑定的同一 channel_id确保为 int
name=f"mp_{room.id}", # 频道名可自定义(注意唯一性)
description=f"Multiplayer room {room.id} chat",
type=ChannelType.MULTIPLAYER,
)
@@ -86,18 +85,20 @@ async def _alloc_channel_id(db: Database) -> int:
logger.debug(f"Error allocating channel_id: {e}")
# Fallback to a timestamp-based approach
import time
return int(time.time()) % 1000000 + 100
class RoomCreateRequest(BaseModel):
"""Request model for creating a multiplayer room."""
name: str
user_id: int
password: str | None = None
match_type: str = "HeadToHead"
queue_mode: str = "HostOnly"
initial_playlist: List[Dict[str, Any]] = []
playlist: List[Dict[str, Any]] = []
initial_playlist: list[dict[str, Any]] = []
playlist: list[dict[str, Any]] = []
def verify_request_signature(request: Request, timestamp: str, body: bytes) -> bool:
@@ -126,10 +127,7 @@ async def _validate_user_exists(db: Database, user_id: int) -> User:
user = user_result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User with ID {user_id} not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User with ID {user_id} not found")
return user
@@ -149,7 +147,7 @@ def _parse_room_enums(match_type: str, queue_mode: str) -> tuple[MatchType, Queu
return match_type_enum, queue_mode_enum
def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_user_id: int) -> Dict[str, Any]:
def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_user_id: int) -> dict[str, Any]:
"""
Normalize playlist item data with default values.
@@ -181,30 +179,28 @@ def _coerce_playlist_item(item_data: Dict[str, Any], default_order: int, host_us
}
def _validate_playlist_items(items: List[Dict[str, Any]]) -> None:
def _validate_playlist_items(items: list[dict[str, Any]]) -> None:
"""Validate playlist items data."""
if not items:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one playlist item is required to create a room"
status_code=status.HTTP_400_BAD_REQUEST, detail="At least one playlist item is required to create a room"
)
for idx, item in enumerate(items):
if item["beatmap_id"] is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Playlist item at index {idx} missing beatmap_id"
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Playlist item at index {idx} missing beatmap_id"
)
ruleset_id = item["ruleset_id"]
if not isinstance(ruleset_id, int) or not (0 <= ruleset_id <= 3):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}"
detail=f"Playlist item at index {idx} has invalid ruleset_id {ruleset_id}",
)
async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, int]:
async def _create_room(db: Database, room_data: dict[str, Any]) -> tuple[Room, int]:
host_user_id = room_data.get("user_id")
room_name = room_data.get("name", "Unnamed Room")
password = room_data.get("password")
@@ -242,12 +238,12 @@ async def _create_room(db: Database, room_data: Dict[str, Any]) -> tuple[Room, i
return room, host_user_id
async def _add_playlist_items(db: Database, room_id: int, room_data: Dict[str, Any], host_user_id: int) -> None:
async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, Any], host_user_id: int) -> None:
"""Add playlist items to the room."""
initial_playlist = room_data.get("initial_playlist", [])
legacy_playlist = room_data.get("playlist", [])
items_raw: List[Dict[str, Any]] = []
items_raw: list[dict[str, Any]] = []
# Process initial playlist
for i, item in enumerate(initial_playlist):
@@ -292,35 +288,24 @@ async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int
async def _verify_room_password(db: Database, room_id: int, provided_password: str | None) -> None:
"""Verify room password if required."""
room_result = await db.execute(
select(Room).where(col(Room.id) == room_id)
)
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
room = room_result.scalar_one_or_none()
if room is None:
logger.debug(f"Room {room_id} not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Room not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
logger.debug(f"Room {room_id} has password: {bool(room.password)}, provided: {bool(provided_password)}")
# If room has password but none provided
if room.password and not provided_password:
logger.debug(f"Room {room_id} requires password but none provided")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Password required"
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Password required")
# If room has password and provided password doesn't match
if room.password and provided_password and provided_password != room.password:
logger.debug(f"Room {room_id} password mismatch")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid password"
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid password")
logger.debug(f"Room {room_id} password verification passed")
@@ -332,7 +317,7 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
select(RoomParticipatedUser.id).where(
RoomParticipatedUser.room_id == room_id,
RoomParticipatedUser.user_id == user_id,
col(RoomParticipatedUser.left_at).is_(None)
col(RoomParticipatedUser.left_at).is_(None),
)
)
existing_ids = existing_result.scalars().all() # 获取所有匹配的ID
@@ -364,10 +349,11 @@ async def _add_or_update_participant(db: Database, room_id: int, user_id: int) -
class BeatmapEnsureRequest(BaseModel):
"""Request model for ensuring beatmap exists."""
beatmap_id: int
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> Dict[str, Any]:
async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int) -> dict[str, Any]:
"""
确保谱面存在(包括元数据和原始文件缓存)。
@@ -383,14 +369,11 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
try:
# 1. 确保谱面元数据存在于数据库中
from app.database.beatmap import Beatmap
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
if not beatmap:
return {
"success": False,
"error": f"Beatmap {beatmap_id} not found",
"beatmap_id": beatmap_id
}
return {"success": False, "error": f"Beatmap {beatmap_id} not found", "beatmap_id": beatmap_id}
# 2. 预缓存谱面原始文件
cache_key = f"beatmap:{beatmap_id}:raw"
@@ -410,35 +393,27 @@ async def _ensure_beatmap_exists(db: Database, fetcher, redis, beatmap_id: int)
"beatmap_id": beatmap_id,
"metadata_cached": True,
"raw_file_cached": await redis.exists(cache_key),
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]"
"beatmap_title": f"{beatmap.beatmapset.artist} - {beatmap.beatmapset.title} [{beatmap.version}]",
}
except Exception as e:
logger.debug(f"Error ensuring beatmap {beatmap_id}: {e}")
return {
"success": False,
"error": str(e),
"beatmap_id": beatmap_id
}
return {"success": False, "error": str(e), "beatmap_id": beatmap_id}
async def _update_room_participant_count(db: Database, room_id: int) -> int:
"""更新房间参与者数量并返回当前数量。"""
# 统计活跃参与者
active_participants_result = await db.execute(
select(RoomParticipatedUser.user_id).where(
RoomParticipatedUser.room_id == room_id,
col(RoomParticipatedUser.left_at).is_(None)
RoomParticipatedUser.room_id == room_id, col(RoomParticipatedUser.left_at).is_(None)
)
)
active_participants = active_participants_result.all()
count = len(active_participants)
# 更新房间参与者数量
await db.execute(
update(Room)
.where(col(Room.id) == room_id)
.values(participant_count=count)
)
await db.execute(update(Room).where(col(Room.id) == room_id).values(participant_count=count))
return count
@@ -457,7 +432,7 @@ async def _end_room_if_empty(db: Database, room_id: int) -> bool:
.values(
ends_at=now,
status=RoomStatus.IDLE, # 或者使用 RoomStatus.ENDED 如果有这个状态
participant_count=0
participant_count=0,
)
)
logger.debug(f"Room {room_id} ended automatically (no participants remaining)")
@@ -474,7 +449,7 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
.where(
col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) != leaving_user_id,
col(RoomParticipatedUser.left_at).is_(None)
col(RoomParticipatedUser.left_at).is_(None),
)
.order_by(col(RoomParticipatedUser.joined_at)) # 按加入时间排序
)
@@ -483,23 +458,21 @@ async def _transfer_ownership_or_end_room(db: Database, room_id: int, leaving_us
if remaining_participants:
# 将房主权限转让给最早加入的用户
new_owner_id = remaining_participants[0][0] # 获取 user_id
await db.execute(
update(Room)
.where(col(Room.id) == room_id)
.values(host_id=new_owner_id)
)
await db.execute(update(Room).where(col(Room.id) == room_id).values(host_id=new_owner_id))
logger.debug(f"Room {room_id} ownership transferred from {leaving_user_id} to {new_owner_id}")
return False # 房间继续存在
else:
# 没有其他参与者,结束房间
return await _end_room_if_empty(db, room_id)
# ===== API ENDPOINTS =====
@router.post("/multiplayer/rooms")
async def create_multiplayer_room(
request: Request,
room_data: Dict[str, Any],
room_data: dict[str, Any],
db: Database,
timestamp: str = "",
) -> int:
@@ -508,10 +481,7 @@ async def create_multiplayer_room(
# Verify request signature
body = await request.body()
if not verify_request_signature(request, str(timestamp), body):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
# Parse room data if string
if isinstance(room_data, str):
@@ -534,7 +504,7 @@ async def create_multiplayer_room(
await _add_playlist_items(db, room_id, room_data, host_user_id)
# Add host as participant
#await _add_host_as_participant(db, room_id, host_user_id)
# await _add_host_as_participant(db, room_id, host_user_id)
await db.commit()
return room_id
@@ -546,18 +516,11 @@ async def create_multiplayer_room(
raise
except json.JSONDecodeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid JSON: {str(e)}"
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON: {e!s}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create room: {str(e)}"
)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create room: {e!s}")
@router.delete("/multiplayer/rooms/{room_id}/users/{user_id}")
@@ -567,29 +530,21 @@ async def remove_user_from_room(
user_id: int,
db: Database,
timestamp: int = Query(..., description="Unix 时间戳(秒)", ge=0),
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Remove a user from a multiplayer room."""
try:
# Verify request signature
body = await request.body()
now = utcnow()
if not verify_request_signature(request, str(timestamp), body):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
# 检查房间是否存在
room_result = await db.execute(
select(Room).where(col(Room.id) == room_id)
)
room_result = await db.execute(select(Room).where(col(Room.id) == room_id))
room = room_result.scalar_one_or_none()
if room is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Room not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
room_owner_id = room.host_id
room_status = room.status
@@ -604,11 +559,10 @@ async def remove_user_from_room(
# 检查用户是否在房间中
participant_result = await db.execute(
select(RoomParticipatedUser.id)
.where(
select(RoomParticipatedUser.id).where(
col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) == user_id,
col(RoomParticipatedUser.left_at).is_(None)
col(RoomParticipatedUser.left_at).is_(None),
)
)
participant_query = participant_result.first()
@@ -634,7 +588,7 @@ async def remove_user_from_room(
.where(
col(RoomParticipatedUser.room_id) == room_id,
col(RoomParticipatedUser.user_id) == user_id,
col(RoomParticipatedUser.left_at).is_(None)
col(RoomParticipatedUser.left_at).is_(None),
)
.values(left_at=now)
)
@@ -667,10 +621,9 @@ async def remove_user_from_room(
raise
except Exception as e:
await db.rollback()
logger.debug(f"Error removing user from room: {str(e)}")
logger.debug(f"Error removing user from room: {e!s}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to remove user from room: {str(e)}"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user from room: {e!s}"
)
@@ -681,7 +634,7 @@ async def add_user_to_room(
user_id: int,
db: Database,
timestamp: str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Add a user to a multiplayer room."""
logger.debug(f"Adding user {user_id} to room {room_id}")
@@ -690,7 +643,7 @@ async def add_user_to_room(
user_data = None
if body:
try:
user_data = json.loads(body.decode('utf-8'))
user_data = json.loads(body.decode("utf-8"))
logger.debug(f"Parsed user_data: {user_data}")
except json.JSONDecodeError:
logger.debug("Failed to parse user_data from request body")
@@ -698,15 +651,11 @@ async def add_user_to_room(
# Verify request signature
if not verify_request_signature(request, timestamp, body):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
# 检查房间是否已结束
room_result = await db.execute(
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id)
.where(col(Room.id) == room_id)
select(Room.id, Room.ends_at, Room.channel_id, Room.host_id).where(col(Room.id) == room_id)
)
room_row = room_result.first()
if not room_row:
@@ -763,7 +712,7 @@ async def ensure_beatmap_present(
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
timestamp: str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
确保谱面在服务器中存在(包括元数据和原始文件缓存)。
@@ -774,10 +723,7 @@ async def ensure_beatmap_present(
# Verify request signature
body = await request.body()
if not verify_request_signature(request, timestamp, body):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid request signature"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid request signature")
beatmap_id = beatmap_data.beatmap_id
logger.debug(f"Ensuring beatmap {beatmap_id} is present")
@@ -795,8 +741,7 @@ async def ensure_beatmap_present(
raise
except Exception as e:
await db.rollback()
logger.debug(f"Error ensuring beatmap: {str(e)}")
logger.debug(f"Error ensuring beatmap: {e!s}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to ensure beatmap: {str(e)}"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to ensure beatmap: {e!s}"
)

View File

@@ -96,6 +96,7 @@ async def send_message(
if channel_type == ChannelType.MULTIPLAYER:
try:
from app.dependencies.database import get_redis
redis = get_redis()
key = f"channel:{channel_id}:messages"
key_type = await redis.type(key)
@@ -162,13 +163,9 @@ async def get_message(
):
# 1) 查频道
if channel.isdigit():
db_channel = (await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -176,7 +173,6 @@ async def get_message(
channel_id = db_channel.channel_id
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
messages.reverse()
@@ -188,11 +184,7 @@ async def get_message(
if since > 0 and until is None:
# 向前加载新消息 → 直接 ASC
query = (
base.where(col(ChatMessage.message_id) > since)
.order_by(col(ChatMessage.message_id).asc())
.limit(limit)
)
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
rows = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
# 已经 ASC无需反转
@@ -202,9 +194,7 @@ async def get_message(
if until is not None:
# 用 DESC 取最近的更早消息,再反转为 ASC
query = (
base.where(col(ChatMessage.message_id) < until)
.order_by(col(ChatMessage.message_id).desc())
.limit(limit)
base.where(col(ChatMessage.message_id) < until).order_by(col(ChatMessage.message_id).desc()).limit(limit)
)
rows = (await session.exec(query)).all()
rows = list(rows)
@@ -221,7 +211,6 @@ async def get_message(
return resp
@router.put(
"/chat/channels/{channel}/mark-as-read/{message}",
status_code=204,

View File

@@ -81,10 +81,11 @@ class ChatServer:
if not users_in_channel:
try:
async with with_db() as session:
from sqlmodel import select
channel = await session.get(ChatChannel, channel_id)
if channel and channel.type == ChannelType.MULTIPLAYER:
logger.warning(f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone")
logger.warning(
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
)
# 对于多人游戏房间,这可能是正常的(用户都离开了房间)
# 但我们仍然记录这个情况以便调试
except Exception as e:

View File

@@ -13,8 +13,8 @@ from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import SearchQueryModel
from app.service.beatmap_download_service import BeatmapDownloadService
from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmap_download_service import BeatmapDownloadService
from .router import router

View File

@@ -59,15 +59,17 @@ async def get_all_rooms(
if status is not None:
where_clauses.append(col(Room.status) == status)
#print(mode, category, status, current_user.id)
# print(mode, category, status, current_user.id)
if mode == "open":
# 修改为新的查询逻辑:状态为 idle 或 playingstarts_at 不为空ends_at 为空
where_clauses.extend([
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
col(Room.starts_at).is_not(None),
col(Room.ends_at).is_(None)
])
#if category == RoomCategory.REALTIME:
where_clauses.extend(
[
col(Room.status).in_([RoomStatus.IDLE, RoomStatus.PLAYING]),
col(Room.starts_at).is_not(None),
col(Room.ends_at).is_(None),
]
)
# if category == RoomCategory.REALTIME:
# where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated":
@@ -96,7 +98,7 @@ async def get_all_rooms(
.unique()
.all()
)
#print("Retrieved rooms:", db_rooms)
# print("Retrieved rooms:", db_rooms)
for room in db_rooms:
resp = await RoomResp.from_db(room, db)
resp.has_password = bool((room.password or "").strip())

View File

@@ -5,10 +5,12 @@
from __future__ import annotations
from typing import Any
from fastapi import Request
from app.config import settings
from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request
async def process_response_assets(data: Any, request: Request) -> Any:
"""
@@ -56,6 +58,7 @@ def asset_proxy_response(func):
"""
装饰器自动处理响应中的资源URL
"""
async def wrapper(*args, **kwargs):
# 获取request对象
request = None

View File

@@ -7,8 +7,8 @@ from __future__ import annotations
import re
from typing import Any
from app.config import settings
from app.log import logger
class AssetProxyService:
@@ -26,7 +26,7 @@ class AssetProxyService:
递归替换数据中的osu!资源URL为自定义域名
"""
# 处理Pydantic模型
if hasattr(data, 'model_dump'):
if hasattr(data, "model_dump"):
# 转换为字典,处理后再转换回模型
data_dict = data.model_dump()
processed_dict = await self.replace_asset_urls(data_dict)
@@ -49,30 +49,20 @@ class AssetProxyService:
# 替换 assets.ppy.sh (用户头像、封面、奖章等)
result = re.sub(
r"https://assets\.ppy\.sh/",
f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/",
result
r"https://assets\.ppy\.sh/", f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/", result
)
# 替换 b.ppy.sh 预览音频 (保持//前缀)
result = re.sub(
r"//b\.ppy\.sh/",
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
result
)
result = re.sub(r"//b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result)
# 替换 https://b.ppy.sh 预览音频 (转换为//前缀)
result = re.sub(
r"https://b\.ppy\.sh/",
f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/",
result
r"https://b\.ppy\.sh/", f"//{self.beatmap_proxy_prefix}.{self.custom_asset_domain}/", result
)
# 替换 a.ppy.sh 头像
result = re.sub(
r"https://a\.ppy\.sh/",
f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/",
result
r"https://a\.ppy\.sh/", f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/", result
)
return result

View File

@@ -5,7 +5,7 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, UTC
from datetime import UTC, datetime, timedelta
import json
from app.dependencies.database import get_redis, get_redis_message

View File

@@ -14,8 +14,8 @@ from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.log import logger
from app.models.score import GameMode
from app.utils import utcnow
from app.service.asset_proxy_service import get_asset_proxy_service
from app.utils import utcnow
from redis.asyncio import Redis
from sqlmodel import col, select

View File

@@ -258,7 +258,9 @@ class RedisMessageSystem:
# 验证删除是否成功
verify_type = await self._redis_exec(self.redis.type, channel_messages_key)
if verify_type != "none":
logger.error(f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}")
logger.error(
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
)
# 强制删除
await self._redis_exec(self.redis.unlink, channel_messages_key)

View File

@@ -5,6 +5,7 @@ Revises: 57bacf936413
Create Date: 2025-08-24 00:08:14.704724
"""
from __future__ import annotations
from collections.abc import Sequence

View File

@@ -5,6 +5,7 @@ Revises: 8d2af11343b9
Create Date: 2025-08-24 04:00:02.063347
"""
from __future__ import annotations
from collections.abc import Sequence
@@ -22,10 +23,14 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
#op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
#op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
op.add_column("room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
op.add_column("room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True))
# op.drop_index(op.f("ix_lazer_user_achievements_achievement_id"), table_name="lazer_user_achievements")
# op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements")
op.add_column(
"room_playlists", sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
)
op.add_column(
"room_playlists", sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=True)
)
# ### end Alembic commands ###
@@ -35,5 +40,7 @@ def downgrade() -> None:
op.drop_column("room_playlists", "updated_at")
op.drop_column("room_playlists", "created_at")
op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True)
op.create_index(op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False)
op.create_index(
op.f("ix_lazer_user_achievements_achievement_id"), "lazer_user_achievements", ["achievement_id"], unique=False
)
# ### end Alembic commands ###

View File

@@ -5,13 +5,13 @@ Revises: 178873984b22
Create Date: 2025-08-23 18:45:03.009632
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "57bacf936413"

View File

@@ -5,6 +5,7 @@ Revises: 20c6df84813f
Create Date: 2025-08-24 00:08:42.419252
"""
from __future__ import annotations
from collections.abc import Sequence

View File

@@ -5,6 +5,7 @@ Revises: 7576ca1e056d
Create Date: 2025-08-24 00:11:05.064099
"""
from __future__ import annotations
from collections.abc import Sequence