From ce465aa0490bd0e9f439c7a898b5fcb9a1b7f851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:57:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E7=90=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/score.py | 12 +- app/dependencies/database.py | 6 +- app/models/score.py | 6 +- app/models/stats.py | 5 +- app/router/notification/channel.py | 30 +- app/router/notification/message.py | 48 ++- app/router/notification/server.py | 30 +- app/router/v2/router.py | 1 - app/router/v2/score.py | 55 ++-- app/router/v2/stats.py | 148 +++++---- app/service/enhanced_interval_stats.py | 360 +++++++++++++--------- app/service/message_queue.py | 154 +++++---- app/service/message_queue_processor.py | 204 +++++++----- app/service/optimized_message.py | 79 ++--- app/service/redis_message_system.py | 411 +++++++++++++++---------- app/service/stats_cleanup.py | 53 ++-- app/service/stats_scheduler.py | 87 +++--- app/signalr/hub/metadata.py | 6 +- app/signalr/hub/multiplayer.py | 89 +++--- app/signalr/hub/spectator.py | 93 +++--- 20 files changed, 1078 insertions(+), 799 deletions(-) diff --git a/app/database/score.py b/app/database/score.py index 0f06547..da9cbc4 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -45,7 +45,7 @@ from .relationship import ( ) from .score_token import ScoreToken -from pydantic import field_validator, field_serializer +from pydantic import field_serializer, field_validator from redis.asyncio import Redis from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs @@ -126,7 +126,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): if isinstance(v, dict): serialized = {} for key, value in v.items(): - if hasattr(key, 'value'): + if hasattr(key, "value"): # 如果是枚举,使用其值 serialized[key.value] = value else: @@ -138,7 +138,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): @field_serializer("rank", when_used="json") def serialize_rank(self, v): """序列化等级,确保枚举值正确转换为字符串""" - if hasattr(v, 'value'): + if hasattr(v, "value"): return v.value return str(v) @@ -188,7 +188,7 @@ class Score(ScoreBase, table=True): @field_serializer("gamemode", when_used="json") def serialize_gamemode(self, v): """序列化游戏模式,确保枚举值正确转换为字符串""" - if hasattr(v, 'value'): + if hasattr(v, "value"): return v.value return str(v) @@ -281,7 +281,7 @@ class ScoreResp(ScoreBase): if isinstance(v, dict): serialized = {} for key, value in v.items(): - if hasattr(key, 'value'): + if hasattr(key, "value"): # 如果是枚举,使用其值 serialized[key.value] = value else: @@ -294,7 +294,7 @@ class ScoreResp(ScoreBase): async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": # 确保 score 对象完全加载,避免懒加载问题 await session.refresh(score) - + s = cls.model_validate(score.model_dump()) assert score.id await score.awaitable_attrs.beatmap diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 4ffefd0..77648d1 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -10,8 +10,8 @@ from app.config import settings from fastapi import Depends from pydantic import BaseModel -import redis.asyncio as redis import redis as sync_redis +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -40,7 +40,9 @@ engine = create_async_engine( redis_client = redis.from_url(settings.redis_url, decode_responses=True) # Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 -redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1) +redis_message_client = sync_redis.from_url( + settings.redis_url, decode_responses=True, db=1 +) # 数据库依赖 diff --git a/app/models/score.py b/app/models/score.py index f75e827..4e42baf 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -7,7 +7,7 @@ from app.config import settings from .mods import API_MODS, APIMod -from pydantic import BaseModel, Field, ValidationInfo, field_validator, field_serializer +from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator if TYPE_CHECKING: import rosu_pp_py as rosu @@ -212,7 +212,7 @@ class SoloScoreSubmissionInfo(BaseModel): if isinstance(v, dict): serialized = {} for key, value in v.items(): - if hasattr(key, 'value'): + if hasattr(key, "value"): # 如果是枚举,使用其值 serialized[key.value] = value else: @@ -224,7 +224,7 @@ class SoloScoreSubmissionInfo(BaseModel): @field_serializer("rank", when_used="json") def serialize_rank(self, v): """序列化等级,确保枚举值正确转换为字符串""" - if hasattr(v, 'value'): + if hasattr(v, "value"): return v.value return str(v) diff --git a/app/models/stats.py b/app/models/stats.py index 4e277b1..ee79ba2 100644 --- a/app/models/stats.py +++ b/app/models/stats.py @@ -1,13 +1,13 @@ from __future__ import annotations from datetime import datetime -from typing import Any from pydantic import BaseModel class OnlineStats(BaseModel): """在线统计信息""" + registered_users: int online_users: int playing_users: int @@ -16,6 +16,7 @@ class OnlineStats(BaseModel): class OnlineHistoryPoint(BaseModel): """在线历史数据点""" + timestamp: datetime online_count: int playing_count: int @@ -23,12 +24,14 @@ class OnlineHistoryPoint(BaseModel): class OnlineHistoryStats(BaseModel): """24小时在线历史统计""" + history: list[OnlineHistoryPoint] current_stats: OnlineStats class ServerStatistics(BaseModel): """服务器统计信息""" + total_users: int online_users: int playing_users: int diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py index 9c5e941..f2018dc 100644 --- a/app/router/notification/channel.py +++ b/app/router/notification/channel.py @@ -62,7 +62,7 @@ async def get_update( if db_channel: # 提取必要的属性避免惰性加载 channel_type = db_channel.type - + resp.presence.append( await ChatChannelResp.from_db( db_channel, @@ -122,9 +122,7 @@ async def join_channel( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() if db_channel is None: @@ -154,9 +152,7 @@ async def leave_channel( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() if db_channel is None: @@ -187,7 +183,7 @@ async def get_channel_list( # 提取必要的属性避免惰性加载 channel_id = channel.channel_id channel_type = channel.type - + assert channel_id is not None results.append( await ChatChannelResp.from_db( @@ -230,19 +226,17 @@ async def get_channel( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() - + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - + # 立即提取需要的属性 channel_id = db_channel.channel_id channel_type = db_channel.type channel_name = db_channel.name - + assert channel_id is not None users = [] @@ -325,7 +319,9 @@ async def create_channel( channel_name = f"pm_{current_user.id}_{req.target_id}" else: channel_name = req.channel.name if req.channel else "Unnamed Channel" - result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name)) + result = await session.exec( + select(ChatChannel).where(ChatChannel.name == channel_name) + ) channel = result.first() if channel is None: @@ -350,11 +346,11 @@ async def create_channel( await server.batch_join_channel([*target_users, current_user], channel, session) await server.join_channel(current_user, channel, session) - + # 提取必要的属性避免惰性加载 channel_id = channel.channel_id assert channel_id - + return await ChatChannelResp.from_db( channel, session, diff --git a/app/router/notification/message.py b/app/router/notification/message.py index ec3b419..390db6b 100644 --- a/app/router/notification/message.py +++ b/app/router/notification/message.py @@ -1,10 +1,5 @@ from __future__ import annotations -import json -import uuid -from datetime import datetime -from typing import Optional - from app.database import ChatMessageResp from app.database.chat import ( ChannelType, @@ -16,14 +11,13 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.lazer_user import User -from app.dependencies.database import Database, get_redis, get_redis_message +from app.dependencies.database import Database, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user +from app.log import logger from app.models.notification import ChannelMessage, ChannelMessageTeam from app.router.v2 import api_v2_router as router -from app.service.optimized_message import optimized_message_service from app.service.redis_message_system import redis_message_system -from app.log import logger from .banchobot import bot from .server import server @@ -106,11 +100,9 @@ async def send_message( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() - + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -118,29 +110,29 @@ async def send_message( channel_id = db_channel.channel_id channel_type = db_channel.type channel_name = db_channel.name - + assert channel_id is not None assert current_user.id - + # 使用 Redis 消息系统发送消息 - 立即返回 resp = await redis_message_system.send_message( channel_id=channel_id, user=current_user, content=req.message, is_action=req.is_action, - user_uuid=req.uuid + user_uuid=req.uuid, ) - + # 立即广播消息给所有客户端 is_bot_command = req.message.startswith("!") await server.send_message_to_channel( resp, is_bot_command and channel_type == ChannelType.PUBLIC ) - + # 处理机器人命令 if is_bot_command: await bot.try_handle(current_user, db_channel, req.message, session) - + # 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道) if channel_type in [ChannelType.PM, ChannelType.TEAM]: temp_msg = ChatMessage( @@ -151,7 +143,7 @@ async def send_message( type=MessageType.ACTION if req.is_action else MessageType.PLAIN, uuid=req.uuid, ) - + if channel_type == ChannelType.PM: user_ids = channel_name.split("_")[1:] await server.new_private_notification( @@ -163,7 +155,7 @@ async def send_message( await server.new_private_notification( ChannelMessageTeam.init(temp_msg, current_user) ) - + return resp @@ -191,11 +183,9 @@ async def get_message( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() - + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -218,7 +208,7 @@ async def get_message( query = query.where(col(ChatMessage.message_id) > since) if until is not None: query = query.where(col(ChatMessage.message_id) < until) - + query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) messages = (await session.exec(query)).all() resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] @@ -247,14 +237,12 @@ async def mark_as_read( ).first() else: db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() - + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - + # 立即提取需要的属性 channel_id = db_channel.channel_id assert channel_id diff --git a/app/router/notification/server.py b/app/router/notification/server.py index e24106f..71a5106 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -96,8 +96,10 @@ class ChatServer: async def send_message_to_channel( self, message: ChatMessageResp, is_bot_command: bool = False ): - logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}") - + logger.info( + f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}" + ) + event = ChatEvent( event="chat.message.new", data={"messages": [message], "users": [message.sender]}, @@ -107,24 +109,32 @@ class ChatServer: self._add_task(self.send_event(message.sender_id, event)) else: # 总是广播消息,无论是临时ID还是真实ID - logger.info(f"Broadcasting message to all users in channel {message.channel_id}") + logger.info( + f"Broadcasting message to all users in channel {message.channel_id}" + ) self._add_task( self.broadcast( message.channel_id, event, ) ) - + # 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息 # Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理 if message.message_id and message.message_id > 0: await self.mark_as_read( message.channel_id, message.sender_id, message.message_id ) - await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) - logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}") + await self.redis.set( + f"chat:{message.channel_id}:last_msg", message.message_id + ) + logger.info( + f"Updated last message ID for channel {message.channel_id} to {message.message_id}" + ) else: - logger.debug(f"Skipping last message update for message ID: {message.message_id}") + logger.debug( + f"Skipping last message update for message ID: {message.message_id}" + ) async def batch_join_channel( self, users: list[User], channel: ChatChannel, session: AsyncSession @@ -340,11 +350,9 @@ async def chat_websocket( server.connect(user_id, websocket) # 使用明确的查询避免延迟加载 db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == 1) - ) + await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1)) ).first() if db_channel is not None: await server.join_channel(user, db_channel, session) - + await _listen_stop(websocket, user_id, factory) diff --git a/app/router/v2/router.py b/app/router/v2/router.py index ffd22c4..e4a6b43 100644 --- a/app/router/v2/router.py +++ b/app/router/v2/router.py @@ -5,4 +5,3 @@ from fastapi import APIRouter router = APIRouter(prefix="/api/v2") # 导入所有子路由模块来注册路由 -from . import stats # 统计路由 diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 8ec01d8..8838d8f 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -75,9 +75,10 @@ READ_SCORE_TIMEOUT = 10 async def process_user_achievement(score_id: int): - from sqlmodel.ext.asyncio.session import AsyncSession from app.dependencies.database import engine - + + from sqlmodel.ext.asyncio.session import AsyncSession + session = AsyncSession(engine) try: await process_achievements(session, get_redis(), score_id) @@ -99,7 +100,7 @@ async def submit_score( ): # 立即获取用户ID,避免后续的懒加载问题 user_id = current_user.id - + if not info.passed: info.rank = Rank.F score_token = ( @@ -166,13 +167,15 @@ async def submit_score( has_pp, has_leaderboard, ) - score = (await db.exec( - select(Score) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] - .where(Score.id == score_id) - )).first() + score = ( + await db.exec( + select(Score) + .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] + .where(Score.id == score_id) + ) + ).first() assert score is not None - + resp = await ScoreResp.from_db(db, score) total_users = (await db.exec(select(func.count()).select_from(User))).first() assert total_users is not None @@ -202,13 +205,10 @@ async def submit_score( # 确保score对象已刷新,避免在后台任务中触发延迟加载 await db.refresh(score) score_gamemode = score.gamemode - + if user_id is not None: background_task.add_task( - _refresh_user_cache_background, - redis, - user_id, - score_gamemode + _refresh_user_cache_background, redis, user_id, score_gamemode ) background_task.add_task(process_user_achievement, resp.id) return resp @@ -217,9 +217,10 @@ async def submit_score( async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameMode): """后台任务:刷新用户缓存""" try: - from sqlmodel.ext.asyncio.session import AsyncSession from app.dependencies.database import engine - + + from sqlmodel.ext.asyncio.session import AsyncSession + user_cache_service = get_user_cache_service(redis) # 创建独立的数据库会话 session = AsyncSession(engine) @@ -422,7 +423,7 @@ async def create_solo_score( assert current_user.id is not None # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) async with db: score_token = ScoreToken( @@ -480,7 +481,7 @@ async def create_playlist_score( assert current_user.id is not None # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + room = await session.get(Room, room_id) if not room: raise HTTPException(status_code=404, detail="Room not found") @@ -557,10 +558,10 @@ async def submit_playlist_score( fetcher: Fetcher = Depends(get_fetcher), ): assert current_user.id is not None - + # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + item = ( await session.exec( select(Playlist).where( @@ -627,7 +628,7 @@ async def index_playlist_scores( ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + room = await session.get(Room, room_id) if not room: raise HTTPException(status_code=404, detail="Room not found") @@ -694,7 +695,7 @@ async def show_playlist_score( ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + room = await session.get(Room, room_id) if not room: raise HTTPException(status_code=404, detail="Room not found") @@ -803,7 +804,7 @@ async def pin_score( ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + score_record = ( await db.exec( select(Score).where( @@ -848,7 +849,7 @@ async def unpin_score( ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + score_record = ( await db.exec( select(Score).where(Score.id == score_id, Score.user_id == user_id) @@ -892,7 +893,7 @@ async def reorder_score_pin( ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + score_record = ( await db.exec( select(Score).where(Score.id == score_id, Score.user_id == user_id) @@ -986,9 +987,9 @@ async def download_score_replay( current_user: User = Security(get_current_user, scopes=["public"]), storage_service: StorageService = Depends(get_storage_service), ): - # 立即获取用户ID,避免懒加载问题 + # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - + score = (await db.exec(select(Score).where(Score.id == score_id))).first() if not score: raise HTTPException(status_code=404, detail="Score not found") diff --git a/app/router/v2/stats.py b/app/router/v2/stats.py index e8e90e2..b79dd54 100644 --- a/app/router/v2/stats.py +++ b/app/router/v2/stats.py @@ -1,42 +1,45 @@ from __future__ import annotations import asyncio -from datetime import datetime, timedelta -import json -from typing import Any from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import json from app.dependencies.database import get_redis, get_redis_message from app.log import logger from .router import router -from fastapi import APIRouter from pydantic import BaseModel # Redis key constants REDIS_ONLINE_USERS_KEY = "server:online_users" -REDIS_PLAYING_USERS_KEY = "server:playing_users" +REDIS_PLAYING_USERS_KEY = "server:playing_users" REDIS_REGISTERED_USERS_KEY = "server:registered_users" REDIS_ONLINE_HISTORY_KEY = "server:online_history" # 线程池用于同步Redis操作 _executor = ThreadPoolExecutor(max_workers=2) + async def _redis_exec(func, *args, **kwargs): """在线程池中执行同步Redis操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(_executor, func, *args, **kwargs) + class ServerStats(BaseModel): """服务器统计信息响应模型""" + registered_users: int online_users: int playing_users: int timestamp: datetime + class OnlineHistoryPoint(BaseModel): """在线历史数据点""" + timestamp: datetime online_count: int playing_count: int @@ -44,33 +47,36 @@ class OnlineHistoryPoint(BaseModel): peak_playing: int | None = None # 峰值游玩数(增强数据) total_samples: int | None = None # 采样次数(增强数据) + class OnlineHistoryResponse(BaseModel): """24小时在线历史响应模型""" + history: list[OnlineHistoryPoint] current_stats: ServerStats + @router.get("/stats", response_model=ServerStats, tags=["统计"]) async def get_server_stats() -> ServerStats: """ 获取服务器实时统计信息 - + 返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息 """ redis = get_redis() - + try: # 并行获取所有统计数据 registered_count, online_count, playing_count = await asyncio.gather( _get_registered_users_count(redis), _get_online_users_count(redis), - _get_playing_users_count(redis) + _get_playing_users_count(redis), ) - + return ServerStats( registered_users=registered_count, online_users=online_count, playing_users=playing_count, - timestamp=datetime.utcnow() + timestamp=datetime.utcnow(), ) except Exception as e: logger.error(f"Error getting server stats: {e}") @@ -79,75 +85,86 @@ async def get_server_stats() -> ServerStats: registered_users=0, online_users=0, playing_users=0, - timestamp=datetime.utcnow() + timestamp=datetime.utcnow(), ) + @router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"]) async def get_online_history() -> OnlineHistoryResponse: """ 获取最近24小时在线统计历史 - + 返回过去24小时内每小时的在线用户数和游玩用户数统计, 包含当前实时数据作为最新数据点 """ try: # 获取历史数据 - 使用同步Redis客户端 redis_sync = get_redis_message() - history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) + history_data = await _redis_exec( + redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1 + ) history_points = [] - + # 处理历史数据 for data in history_data: try: point_data = json.loads(data) # 支持新旧格式的历史数据 - history_points.append(OnlineHistoryPoint( - timestamp=datetime.fromisoformat(point_data["timestamp"]), - online_count=point_data["online_count"], - playing_count=point_data["playing_count"], - peak_online=point_data.get("peak_online"), # 新字段,可能不存在 - peak_playing=point_data.get("peak_playing"), # 新字段,可能不存在 - total_samples=point_data.get("total_samples") # 新字段,可能不存在 - )) + history_points.append( + OnlineHistoryPoint( + timestamp=datetime.fromisoformat(point_data["timestamp"]), + online_count=point_data["online_count"], + playing_count=point_data["playing_count"], + peak_online=point_data.get("peak_online"), # 新字段,可能不存在 + peak_playing=point_data.get( + "peak_playing" + ), # 新字段,可能不存在 + total_samples=point_data.get( + "total_samples" + ), # 新字段,可能不存在 + ) + ) except (json.JSONDecodeError, KeyError, ValueError) as e: logger.warning(f"Invalid history data point: {data}, error: {e}") continue - + # 获取当前实时统计信息 current_stats = await get_server_stats() - + # 如果历史数据为空或者最新数据超过15分钟,添加当前数据点 if not history_points or ( - history_points and - (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60 + history_points + and ( + current_stats.timestamp + - max(history_points, key=lambda x: x.timestamp).timestamp + ).total_seconds() + > 15 * 60 ): - history_points.append(OnlineHistoryPoint( - timestamp=current_stats.timestamp, - online_count=current_stats.online_users, - playing_count=current_stats.playing_users, - peak_online=current_stats.online_users, # 当前实时数据作为峰值 - peak_playing=current_stats.playing_users, - total_samples=1 - )) - + history_points.append( + OnlineHistoryPoint( + timestamp=current_stats.timestamp, + online_count=current_stats.online_users, + playing_count=current_stats.playing_users, + peak_online=current_stats.online_users, # 当前实时数据作为峰值 + peak_playing=current_stats.playing_users, + total_samples=1, + ) + ) + # 按时间排序(最新的在前) history_points.sort(key=lambda x: x.timestamp, reverse=True) - + # 限制到最多48个数据点(24小时) history_points = history_points[:48] - + return OnlineHistoryResponse( - history=history_points, - current_stats=current_stats + history=history_points, current_stats=current_stats ) except Exception as e: logger.error(f"Error getting online history: {e}") # 返回空历史和当前状态 current_stats = await get_server_stats() - return OnlineHistoryResponse( - history=[], - current_stats=current_stats - ) + return OnlineHistoryResponse(history=[], current_stats=current_stats) async def _get_registered_users_count(redis) -> int: @@ -159,6 +176,7 @@ async def _get_registered_users_count(redis) -> int: logger.error(f"Error getting registered users count: {e}") return 0 + async def _get_online_users_count(redis) -> int: """获取当前在线用户数""" try: @@ -168,6 +186,7 @@ async def _get_online_users_count(redis) -> int: logger.error(f"Error getting online users count: {e}") return 0 + async def _get_playing_users_count(redis) -> int: """获取当前游玩用户数""" try: @@ -177,14 +196,16 @@ async def _get_playing_users_count(redis) -> int: logger.error(f"Error getting playing users count: {e}") return 0 + # 统计更新功能 async def update_registered_users_count() -> None: """更新注册用户数缓存""" - from app.dependencies.database import with_db - from app.database import User from app.const import BANCHOBOT_ID - from sqlmodel import select, func - + from app.database import User + from app.dependencies.database import with_db + + from sqlmodel import func, select + redis = get_redis() try: async with with_db() as db: @@ -198,6 +219,7 @@ async def update_registered_users_count() -> None: except Exception as e: logger.error(f"Error updating registered users count: {e}") + async def add_online_user(user_id: int) -> None: """添加在线用户""" redis_sync = get_redis_message() @@ -209,14 +231,16 @@ async def add_online_user(user_id: int) -> None: if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期 await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期 logger.debug(f"Added online user {user_id}") - + # 立即更新当前区间统计 from app.service.enhanced_interval_stats import update_user_activity_in_interval + asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False)) - + except Exception as e: logger.error(f"Error adding online user {user_id}: {e}") + async def remove_online_user(user_id: int) -> None: """移除在线用户""" redis_sync = get_redis_message() @@ -226,6 +250,7 @@ async def remove_online_user(user_id: int) -> None: except Exception as e: logger.error(f"Error removing online user {user_id}: {e}") + async def add_playing_user(user_id: int) -> None: """添加游玩用户""" redis_sync = get_redis_message() @@ -237,14 +262,16 @@ async def add_playing_user(user_id: int) -> None: if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期 await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期 logger.debug(f"Added playing user {user_id}") - + # 立即更新当前区间统计 from app.service.enhanced_interval_stats import update_user_activity_in_interval + asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True)) - + except Exception as e: logger.error(f"Error adding playing user {user_id}: {e}") + async def remove_playing_user(user_id: int) -> None: """移除游玩用户""" redis_sync = get_redis_message() @@ -253,6 +280,7 @@ async def remove_playing_user(user_id: int) -> None: except Exception as e: logger.error(f"Error removing playing user {user_id}: {e}") + async def record_hourly_stats() -> None: """记录统计数据 - 简化版本,主要作为fallback使用""" redis_sync = get_redis_message() @@ -260,10 +288,10 @@ async def record_hourly_stats() -> None: try: # 先确保Redis连接正常 await redis_async.ping() - + online_count = await _get_online_users_count(redis_async) playing_count = await _get_playing_users_count(redis_async) - + current_time = datetime.utcnow() history_point = { "timestamp": current_time.isoformat(), @@ -271,16 +299,20 @@ async def record_hourly_stats() -> None: "playing_count": playing_count, "peak_online": online_count, "peak_playing": playing_count, - "total_samples": 1 + "total_samples": 1, } - + # 添加到历史记录 - await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) + await _redis_exec( + redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point) + ) # 只保留48个数据点(24小时,每30分钟一个点) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) # 设置过期时间为26小时,确保有足够缓冲 await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) - - logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}") + + logger.info( + f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}" + ) except Exception as e: logger.error(f"Error recording fallback stats: {e}") diff --git a/app/service/enhanced_interval_stats.py b/app/service/enhanced_interval_stats.py index 21d97fb..60ac4bc 100644 --- a/app/service/enhanced_interval_stats.py +++ b/app/service/enhanced_interval_stats.py @@ -4,11 +4,9 @@ from __future__ import annotations -import json -import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Dict, Set, Optional, List -from dataclasses import dataclass, asdict +import json from app.dependencies.database import get_redis, get_redis_message from app.log import logger @@ -16,7 +14,7 @@ from app.router.v2.stats import ( REDIS_ONLINE_HISTORY_KEY, _get_online_users_count, _get_playing_users_count, - _redis_exec + _redis_exec, ) # Redis keys for interval statistics @@ -29,34 +27,36 @@ CURRENT_INTERVAL_INFO_KEY = "server:current_interval_info" # 当前区间信息 @dataclass class IntervalInfo: """区间信息""" + start_time: datetime end_time: datetime interval_key: str - + def is_current(self) -> bool: """检查是否是当前区间""" now = datetime.utcnow() return self.start_time <= now < self.end_time - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: return { - 'start_time': self.start_time.isoformat(), - 'end_time': self.end_time.isoformat(), - 'interval_key': self.interval_key + "start_time": self.start_time.isoformat(), + "end_time": self.end_time.isoformat(), + "interval_key": self.interval_key, } - + @classmethod - def from_dict(cls, data: Dict) -> 'IntervalInfo': + def from_dict(cls, data: dict) -> "IntervalInfo": return cls( - start_time=datetime.fromisoformat(data['start_time']), - end_time=datetime.fromisoformat(data['end_time']), - interval_key=data['interval_key'] + start_time=datetime.fromisoformat(data["start_time"]), + end_time=datetime.fromisoformat(data["end_time"]), + interval_key=data["interval_key"], ) @dataclass class IntervalStats: """区间统计数据""" + interval_key: str start_time: datetime end_time: datetime @@ -66,38 +66,38 @@ class IntervalStats: peak_playing_count: int # 区间内游玩用户数峰值 total_samples: int # 采样次数 created_at: datetime - - def to_dict(self) -> Dict: + + def to_dict(self) -> dict: return { - 'interval_key': self.interval_key, - 'start_time': self.start_time.isoformat(), - 'end_time': self.end_time.isoformat(), - 'unique_online_users': self.unique_online_users, - 'unique_playing_users': self.unique_playing_users, - 'peak_online_count': self.peak_online_count, - 'peak_playing_count': self.peak_playing_count, - 'total_samples': self.total_samples, - 'created_at': self.created_at.isoformat() + "interval_key": self.interval_key, + "start_time": self.start_time.isoformat(), + "end_time": self.end_time.isoformat(), + "unique_online_users": self.unique_online_users, + "unique_playing_users": self.unique_playing_users, + "peak_online_count": self.peak_online_count, + "peak_playing_count": self.peak_playing_count, + "total_samples": self.total_samples, + "created_at": self.created_at.isoformat(), } - + @classmethod - def from_dict(cls, data: Dict) -> 'IntervalStats': + def from_dict(cls, data: dict) -> "IntervalStats": return cls( - interval_key=data['interval_key'], - start_time=datetime.fromisoformat(data['start_time']), - end_time=datetime.fromisoformat(data['end_time']), - unique_online_users=data['unique_online_users'], - unique_playing_users=data['unique_playing_users'], - peak_online_count=data['peak_online_count'], - peak_playing_count=data['peak_playing_count'], - total_samples=data['total_samples'], - created_at=datetime.fromisoformat(data['created_at']) + interval_key=data["interval_key"], + start_time=datetime.fromisoformat(data["start_time"]), + end_time=datetime.fromisoformat(data["end_time"]), + unique_online_users=data["unique_online_users"], + unique_playing_users=data["unique_playing_users"], + peak_online_count=data["peak_online_count"], + peak_playing_count=data["peak_playing_count"], + total_samples=data["total_samples"], + created_at=datetime.fromisoformat(data["created_at"]), ) class EnhancedIntervalStatsManager: """增强的区间统计管理器 - 真正统计半小时区间内的用户活跃情况""" - + @staticmethod def get_current_interval_boundaries() -> tuple[datetime, datetime]: """获取当前30分钟区间的边界""" @@ -108,49 +108,53 @@ class EnhancedIntervalStatsManager: # 区间结束时间 end_time = start_time + timedelta(minutes=30) return start_time, end_time - + @staticmethod def generate_interval_key(start_time: datetime) -> str: """生成区间唯一标识""" return f"{INTERVAL_STATS_BASE_KEY}:{start_time.strftime('%Y%m%d_%H%M')}" - + @staticmethod async def get_current_interval_info() -> IntervalInfo: """获取当前区间信息""" - start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries() - interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time) - - return IntervalInfo( - start_time=start_time, - end_time=end_time, - interval_key=interval_key + start_time, end_time = ( + EnhancedIntervalStatsManager.get_current_interval_boundaries() ) - + interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time) + + return IntervalInfo( + start_time=start_time, end_time=end_time, interval_key=interval_key + ) + @staticmethod async def initialize_current_interval() -> None: """初始化当前区间""" redis_sync = get_redis_message() redis_async = get_redis() - + try: - current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() - + current_interval = ( + await EnhancedIntervalStatsManager.get_current_interval_info() + ) + # 存储当前区间信息 await _redis_exec( - redis_sync.set, - CURRENT_INTERVAL_INFO_KEY, - json.dumps(current_interval.to_dict()) + redis_sync.set, + CURRENT_INTERVAL_INFO_KEY, + json.dumps(current_interval.to_dict()), ) await redis_async.expire(CURRENT_INTERVAL_INFO_KEY, 35 * 60) # 35分钟过期 - + # 初始化区间用户集合(如果不存在) online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" - playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" - + playing_key = ( + f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" + ) + # 设置过期时间为35分钟 await redis_async.expire(online_key, 35 * 60) await redis_async.expire(playing_key, 35 * 60) - + # 初始化区间统计记录 stats = IntervalStats( interval_key=current_interval.interval_key, @@ -161,157 +165,193 @@ class EnhancedIntervalStatsManager: peak_online_count=0, peak_playing_count=0, total_samples=0, - created_at=datetime.utcnow() + created_at=datetime.utcnow(), ) - + await _redis_exec( redis_sync.set, current_interval.interval_key, - json.dumps(stats.to_dict()) + json.dumps(stats.to_dict()), ) await redis_async.expire(current_interval.interval_key, 35 * 60) - + # 如果历史记录为空,自动填充前24小时数据为0 await EnhancedIntervalStatsManager._ensure_24h_history_exists() - - logger.info(f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}") - + + logger.info( + f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}" + ) + except Exception as e: logger.error(f"Error initializing current interval: {e}") - + @staticmethod async def _ensure_24h_history_exists() -> None: """确保24小时历史数据存在,不存在则用0填充""" redis_sync = get_redis_message() redis_async = get_redis() - + try: # 检查现有历史数据数量 - history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY) - + history_length = await _redis_exec( + redis_sync.llen, REDIS_ONLINE_HISTORY_KEY + ) + if history_length < 48: # 少于48个数据点(24小时*2) - logger.info(f"History has only {history_length} points, filling with zeros for 24h") - + logger.info( + f"History has only {history_length} points, filling with zeros for 24h" + ) + # 计算需要填充的数据点数量 needed_points = 48 - history_length - + # 从当前时间往前推,创建缺失的时间点(都填充为0) current_time = datetime.utcnow() - current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries() - + current_interval_start, _ = ( + EnhancedIntervalStatsManager.get_current_interval_boundaries() + ) + # 从当前区间开始往前推,创建历史数据点(确保时间对齐到30分钟边界) fill_points = [] for i in range(needed_points): # 每次往前推30分钟,确保时间对齐 - point_time = current_interval_start - timedelta(minutes=30 * (i + 1)) - + point_time = current_interval_start - timedelta( + minutes=30 * (i + 1) + ) + # 确保时间对齐到30分钟边界 aligned_minute = (point_time.minute // 30) * 30 - point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0) - + point_time = point_time.replace( + minute=aligned_minute, second=0, microsecond=0 + ) + history_point = { "timestamp": point_time.isoformat(), "online_count": 0, "playing_count": 0, "peak_online": 0, "peak_playing": 0, - "total_samples": 0 + "total_samples": 0, } fill_points.append(json.dumps(history_point)) - + # 将填充数据添加到历史记录末尾(最旧的数据) if fill_points: # 先将现有数据转移到临时位置 temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp" if history_length > 0: # 复制现有数据到临时key - existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) + existing_data = await _redis_exec( + redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1 + ) if existing_data: for data in existing_data: await _redis_exec(redis_sync.rpush, temp_key, data) - + # 清空原有key await redis_async.delete(REDIS_ONLINE_HISTORY_KEY) - + # 先添加填充数据(最旧的) for point in reversed(fill_points): # 反向添加,最旧的在最后 - await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point) - + await _redis_exec( + redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point + ) + # 再添加原有数据(较新的) if history_length > 0: - existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1) + existing_data = await _redis_exec( + redis_sync.lrange, temp_key, 0, -1 + ) for data in existing_data: - await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data) - + await _redis_exec( + redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data + ) + # 清理临时key await redis_async.delete(temp_key) - + # 确保只保留48个数据点 await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) - + # 设置过期时间 await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) - - logger.info(f"Filled {len(fill_points)} historical data points with zeros") - + + logger.info( + f"Filled {len(fill_points)} historical data points with zeros" + ) + except Exception as e: logger.error(f"Error ensuring 24h history exists: {e}") - + @staticmethod async def add_user_to_interval(user_id: int, is_playing: bool = False) -> None: """添加用户到当前区间统计 - 实时更新当前运行的区间""" redis_sync = get_redis_message() redis_async = get_redis() - + try: - current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() - + current_interval = ( + await EnhancedIntervalStatsManager.get_current_interval_info() + ) + # 添加到区间在线用户集合 online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" await _redis_exec(redis_sync.sadd, online_key, str(user_id)) await redis_async.expire(online_key, 35 * 60) - + # 如果用户在游玩,也添加到游玩用户集合 if is_playing: - playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" + playing_key = ( + f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" + ) await _redis_exec(redis_sync.sadd, playing_key, str(user_id)) await redis_async.expire(playing_key, 35 * 60) - + # 立即更新区间统计(同步更新,确保数据实时性) await EnhancedIntervalStatsManager._update_interval_stats() - - logger.debug(f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}") - + + logger.debug( + f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}" + ) + except Exception as e: logger.error(f"Error adding user {user_id} to interval: {e}") - + @staticmethod async def _update_interval_stats() -> None: """更新当前区间统计 - 立即同步更新""" redis_sync = get_redis_message() redis_async = get_redis() - + try: - current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() - + current_interval = ( + await EnhancedIntervalStatsManager.get_current_interval_info() + ) + # 获取区间内独特用户数 online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" - playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" - + playing_key = ( + f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" + ) + unique_online = await _redis_exec(redis_sync.scard, online_key) unique_playing = await _redis_exec(redis_sync.scard, playing_key) - + # 获取当前实时用户数作为峰值参考 current_online = await _get_online_users_count(redis_async) current_playing = await _get_playing_users_count(redis_async) - + # 获取现有统计数据 - existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key) + existing_data = await _redis_exec( + redis_sync.get, current_interval.interval_key + ) if existing_data: stats = IntervalStats.from_dict(json.loads(existing_data)) # 更新峰值 stats.peak_online_count = max(stats.peak_online_count, current_online) - stats.peak_playing_count = max(stats.peak_playing_count, current_playing) + stats.peak_playing_count = max( + stats.peak_playing_count, current_playing + ) stats.total_samples += 1 else: # 创建新的统计记录 @@ -324,46 +364,52 @@ class EnhancedIntervalStatsManager: peak_online_count=current_online, peak_playing_count=current_playing, total_samples=1, - created_at=datetime.utcnow() + created_at=datetime.utcnow(), ) - + # 更新独特用户数 stats.unique_online_users = unique_online stats.unique_playing_users = unique_playing - + # 立即保存更新的统计数据 await _redis_exec( redis_sync.set, current_interval.interval_key, - json.dumps(stats.to_dict()) + json.dumps(stats.to_dict()), ) await redis_async.expire(current_interval.interval_key, 35 * 60) - - logger.debug(f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}") - + + logger.debug( + f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}" + ) + except Exception as e: logger.error(f"Error updating interval stats: {e}") - + @staticmethod - async def finalize_interval() -> Optional[IntervalStats]: + async def finalize_interval() -> IntervalStats | None: """完成当前区间统计并保存到历史""" redis_sync = get_redis_message() redis_async = get_redis() - + try: - current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() - + current_interval = ( + await EnhancedIntervalStatsManager.get_current_interval_info() + ) + # 最后一次更新统计 await EnhancedIntervalStatsManager._update_interval_stats() - + # 获取最终统计数据 - stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key) + stats_data = await _redis_exec( + redis_sync.get, current_interval.interval_key + ) if not stats_data: logger.warning("No interval stats found to finalize") return None - + stats = IntervalStats.from_dict(json.loads(stats_data)) - + # 创建历史记录点(使用区间结束时间作为时间戳,确保时间对齐) history_point = { "timestamp": current_interval.end_time.isoformat(), @@ -371,16 +417,18 @@ class EnhancedIntervalStatsManager: "playing_count": stats.unique_playing_users, "peak_online": stats.peak_online_count, "peak_playing": stats.peak_playing_count, - "total_samples": stats.total_samples + "total_samples": stats.total_samples, } - + # 添加到历史记录 - await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) + await _redis_exec( + redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point) + ) # 只保留48个数据点(24小时,每30分钟一个点) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) # 设置过期时间为26小时,确保有足够缓冲 await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) - + logger.info( f"Finalized interval stats: " f"unique_online={stats.unique_online_users}, " @@ -390,64 +438,70 @@ class EnhancedIntervalStatsManager: f"samples={stats.total_samples} " f"for {stats.start_time.strftime('%H:%M')}-{stats.end_time.strftime('%H:%M')}" ) - + return stats - + except Exception as e: logger.error(f"Error finalizing interval stats: {e}") return None - + @staticmethod - async def get_current_interval_stats() -> Optional[IntervalStats]: + async def get_current_interval_stats() -> IntervalStats | None: """获取当前区间统计""" redis_sync = get_redis_message() - + try: - current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() - stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key) - + current_interval = ( + await EnhancedIntervalStatsManager.get_current_interval_info() + ) + stats_data = await _redis_exec( + redis_sync.get, current_interval.interval_key + ) + if stats_data: return IntervalStats.from_dict(json.loads(stats_data)) return None - + except Exception as e: logger.error(f"Error getting current interval stats: {e}") return None - + @staticmethod async def cleanup_old_intervals() -> None: """清理过期的区间数据""" redis_async = get_redis() - + try: # 删除过期的区间统计数据(超过2小时的) cutoff_time = datetime.utcnow() - timedelta(hours=2) pattern = f"{INTERVAL_STATS_BASE_KEY}:*" - + keys = await redis_async.keys(pattern) for key in keys: try: # 从key中提取时间 - time_part = key.decode().split(':')[-1] # YYYYMMDD_HHMM格式 - key_time = datetime.strptime(time_part, '%Y%m%d_%H%M') - + time_part = key.decode().split(":")[-1] # YYYYMMDD_HHMM格式 + key_time = datetime.strptime(time_part, "%Y%m%d_%H%M") + if key_time < cutoff_time: await redis_async.delete(key) # 也删除对应的用户集合 await redis_async.delete(f"{INTERVAL_ONLINE_USERS_KEY}:{key}") await redis_async.delete(f"{INTERVAL_PLAYING_USERS_KEY}:{key}") - + except (ValueError, IndexError): # 忽略解析错误的key continue - + logger.debug("Cleaned up old interval data") - + except Exception as e: logger.error(f"Error cleaning up old intervals: {e}") # 便捷函数,用于替换现有的统计更新函数 -async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None: +async def update_user_activity_in_interval( + user_id: int, is_playing: bool = False +) -> None: """用户活动时更新区间统计(在登录、开始游玩等时调用)""" await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing) diff --git a/app/service/message_queue.py b/app/service/message_queue.py index cc921d5..0665796 100644 --- a/app/service/message_queue.py +++ b/app/service/message_queue.py @@ -3,53 +3,52 @@ Redis 消息队列服务 用于实现实时消息推送和异步数据库持久化 """ -import asyncio -import json -import uuid -from datetime import datetime -from functools import partial -from typing import Optional, Union -import concurrent.futures +from __future__ import annotations -from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType +import asyncio +import concurrent.futures +from datetime import datetime +import uuid + +from app.database.chat import ChatMessage, MessageType from app.dependencies.database import get_redis, with_db from app.log import logger class MessageQueue: """Redis 消息队列服务""" - + def __init__(self): self.redis = get_redis() self._processing = False self._batch_size = 50 # 批量处理大小 self._batch_timeout = 1.0 # 批量处理超时时间(秒) self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) - + async def _run_in_executor(self, func, *args): """在线程池中运行同步 Redis 操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self._executor, func, *args) - + async def start_processing(self): """启动消息处理任务""" if not self._processing: self._processing = True asyncio.create_task(self._process_message_queue()) logger.info("Message queue processing started") - + async def stop_processing(self): """停止消息处理""" self._processing = False logger.info("Message queue processing stopped") - + async def enqueue_message(self, message_data: dict) -> str: """ 将消息加入 Redis 队列(实时响应) - + Args: message_data: 消息数据字典,包含所有必要的字段 - + Returns: 消息的临时 UUID """ @@ -58,36 +57,42 @@ class MessageQueue: message_data["temp_uuid"] = temp_uuid message_data["timestamp"] = datetime.now().isoformat() message_data["status"] = "pending" # pending, processing, completed, failed - + # 将消息存储到 Redis await self._run_in_executor( lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data) ) - await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 - + await self._run_in_executor( + self.redis.expire, f"msg:{temp_uuid}", 3600 + ) # 1小时过期 + # 加入处理队列 await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid) - + logger.info(f"Message enqueued with temp_uuid: {temp_uuid}") return temp_uuid - - async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + + async def get_message_status(self, temp_uuid: str) -> dict | None: """获取消息状态""" - message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") + message_data = await self._run_in_executor( + self.redis.hgetall, f"msg:{temp_uuid}" + ) if not message_data: return None - + return message_data - - async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: + + async def get_cached_messages( + self, channel_id: int, limit: int = 50, since: int = 0 + ) -> list[dict]: """ 从 Redis 获取缓存的消息 - + Args: channel_id: 频道 ID limit: 限制数量 since: 获取自此消息 ID 之后的消息 - + Returns: 消息列表 """ @@ -95,29 +100,39 @@ class MessageQueue: message_uuids = await self._run_in_executor( self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1 ) - + messages = [] for uuid_str in message_uuids: - message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}") + message_data = await self._run_in_executor( + self.redis.hgetall, f"msg:{uuid_str}" + ) if message_data: # 检查是否满足 since 条件 if since > 0 and "message_id" in message_data: if int(message_data["message_id"]) <= since: continue - + messages.append(message_data) - + return messages[::-1] # 返回时间顺序 - - async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100): + + async def cache_channel_message( + self, channel_id: int, temp_uuid: str, max_cache: int = 100 + ): """将消息 UUID 缓存到频道消息列表""" # 添加到频道消息列表开头 - await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid) + await self._run_in_executor( + self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid + ) # 限制缓存大小 - await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1) + await self._run_in_executor( + self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1 + ) # 设置过期时间(24小时) - await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400) - + await self._run_in_executor( + self.redis.expire, f"channel:{channel_id}:messages", 86400 + ) + async def _process_message_queue(self): """异步处理消息队列,批量写入数据库""" while self._processing: @@ -132,75 +147,90 @@ class MessageQueue: message_uuids.append(result[1]) else: break - + if message_uuids: await self._process_message_batch(message_uuids) else: # 没有消息时短暂等待 await asyncio.sleep(0.1) - + except Exception as e: logger.error(f"Error processing message queue: {e}") await asyncio.sleep(1) # 错误时等待1秒再重试 - + async def _process_message_batch(self, message_uuids: list[str]): """批量处理消息写入数据库""" async with with_db() as session: messages_to_insert = [] - + for temp_uuid in message_uuids: try: # 获取消息数据 - message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") + message_data = await self._run_in_executor( + self.redis.hgetall, f"msg:{temp_uuid}" + ) if not message_data: continue - + # 更新状态为处理中 - await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing") - + await self._run_in_executor( + self.redis.hset, f"msg:{temp_uuid}", "status", "processing" + ) + # 创建数据库消息对象 msg = ChatMessage( channel_id=int(message_data["channel_id"]), content=message_data["content"], sender_id=int(message_data["sender_id"]), type=MessageType(message_data["type"]), - uuid=message_data.get("user_uuid") # 用户提供的 UUID(如果有) + uuid=message_data.get("user_uuid"), # 用户提供的 UUID(如果有) ) - + messages_to_insert.append((msg, temp_uuid)) - + except Exception as e: logger.error(f"Error preparing message {temp_uuid}: {e}") - await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") - + await self._run_in_executor( + self.redis.hset, f"msg:{temp_uuid}", "status", "failed" + ) + if messages_to_insert: try: # 批量插入数据库 for msg, temp_uuid in messages_to_insert: session.add(msg) - + await session.commit() - + # 更新所有消息状态和真实 ID for msg, temp_uuid in messages_to_insert: await session.refresh(msg) await self._run_in_executor( - lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={ - "status": "completed", - "message_id": str(msg.message_id), - "created_at": msg.timestamp.isoformat() if msg.timestamp else "" - }) + lambda: self.redis.hset( + f"msg:{temp_uuid}", + mapping={ + "status": "completed", + "message_id": str(msg.message_id), + "created_at": msg.timestamp.isoformat() + if msg.timestamp + else "", + }, + ) ) - - logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}") - + + logger.info( + f"Message {temp_uuid} persisted to DB with ID {msg.message_id}" + ) + except Exception as e: logger.error(f"Error inserting messages to database: {e}") await session.rollback() - + # 标记所有消息为失败 for _, temp_uuid in messages_to_insert: - await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") + await self._run_in_executor( + self.redis.hset, f"msg:{temp_uuid}", "status", "failed" + ) # 全局消息队列实例 diff --git a/app/service/message_queue_processor.py b/app/service/message_queue_processor.py index 3f4180f..3ecce55 100644 --- a/app/service/message_queue_processor.py +++ b/app/service/message_queue_processor.py @@ -3,12 +3,12 @@ 专门处理 Redis 消息队列的异步写入数据库 """ +from __future__ import annotations + import asyncio -import json -import uuid from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from typing import Optional +import json from app.database.chat import ChatMessage, MessageType from app.dependencies.database import get_redis_message, with_db @@ -17,103 +17,132 @@ from app.log import logger class MessageQueueProcessor: """消息队列处理器""" - + def __init__(self): self.redis_message = get_redis_message() self.executor = ThreadPoolExecutor(max_workers=2) self._processing = False self._queue_task = None - + async def _redis_exec(self, func, *args, **kwargs): """在线程池中执行 Redis 操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) - + async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str): """将消息缓存到 Redis""" try: # 存储消息数据 - await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data) - await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 - + await self._redis_exec( + self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data + ) + await self._redis_exec( + self.redis_message.expire, f"msg:{temp_uuid}", 3600 + ) # 1小时过期 + # 加入频道消息列表 - await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid) - await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条 - await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期 - + await self._redis_exec( + self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid + ) + await self._redis_exec( + self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99 + ) # 保持最新100条 + await self._redis_exec( + self.redis_message.expire, f"channel:{channel_id}:messages", 86400 + ) # 24小时过期 + # 加入异步处理队列 - await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid) - + await self._redis_exec( + self.redis_message.lpush, "message_write_queue", temp_uuid + ) + logger.info(f"Message cached to Redis: {temp_uuid}") except Exception as e: logger.error(f"Failed to cache message to Redis: {e}") - - async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: + + async def get_cached_messages( + self, channel_id: int, limit: int = 50, since: int = 0 + ) -> list[dict]: """从 Redis 获取缓存的消息""" try: message_uuids = await self._redis_exec( - self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1 + self.redis_message.lrange, + f"channel:{channel_id}:messages", + 0, + limit - 1, ) - + messages = [] for temp_uuid in message_uuids: # 解码 UUID 如果它是字节类型 if isinstance(temp_uuid, bytes): - temp_uuid = temp_uuid.decode('utf-8') - - raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + temp_uuid = temp_uuid.decode("utf-8") + + raw_data = await self._redis_exec( + self.redis_message.hgetall, f"msg:{temp_uuid}" + ) if raw_data: # 解码 Redis 返回的字节数据 message_data = { - k.decode('utf-8') if isinstance(k, bytes) else k: - v.decode('utf-8') if isinstance(v, bytes) else v + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode( + "utf-8" + ) + if isinstance(v, bytes) + else v for k, v in raw_data.items() } - + # 检查 since 条件 if since > 0 and message_data.get("message_id"): if int(message_data["message_id"]) <= since: continue messages.append(message_data) - + return messages[::-1] # 按时间顺序返回 except Exception as e: logger.error(f"Failed to get cached messages: {e}") return [] - - async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None): + + async def update_message_status( + self, temp_uuid: str, status: str, message_id: int | None = None + ): """更新消息状态""" try: update_data = {"status": status} if message_id: update_data["message_id"] = str(message_id) update_data["db_timestamp"] = datetime.now().isoformat() - - await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data) + + await self._redis_exec( + self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data + ) except Exception as e: logger.error(f"Failed to update message status: {e}") - - async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + + async def get_message_status(self, temp_uuid: str) -> dict | None: """获取消息状态""" try: - raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + raw_data = await self._redis_exec( + self.redis_message.hgetall, f"msg:{temp_uuid}" + ) if not raw_data: return None - + # 解码 Redis 返回的字节数据 return { - k.decode('utf-8') if isinstance(k, bytes) else k: - v.decode('utf-8') if isinstance(v, bytes) else v + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") + if isinstance(v, bytes) + else v for k, v in raw_data.items() } except Exception as e: logger.error(f"Failed to get message status: {e}") return None - + async def _process_message_queue(self): """处理消息队列,异步写入数据库""" logger.info("Message queue processing started") - + while self._processing: try: # 批量获取消息 @@ -126,47 +155,52 @@ class MessageQueueProcessor: # result是 (queue_name, value) 的元组,需要解码 uuid_value = result[1] if isinstance(uuid_value, bytes): - uuid_value = uuid_value.decode('utf-8') + uuid_value = uuid_value.decode("utf-8") message_uuids.append(uuid_value) else: break - + if not message_uuids: await asyncio.sleep(0.5) continue - + # 批量写入数据库 await self._process_message_batch(message_uuids) - + except Exception as e: logger.error(f"Error in message queue processing: {e}") await asyncio.sleep(1) - + logger.info("Message queue processing stopped") - + async def _process_message_batch(self, message_uuids: list[str]): """批量处理消息写入数据库""" async with with_db() as session: for temp_uuid in message_uuids: try: # 获取消息数据并解码 - raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + raw_data = await self._redis_exec( + self.redis_message.hgetall, f"msg:{temp_uuid}" + ) if not raw_data: continue - + # 解码 Redis 返回的字节数据 message_data = { - k.decode('utf-8') if isinstance(k, bytes) else k: - v.decode('utf-8') if isinstance(v, bytes) else v + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode( + "utf-8" + ) + if isinstance(v, bytes) + else v for k, v in raw_data.items() } - + if message_data.get("status") != "pending": continue - + # 更新状态为处理中 await self.update_message_status(temp_uuid, "processing") - + # 创建数据库消息 msg = ChatMessage( channel_id=int(message_data["channel_id"]), @@ -175,15 +209,17 @@ class MessageQueueProcessor: type=MessageType(message_data["type"]), uuid=message_data.get("user_uuid") or None, ) - + session.add(msg) await session.commit() await session.refresh(msg) - + # 更新成功状态,包含临时消息ID映射 assert msg.message_id is not None - await self.update_message_status(temp_uuid, "completed", msg.message_id) - + await self.update_message_status( + temp_uuid, "completed", msg.message_id + ) + # 如果有临时消息ID,存储映射关系并通知客户端更新 if message_data.get("temp_message_id"): temp_msg_id = int(message_data["temp_message_id"]) @@ -191,53 +227,65 @@ class MessageQueueProcessor: self.redis_message.set, f"temp_to_real:{temp_msg_id}", str(msg.message_id), - ex=3600 # 1小时过期 + ex=3600, # 1小时过期 ) - + # 发送消息ID更新通知到频道 channel_id = int(message_data["channel_id"]) - await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data) - - logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}") - + await self._notify_message_update( + channel_id, temp_msg_id, msg.message_id, message_data + ) + + logger.info( + f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}" + ) + except Exception as e: logger.error(f"Failed to process message {temp_uuid}: {e}") await self.update_message_status(temp_uuid, "failed") - - async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict): + + async def _notify_message_update( + self, + channel_id: int, + temp_message_id: int, + real_message_id: int, + message_data: dict, + ): """通知客户端消息ID已更新""" try: # 这里我们需要通过 SignalR 发送消息更新通知 # 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件 update_event = { - "event": "chat.message.update", + "event": "chat.message.update", "data": { "channel_id": channel_id, "temp_message_id": temp_message_id, "real_message_id": real_message_id, - "timestamp": message_data.get("timestamp") - } + "timestamp": message_data.get("timestamp"), + }, } - + # 发布到 Redis 频道,让 SignalR 服务处理 await self._redis_exec( self.redis_message.publish, f"chat_updates:{channel_id}", - json.dumps(update_event) + json.dumps(update_event), ) - - logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}") - + + logger.info( + f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}" + ) + except Exception as e: logger.error(f"Failed to notify message update: {e}") - + def start_processing(self): """启动消息队列处理""" if not self._processing: self._processing = True self._queue_task = asyncio.create_task(self._process_message_queue()) logger.info("Message queue processor started") - + def stop_processing(self): """停止消息队列处理""" if self._processing: @@ -246,10 +294,10 @@ class MessageQueueProcessor: self._queue_task.cancel() self._queue_task = None logger.info("Message queue processor stopped") - + def __del__(self): """清理资源""" - if hasattr(self, 'executor'): + if hasattr(self, "executor"): self.executor.shutdown(wait=False) @@ -272,11 +320,13 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: await message_queue_processor.cache_message(channel_id, message_data, temp_uuid) -async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: +async def get_cached_messages( + channel_id: int, limit: int = 50, since: int = 0 +) -> list[dict]: """从 Redis 获取缓存的消息 - 便捷接口""" return await message_queue_processor.get_cached_messages(channel_id, limit, since) -async def get_message_status(temp_uuid: str) -> Optional[dict]: +async def get_message_status(temp_uuid: str) -> dict | None: """获取消息状态 - 便捷接口""" return await message_queue_processor.get_message_status(temp_uuid) diff --git a/app/service/optimized_message.py b/app/service/optimized_message.py index 6ffd10d..76eb6b4 100644 --- a/app/service/optimized_message.py +++ b/app/service/optimized_message.py @@ -3,23 +3,26 @@ 结合 Redis 缓存和异步数据库写入实现实时消息传送 """ -from typing import Optional -from fastapi import HTTPException +from __future__ import annotations -from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp +from app.database.chat import ( + ChannelType, + ChatMessageResp, + MessageType, +) from app.database.lazer_user import User -from app.router.notification.server import server -from app.service.message_queue import message_queue from app.log import logger +from app.service.message_queue import message_queue + from sqlalchemy.ext.asyncio import AsyncSession class OptimizedMessageService: """优化的消息服务""" - + def __init__(self): self.message_queue = message_queue - + async def send_message_fast( self, channel_id: int, @@ -28,12 +31,12 @@ class OptimizedMessageService: content: str, sender: User, is_action: bool = False, - user_uuid: Optional[str] = None, - session: Optional[AsyncSession] = None + user_uuid: str | None = None, + session: AsyncSession | None = None, ) -> ChatMessageResp: """ 快速发送消息(先缓存到 Redis,异步写入数据库) - + Args: channel_id: 频道 ID channel_type: 频道类型 @@ -43,12 +46,12 @@ class OptimizedMessageService: is_action: 是否为动作消息 user_uuid: 用户提供的 UUID session: 数据库会话(可选,用于一些验证) - + Returns: 消息响应对象 """ assert sender.id is not None - + # 准备消息数据 message_data = { "channel_id": str(channel_id), @@ -57,27 +60,28 @@ class OptimizedMessageService: "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, "user_uuid": user_uuid or "", "channel_type": channel_type.value, - "channel_name": channel_name + "channel_name": channel_name, } - + # 立即将消息加入 Redis 队列(实时响应) temp_uuid = await self.message_queue.enqueue_message(message_data) - + # 缓存到频道消息列表 await self.message_queue.cache_channel_message(channel_id, temp_uuid) - + # 创建临时响应对象(简化版本,用于立即响应) from datetime import datetime + from app.database.lazer_user import UserResp - + # 创建基本的用户响应对象 user_resp = UserResp( id=sender.id, username=sender.username, - country_code=getattr(sender, 'country_code', 'XX'), + country_code=getattr(sender, "country_code", "XX"), # 基本字段,其他复杂字段可以后续异步加载 ) - + temp_response = ChatMessageResp( message_id=0, # 临时 ID,等数据库写入后会更新 channel_id=channel_id, @@ -86,63 +90,62 @@ class OptimizedMessageService: sender_id=sender.id, sender=user_resp, is_action=is_action, - uuid=user_uuid + uuid=user_uuid, ) temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新 - + logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}") return temp_response - + async def get_cached_messages( - self, - channel_id: int, - limit: int = 50, - since: int = 0 + self, channel_id: int, limit: int = 50, since: int = 0 ) -> list[dict]: """ 获取缓存的消息 - + Args: channel_id: 频道 ID limit: 限制数量 since: 获取自此消息 ID 之后的消息 - + Returns: 消息列表 """ return await self.message_queue.get_cached_messages(channel_id, limit, since) - - async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + + async def get_message_status(self, temp_uuid: str) -> dict | None: """ 获取消息状态 - + Args: temp_uuid: 临时消息 UUID - + Returns: 消息状态信息 """ return await self.message_queue.get_message_status(temp_uuid) - - async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]: + + async def wait_for_message_persisted( + self, temp_uuid: str, timeout: int = 30 + ) -> dict | None: """ 等待消息持久化到数据库 - + Args: temp_uuid: 临时消息 UUID timeout: 超时时间(秒) - + Returns: 完成后的消息状态 """ import asyncio - + for _ in range(timeout * 10): # 每100ms检查一次 status = await self.get_message_status(temp_uuid) if status and status.get("status") in ["completed", "failed"]: return status await asyncio.sleep(0.1) - + return None diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py index cd58ffb..c6322f2 100644 --- a/app/service/redis_message_system.py +++ b/app/service/redis_message_system.py @@ -5,59 +5,66 @@ - 支持消息状态同步和故障恢复 """ +from __future__ import annotations + import asyncio +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime import json import time -import uuid -from datetime import datetime -from typing import Optional, List, Dict, Any -from concurrent.futures import ThreadPoolExecutor +from typing import Any -from app.database.chat import ChatMessage, MessageType, ChatMessageResp -from app.database.lazer_user import User, UserResp, RANKING_INCLUDES +from app.database.chat import ChatMessage, ChatMessageResp, MessageType +from app.database.lazer_user import RANKING_INCLUDES, User, UserResp from app.dependencies.database import get_redis_message, with_db from app.log import logger class RedisMessageSystem: """Redis 消息系统""" - + def __init__(self): self.redis = get_redis_message() self.executor = ThreadPoolExecutor(max_workers=2) - self._batch_timer: Optional[asyncio.Task] = None + self._batch_timer: asyncio.Task | None = None self._running = False self.batch_interval = 5.0 # 5秒批量存储一次 self.max_batch_size = 100 # 每批最多处理100条消息 - + async def _redis_exec(self, func, *args, **kwargs): """在线程池中执行 Redis 操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) - - async def send_message(self, channel_id: int, user: User, content: str, - is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp: + + async def send_message( + self, + channel_id: int, + user: User, + content: str, + is_action: bool = False, + user_uuid: str | None = None, + ) -> ChatMessageResp: """ 发送消息 - 立即存储到 Redis 并返回 - + Args: channel_id: 频道ID user: 发送用户 content: 消息内容 is_action: 是否为动作消息 user_uuid: 用户UUID - + Returns: ChatMessageResp: 消息响应对象 """ # 生成消息ID和时间戳 message_id = await self._generate_message_id(channel_id) timestamp = datetime.now() - + # 确保用户ID存在 if not user.id: raise ValueError("User ID is required") - + # 准备消息数据 message_data = { "message_id": message_id, @@ -68,19 +75,20 @@ class RedisMessageSystem: "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, "uuid": user_uuid or "", "status": "cached", # Redis 缓存状态 - "created_at": time.time() + "created_at": time.time(), } - + # 立即存储到 Redis await self._store_to_redis(message_id, channel_id, message_data) - + # 创建响应对象 async with with_db() as session: user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES) - + # 确保 statistics 不为空 if user_resp.statistics is None: from app.database.statistics import UserStatisticsResp + user_resp.statistics = UserStatisticsResp( mode=user.playmode, global_rank=0, @@ -96,9 +104,9 @@ class RedisMessageSystem: replays_watched_by_others=0, is_ranked=False, grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, - level={"current": 1, "progress": 0} + level={"current": 1, "progress": 0}, ) - + response = ChatMessageResp( message_id=message_id, channel_id=channel_id, @@ -107,51 +115,71 @@ class RedisMessageSystem: sender_id=user.id, sender=user_resp, is_action=is_action, - uuid=user_uuid + uuid=user_uuid, + ) + + logger.info( + f"Message {message_id} sent to Redis cache for channel {channel_id}" ) - - logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") return response - - async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]: + + async def get_messages( + self, channel_id: int, limit: int = 50, since: int = 0 + ) -> list[ChatMessageResp]: """ 获取频道消息 - 优先从 Redis 获取最新消息 - + Args: channel_id: 频道ID limit: 消息数量限制 since: 起始消息ID - + Returns: List[ChatMessageResp]: 消息列表 """ messages = [] - + try: # 从 Redis 获取最新消息 redis_messages = await self._get_from_redis(channel_id, limit, since) - + # 为每条消息构建响应对象 async with with_db() as session: for msg_data in redis_messages: # 获取发送者信息 sender = await session.get(User, msg_data["sender_id"]) if sender: - user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) - + user_resp = await UserResp.from_db( + sender, session, RANKING_INCLUDES + ) + if user_resp.statistics is None: from app.database.statistics import UserStatisticsResp + user_resp.statistics = UserStatisticsResp( mode=sender.playmode, - global_rank=0, country_rank=0, pp=0.0, - ranked_score=0, hit_accuracy=0.0, play_count=0, - play_time=0, total_score=0, total_hits=0, - maximum_combo=0, replays_watched_by_others=0, + global_rank=0, + country_rank=0, + pp=0.0, + ranked_score=0, + hit_accuracy=0.0, + play_count=0, + play_time=0, + total_score=0, + total_hits=0, + maximum_combo=0, + replays_watched_by_others=0, is_ranked=False, - grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, - level={"current": 1, "progress": 0} + grade_counts={ + "ssh": 0, + "ss": 0, + "sh": 0, + "s": 0, + "a": 0, + }, + level={"current": 1, "progress": 0}, ) - + message_resp = ChatMessageResp( message_id=msg_data["message_id"], channel_id=msg_data["channel_id"], @@ -160,77 +188,97 @@ class RedisMessageSystem: sender_id=msg_data["sender_id"], sender=user_resp, is_action=msg_data["type"] == MessageType.ACTION.value, - uuid=msg_data.get("uuid") or None + uuid=msg_data.get("uuid") or None, ) messages.append(message_resp) - + # 如果 Redis 消息不够,从数据库补充 if len(messages) < limit and since == 0: await self._backfill_from_database(channel_id, messages, limit) - + except Exception as e: logger.error(f"Failed to get messages from Redis: {e}") # 回退到数据库查询 messages = await self._get_from_database_only(channel_id, limit, since) - + return messages[:limit] - + async def _generate_message_id(self, channel_id: int) -> int: """生成唯一的消息ID - 确保全局唯一且严格递增""" # 使用全局计数器确保所有频道的消息ID都是严格递增的 - message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") - + message_id = await self._redis_exec( + self.redis.incr, "global_message_id_counter" + ) + # 同时更新频道的最后消息ID,用于客户端状态同步 - await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) - + await self._redis_exec( + self.redis.set, f"channel:{channel_id}:last_msg_id", message_id + ) + return message_id - - async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]): + + async def _store_to_redis( + self, message_id: int, channel_id: int, message_data: dict[str, Any] + ): """存储消息到 Redis""" try: # 存储消息数据 await self._redis_exec( - self.redis.hset, - f"msg:{channel_id}:{message_id}", - mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) - for k, v in message_data.items()} + self.redis.hset, + f"msg:{channel_id}:{message_id}", + mapping={ + k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) + for k, v in message_data.items() + }, ) - + # 设置消息过期时间(7天) - await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) - + await self._redis_exec( + self.redis.expire, f"msg:{channel_id}:{message_id}", 604800 + ) + # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) channel_messages_key = f"channel:{channel_id}:messages" - + # 检查键的类型,如果不是 zset 类型则删除 try: key_type = await self._redis_exec(self.redis.type, channel_messages_key) if key_type and key_type != "zset": - logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") + logger.warning( + f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}" + ) await self._redis_exec(self.redis.delete, channel_messages_key) except Exception as type_check_error: - logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") + logger.warning( + f"Failed to check key type for {channel_messages_key}: {type_check_error}" + ) # 如果检查失败,直接删除键以确保清理 await self._redis_exec(self.redis.delete, channel_messages_key) - + # 添加到频道消息列表(sorted set) await self._redis_exec( - self.redis.zadd, - channel_messages_key, - {f"msg:{channel_id}:{message_id}": message_id} + self.redis.zadd, + channel_messages_key, + {f"msg:{channel_id}:{message_id}": message_id}, ) - + # 保持频道消息列表大小(最多1000条) - await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) - + await self._redis_exec( + self.redis.zremrangebyrank, channel_messages_key, 0, -1001 + ) + # 添加到待持久化队列 - await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") - + await self._redis_exec( + self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}" + ) + except Exception as e: logger.error(f"Failed to store message to Redis: {e}") raise - - async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]: + + async def _get_from_redis( + self, channel_id: int, limit: int = 50, since: int = 0 + ) -> list[dict[str, Any]]: """从 Redis 获取消息""" try: # 获取消息键列表,按消息ID排序 @@ -239,22 +287,22 @@ class RedisMessageSystem: message_keys = await self._redis_exec( self.redis.zrangebyscore, f"channel:{channel_id}:messages", - since + 1, "+inf", - start=0, num=limit + since + 1, + "+inf", + start=0, + num=limit, ) else: # 获取最新的消息(倒序获取,然后反转) message_keys = await self._redis_exec( - self.redis.zrevrange, - f"channel:{channel_id}:messages", - 0, limit - 1 + self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1 ) - + messages = [] for key in message_keys: if isinstance(key, bytes): - key = key.decode('utf-8') - + key = key.decode("utf-8") + # 获取消息数据 raw_data = await self._redis_exec(self.redis.hgetall, key) if raw_data: @@ -262,106 +310,118 @@ class RedisMessageSystem: message_data = {} for k, v in raw_data.items(): if isinstance(k, bytes): - k = k.decode('utf-8') + k = k.decode("utf-8") if isinstance(v, bytes): - v = v.decode('utf-8') - + v = v.decode("utf-8") + # 尝试解析 JSON try: - if k in ['grade_counts', 'level'] or v.startswith(('{', '[')): + if k in ["grade_counts", "level"] or v.startswith( + ("{", "[") + ): message_data[k] = json.loads(v) - elif k in ['message_id', 'channel_id', 'sender_id']: + elif k in ["message_id", "channel_id", "sender_id"]: message_data[k] = int(v) - elif k == 'created_at': + elif k == "created_at": message_data[k] = float(v) else: message_data[k] = v except (json.JSONDecodeError, ValueError): message_data[k] = v - + messages.append(message_data) - + # 确保消息按ID正序排序(时间顺序) - messages.sort(key=lambda x: x.get('message_id', 0)) - + messages.sort(key=lambda x: x.get("message_id", 0)) + # 如果是获取最新消息(since=0),需要保持倒序(最新的在前面) if since == 0: messages.reverse() - + return messages - + except Exception as e: logger.error(f"Failed to get messages from Redis: {e}") return [] - - async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int): + + async def _backfill_from_database( + self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int + ): """从数据库补充历史消息""" try: # 找到最小的消息ID - min_id = float('inf') + min_id = float("inf") if existing_messages: for msg in existing_messages: if msg.message_id is not None and msg.message_id < min_id: min_id = msg.message_id - + needed = limit - len(existing_messages) - + if needed <= 0: return - + async with with_db() as session: - from sqlmodel import select, col - query = select(ChatMessage).where( - ChatMessage.channel_id == channel_id - ) - - if min_id != float('inf'): + from sqlmodel import col, select + + query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) + + if min_id != float("inf"): query = query.where(col(ChatMessage.message_id) < min_id) - + query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed) - + db_messages = (await session.exec(query)).all() - + for msg in reversed(db_messages): # 按时间正序插入 msg_resp = await ChatMessageResp.from_db(msg, session) existing_messages.insert(0, msg_resp) - + except Exception as e: logger.error(f"Failed to backfill from database: {e}") - - async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]: + + async def _get_from_database_only( + self, channel_id: int, limit: int, since: int + ) -> list[ChatMessageResp]: """仅从数据库获取消息(回退方案)""" try: async with with_db() as session: - from sqlmodel import select, col + from sqlmodel import col, select + query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) - + if since > 0: # 获取指定ID之后的消息,按ID正序 query = query.where(col(ChatMessage.message_id) > since) - query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit) + query = query.order_by(col(ChatMessage.message_id).asc()).limit( + limit + ) else: # 获取最新消息,按ID倒序(最新的在前面) - query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) - + query = query.order_by(col(ChatMessage.message_id).desc()).limit( + limit + ) + messages = (await session.exec(query)).all() - - results = [await ChatMessageResp.from_db(msg, session) for msg in messages] - + + results = [ + await ChatMessageResp.from_db(msg, session) for msg in messages + ] + # 如果是 since > 0,保持正序;否则反转为时间正序 if since == 0: results.reverse() - + return results - + except Exception as e: logger.error(f"Failed to get messages from database: {e}") return [] - + async def _batch_persist_to_database(self): """批量持久化消息到数据库""" logger.info("Starting batch persistence to database") - + while self._running: try: # 获取待处理的消息 @@ -374,52 +434,52 @@ class RedisMessageSystem: # key 是 (queue_name, value) 的元组 value = key[1] if isinstance(value, bytes): - value = value.decode('utf-8') + value = value.decode("utf-8") message_keys.append(value) else: break - + if message_keys: await self._process_message_batch(message_keys) else: await asyncio.sleep(self.batch_interval) - + except Exception as e: logger.error(f"Error in batch persistence: {e}") await asyncio.sleep(1) - + logger.info("Stopped batch persistence to database") - - async def _process_message_batch(self, message_keys: List[str]): + + async def _process_message_batch(self, message_keys: list[str]): """处理消息批次""" async with with_db() as session: for key in message_keys: try: # 解析频道ID和消息ID - channel_id, message_id = map(int, key.split(':')) - + channel_id, message_id = map(int, key.split(":")) + # 从 Redis 获取消息数据 raw_data = await self._redis_exec( self.redis.hgetall, f"msg:{channel_id}:{message_id}" ) - + if not raw_data: continue - + # 解码数据 message_data = {} for k, v in raw_data.items(): if isinstance(k, bytes): - k = k.decode('utf-8') + k = k.decode("utf-8") if isinstance(v, bytes): - v = v.decode('utf-8') + v = v.decode("utf-8") message_data[k] = v - + # 检查消息是否已存在于数据库 existing = await session.get(ChatMessage, int(message_id)) if existing: continue - + # 创建数据库消息 - 使用 Redis 生成的正数ID db_message = ChatMessage( message_id=int(message_id), # 使用 Redis 系统生成的正数ID @@ -428,31 +488,34 @@ class RedisMessageSystem: content=message_data["content"], timestamp=datetime.fromisoformat(message_data["timestamp"]), type=MessageType(message_data["type"]), - uuid=message_data.get("uuid") or None + uuid=message_data.get("uuid") or None, ) - + session.add(db_message) - + # 更新 Redis 中的状态 await self._redis_exec( self.redis.hset, f"msg:{channel_id}:{message_id}", - "status", "persisted" + "status", + "persisted", ) - + logger.debug(f"Message {message_id} persisted to database") - + except Exception as e: logger.error(f"Failed to process message {key}: {e}") - + # 提交批次 try: await session.commit() - logger.info(f"Batch of {len(message_keys)} messages committed to database") + logger.info( + f"Batch of {len(message_keys)} messages committed to database" + ) except Exception as e: logger.error(f"Failed to commit message batch: {e}") await session.rollback() - + def start(self): """启动系统""" if not self._running: @@ -461,63 +524,71 @@ class RedisMessageSystem: # 启动时初始化消息ID计数器 asyncio.create_task(self._initialize_message_counter()) logger.info("Redis message system started") - + async def _initialize_message_counter(self): """初始化全局消息ID计数器,确保从数据库最大ID开始""" try: # 清理可能存在的问题键 await self._cleanup_redis_keys() - + async with with_db() as session: - from sqlmodel import select, func - + from sqlmodel import func, select + # 获取数据库中最大的消息ID - result = await session.exec( - select(func.max(ChatMessage.message_id)) - ) + result = await session.exec(select(func.max(ChatMessage.message_id))) max_id = result.one() or 0 - + # 检查 Redis 中的计数器值 - current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") + current_counter = await self._redis_exec( + self.redis.get, "global_message_id_counter" + ) current_counter = int(current_counter) if current_counter else 0 - + # 设置计数器为两者中的最大值 initial_counter = max(max_id, current_counter) - await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) - - logger.info(f"Initialized global message ID counter to {initial_counter}") - + await self._redis_exec( + self.redis.set, "global_message_id_counter", initial_counter + ) + + logger.info( + f"Initialized global message ID counter to {initial_counter}" + ) + except Exception as e: logger.error(f"Failed to initialize message counter: {e}") # 如果初始化失败,设置一个安全的起始值 - await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) - + await self._redis_exec( + self.redis.setnx, "global_message_id_counter", 1000000 + ) + async def _cleanup_redis_keys(self): """清理可能存在问题的 Redis 键""" try: # 扫描所有 channel:*:messages 键并检查类型 keys_pattern = "channel:*:messages" keys = await self._redis_exec(self.redis.keys, keys_pattern) - + for key in keys: if isinstance(key, bytes): - key = key.decode('utf-8') - + key = key.decode("utf-8") + try: key_type = await self._redis_exec(self.redis.type, key) if key_type and key_type != "zset": - logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") + logger.warning( + f"Cleaning up Redis key {key} with wrong type: {key_type}" + ) await self._redis_exec(self.redis.delete, key) except Exception as cleanup_error: logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") # 强制删除问题键 await self._redis_exec(self.redis.delete, key) - + logger.info("Redis keys cleanup completed") - + except Exception as e: logger.error(f"Failed to cleanup Redis keys: {e}") - + def stop(self): """停止系统""" if self._running: @@ -526,10 +597,10 @@ class RedisMessageSystem: self._batch_timer.cancel() self._batch_timer = None logger.info("Redis message system stopped") - + def __del__(self): """清理资源""" - if hasattr(self, 'executor'): + if hasattr(self, "executor"): self.executor.shutdown(wait=False) diff --git a/app/service/stats_cleanup.py b/app/service/stats_cleanup.py index f87a03a..a3856a7 100644 --- a/app/service/stats_cleanup.py +++ b/app/service/stats_cleanup.py @@ -1,81 +1,94 @@ from __future__ import annotations from datetime import datetime, timedelta + from app.dependencies.database import get_redis, get_redis_message from app.log import logger -from app.router.v2.stats import REDIS_ONLINE_USERS_KEY, REDIS_PLAYING_USERS_KEY, _redis_exec +from app.router.v2.stats import ( + REDIS_ONLINE_USERS_KEY, + REDIS_PLAYING_USERS_KEY, + _redis_exec, +) async def cleanup_stale_online_users() -> tuple[int, int]: """清理过期的在线和游玩用户,返回清理的用户数""" redis_sync = get_redis_message() redis_async = get_redis() - + online_cleaned = 0 playing_cleaned = 0 - + try: # 获取所有在线用户 online_users = await _redis_exec(redis_sync.smembers, REDIS_ONLINE_USERS_KEY) playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY) - + # 检查在线用户的最后活动时间 current_time = datetime.utcnow() stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 - + # 对于在线用户,我们检查metadata在线标记 stale_online_users = [] for user_id in online_users: - user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) + user_id_str = ( + user_id.decode() if isinstance(user_id, bytes) else str(user_id) + ) metadata_key = f"metadata:online:{user_id_str}" - + # 如果metadata标记不存在,说明用户已经离线 if not await redis_async.exists(metadata_key): stale_online_users.append(user_id_str) - + # 清理过期的在线用户 if stale_online_users: - await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users) + await _redis_exec( + redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users + ) online_cleaned = len(stale_online_users) logger.info(f"Cleaned {online_cleaned} stale online users") - + # 对于游玩用户,我们也检查对应的spectator状态 stale_playing_users = [] for user_id in playing_users: - user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) - + user_id_str = ( + user_id.decode() if isinstance(user_id, bytes) else str(user_id) + ) + # 如果用户不在在线用户列表中,说明已经离线,也应该从游玩列表中移除 if user_id_str in stale_online_users or user_id_str not in [ u.decode() if isinstance(u, bytes) else str(u) for u in online_users ]: stale_playing_users.append(user_id_str) - + # 清理过期的游玩用户 if stale_playing_users: - await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users) + await _redis_exec( + redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users + ) playing_cleaned = len(stale_playing_users) logger.info(f"Cleaned {playing_cleaned} stale playing users") - + except Exception as e: logger.error(f"Error cleaning stale users: {e}") - + return online_cleaned, playing_cleaned async def refresh_redis_key_expiry() -> None: """刷新Redis键的过期时间,防止数据丢失""" redis_async = get_redis() - + try: # 刷新在线用户key的过期时间 if await redis_async.exists(REDIS_ONLINE_USERS_KEY): await redis_async.expire(REDIS_ONLINE_USERS_KEY, 6 * 3600) # 6小时 - + # 刷新游玩用户key的过期时间 if await redis_async.exists(REDIS_PLAYING_USERS_KEY): await redis_async.expire(REDIS_PLAYING_USERS_KEY, 6 * 3600) # 6小时 - + logger.debug("Refreshed Redis key expiry times") - + except Exception as e: logger.error(f"Error refreshing Redis key expiry: {e}") diff --git a/app/service/stats_scheduler.py b/app/service/stats_scheduler.py index 61c1210..02adc73 100644 --- a/app/service/stats_scheduler.py +++ b/app/service/stats_scheduler.py @@ -5,46 +5,49 @@ from datetime import datetime, timedelta from app.log import logger from app.router.v2.stats import record_hourly_stats, update_registered_users_count -from app.service.stats_cleanup import cleanup_stale_online_users, refresh_redis_key_expiry from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager +from app.service.stats_cleanup import ( + cleanup_stale_online_users, + refresh_redis_key_expiry, +) class StatsScheduler: """统计数据调度器""" - + def __init__(self): self._running = False self._stats_task: asyncio.Task | None = None self._registered_task: asyncio.Task | None = None self._cleanup_task: asyncio.Task | None = None - + def start(self) -> None: """启动调度器""" if self._running: return - + self._running = True self._stats_task = asyncio.create_task(self._stats_loop()) self._registered_task = asyncio.create_task(self._registered_users_loop()) self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info("Stats scheduler started") - + def stop(self) -> None: """停止调度器""" if not self._running: return - + self._running = False - + if self._stats_task: self._stats_task.cancel() if self._registered_task: self._registered_task.cancel() if self._cleanup_task: self._cleanup_task.cancel() - + logger.info("Stats scheduler stopped") - + async def _stats_loop(self) -> None: """统计数据记录循环 - 每30分钟记录一次""" # 启动时立即记录一次统计数据 @@ -53,49 +56,57 @@ class StatsScheduler: logger.info("Initial enhanced interval statistics initialized on startup") except Exception as e: logger.error(f"Error initializing enhanced interval stats: {e}") - + while self._running: try: # 计算下次记录时间(下个30分钟整点) now = datetime.utcnow() - + # 计算当前区间边界 current_minute = (now.minute // 30) * 30 - current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30) - + current_interval_end = now.replace( + minute=current_minute, second=0, microsecond=0 + ) + timedelta(minutes=30) + # 如果已经过了当前区间结束时间,立即处理 if now >= current_interval_end: current_interval_end += timedelta(minutes=30) - + # 计算需要等待的时间(到下个区间结束) sleep_seconds = (current_interval_end - now).total_seconds() - + # 确保至少等待1分钟,最多等待31分钟 sleep_seconds = max(min(sleep_seconds, 31 * 60), 60) - - logger.debug(f"Next interval finalization in {sleep_seconds/60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}") + + logger.debug( + f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}" + ) await asyncio.sleep(sleep_seconds) - + if not self._running: break - + # 完成当前区间并记录到历史 finalized_stats = await EnhancedIntervalStatsManager.finalize_interval() if finalized_stats: - logger.info(f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}") + logger.info( + f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}" + ) else: # 如果区间完成失败,使用原有方式记录 await record_hourly_stats() - logger.info(f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}") - + logger.info( + f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}" + ) + # 开始新的区间统计 await EnhancedIntervalStatsManager.initialize_current_interval() - + except Exception as e: logger.error(f"Error in stats loop: {e}") # 出错时等待5分钟再重试 await asyncio.sleep(5 * 60) - + async def _registered_users_loop(self) -> None: """注册用户数更新循环 - 每5分钟更新一次""" # 启动时立即更新一次注册用户数 @@ -104,14 +115,14 @@ class StatsScheduler: logger.info("Initial registered users count updated on startup") except Exception as e: logger.error(f"Error updating initial registered users count: {e}") - + while self._running: # 等待5分钟 await asyncio.sleep(5 * 60) - + if not self._running: break - + try: await update_registered_users_count() logger.debug("Updated registered users count") @@ -124,31 +135,35 @@ class StatsScheduler: try: online_cleaned, playing_cleaned = await cleanup_stale_online_users() if online_cleaned > 0 or playing_cleaned > 0: - logger.info(f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users") - + logger.info( + f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users" + ) + await refresh_redis_key_expiry() except Exception as e: logger.error(f"Error in initial cleanup: {e}") - + while self._running: # 等待10分钟 await asyncio.sleep(10 * 60) - + if not self._running: break - + try: # 清理过期用户 online_cleaned, playing_cleaned = await cleanup_stale_online_users() if online_cleaned > 0 or playing_cleaned > 0: - logger.info(f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users") - + logger.info( + f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users" + ) + # 刷新Redis key过期时间 await refresh_redis_key_expiry() - + # 清理过期的区间数据 await EnhancedIntervalStatsManager.cleanup_old_intervals() - + except Exception as e: logger.error(f"Error in cleanup loop: {e}") # 出错时等待2分钟再重试 diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 8efea4f..0a753f1 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -92,11 +92,12 @@ class MetadataHub(Hub[MetadataClientState]): @override async def _clean_state(self, state: MetadataClientState) -> None: user_id = int(state.connection_id) - + # Remove from online user tracking from app.router.v2.stats import remove_online_user + asyncio.create_task(remove_online_user(user_id)) - + if state.pushable: await asyncio.gather(*self.broadcast_tasks(user_id, None)) redis = get_redis() @@ -125,6 +126,7 @@ class MetadataHub(Hub[MetadataClientState]): # Track online user from app.router.v2.stats import add_online_user + asyncio.create_task(add_online_user(user_id)) async with with_db() as session: diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 11520dd..305f33b 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -163,11 +163,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): @override async def _clean_state(self, state: MultiplayerClientState): user_id = int(state.connection_id) - + # Remove from online user tracking from app.router.v2.stats import remove_online_user + asyncio.create_task(remove_online_user(user_id)) - + if state.room_id != 0 and state.room_id in self.rooms: server_room = self.rooms[state.room_id] room = server_room.room @@ -180,9 +181,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def on_client_connect(self, client: Client) -> None: """Track online users when connecting to multiplayer hub""" logger.info(f"[MultiplayerHub] Client {client.user_id} connected") - + # Track online user from app.router.v2.stats import add_online_user + asyncio.create_task(add_online_user(client.user_id)) def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom: @@ -292,11 +294,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room.users.append(user) self.add_to_group(client, self.group_id(room_id)) await server_room.match_type_handler.handle_join(user) - + # Critical fix: Send current room and gameplay state to new user # This ensures spectators joining ongoing games get proper state sync await self._send_room_state_to_new_user(client, server_room) - + await self.event_logger.player_joined(room_id, user.user_id) async with with_db() as session: @@ -669,16 +671,22 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # Enhanced spectator validation - allow transitions from more states # This matches official osu-server-spectator behavior if old not in ( - MultiplayerUserState.IDLE, + MultiplayerUserState.IDLE, MultiplayerUserState.READY, MultiplayerUserState.RESULTS, # Allow spectating after results ): # Allow spectating during gameplay states only if the room is in appropriate state - if not (old.is_playing and room.room.state in ( - MultiplayerRoomState.WAITING_FOR_LOAD, - MultiplayerRoomState.PLAYING - )): - raise InvokeException(f"Cannot change state from {old} to {new}") + if not ( + old.is_playing + and room.room.state + in ( + MultiplayerRoomState.WAITING_FOR_LOAD, + MultiplayerRoomState.PLAYING, + ) + ): + raise InvokeException( + f"Cannot change state from {old} to {new}" + ) case _: raise InvokeException(f"Invalid state transition from {old} to {new}") @@ -691,7 +699,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user.state == state: return - + # Special handling for state changes during gameplay match state: case MultiplayerUserState.IDLE: @@ -704,15 +712,15 @@ class MultiplayerHub(Hub[MultiplayerClientState]): logger.info( f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}" ) - + await self.validate_user_stare( server_room, user.state, state, ) - + await self.change_user_state(server_room, user, state) - + # Enhanced spectator handling based on official implementation if state == MultiplayerUserState.SPECTATING: await self.handle_spectator_state_change(client, server_room, user) @@ -738,24 +746,21 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def handle_spectator_state_change( - self, - client: Client, - room: ServerMultiplayerRoom, - user: MultiplayerRoomUser + self, client: Client, room: ServerMultiplayerRoom, user: MultiplayerRoomUser ): """ Handle special logic for users entering spectator mode during ongoing gameplay. Based on official osu-server-spectator implementation. """ room_state = room.room.state - + # If switching to spectating during gameplay, immediately request load if room_state == MultiplayerRoomState.WAITING_FOR_LOAD: logger.info( f"[MultiplayerHub] Spectator {user.user_id} joining during load phase" ) await self.call_noblock(client, "LoadRequested") - + elif room_state == MultiplayerRoomState.PLAYING: logger.info( f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay" @@ -763,9 +768,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.call_noblock(client, "LoadRequested") async def _send_current_gameplay_state_to_spectator( - self, - client: Client, - room: ServerMultiplayerRoom + self, client: Client, room: ServerMultiplayerRoom ): """ Send current gameplay state information to a newly joined spectator. @@ -773,12 +776,8 @@ class MultiplayerHub(Hub[MultiplayerClientState]): """ try: # Send current room state - await self.call_noblock( - client, - "RoomStateChanged", - room.room.state - ) - + await self.call_noblock(client, "RoomStateChanged", room.room.state) + # Send current user states for all players for room_user in room.room.users: if room_user.state.is_playing: @@ -788,7 +787,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room_user.user_id, room_user.state, ) - + logger.debug( f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}" ) @@ -798,9 +797,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def _send_room_state_to_new_user( - self, - client: Client, - room: ServerMultiplayerRoom + self, client: Client, room: ServerMultiplayerRoom ): """ Send complete room state to a newly joined user. @@ -809,23 +806,19 @@ class MultiplayerHub(Hub[MultiplayerClientState]): try: # Send current room state if room.room.state != MultiplayerRoomState.OPEN: - await self.call_noblock( - client, - "RoomStateChanged", - room.room.state - ) - + await self.call_noblock(client, "RoomStateChanged", room.room.state) + # If room is in gameplay state, send LoadRequested immediately if room.room.state in ( MultiplayerRoomState.WAITING_FOR_LOAD, - MultiplayerRoomState.PLAYING + MultiplayerRoomState.PLAYING, ): logger.info( f"[MultiplayerHub] Sending LoadRequested to user {client.user_id} " f"joining ongoing game (room state: {room.room.state})" ) await self.call_noblock(client, "LoadRequested") - + # Send all user states to help with synchronization for room_user in room.room.users: if room_user.user_id != client.user_id: # Don't send own state @@ -835,11 +828,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room_user.user_id, room_user.state, ) - + # Critical addition: Send current playing users to SpectatorHub for cross-hub sync # This ensures spectators can watch multiplayer players properly await self._sync_with_spectator_hub(client, room) - + logger.debug( f"[MultiplayerHub] Sent complete room state to new user {client.user_id}" ) @@ -849,9 +842,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def _sync_with_spectator_hub( - self, - client: Client, - room: ServerMultiplayerRoom + self, client: Client, room: ServerMultiplayerRoom ): """ Sync with SpectatorHub to ensure cross-hub spectating works properly. @@ -860,7 +851,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): try: # Import here to avoid circular imports from app.signalr.hub import SpectatorHubs - + # For each playing user in the room, check if they have SpectatorHub state # and notify the new client about their playing status for room_user in room.room.users: @@ -878,7 +869,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): f"[MultiplayerHub] Synced spectator state for user {room_user.user_id} " f"to new client {client.user_id}" ) - + except Exception as e: logger.debug(f"[MultiplayerHub] Failed to sync with SpectatorHub: {e}") # This is not critical, so we don't raise the exception diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index 71e2a92..6702f28 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -170,21 +170,22 @@ class SpectatorHub(Hub[StoreClientState]): Properly notifies watched users when spectator disconnects. """ user_id = int(state.connection_id) - + # Remove from online and playing tracking from app.router.v2.stats import remove_online_user + asyncio.create_task(remove_online_user(user_id)) - + if state.state: await self._end_session(user_id, state.state, state) - + # Critical fix: Notify all watched users that this spectator has disconnected # This matches the official CleanUpState implementation for watched_user_id in state.watched_user: - if (target_client := self.get_client_by_id(str(watched_user_id))) is not None: - await self.call_noblock( - target_client, "UserEndedWatching", user_id - ) + if ( + target_client := self.get_client_by_id(str(watched_user_id)) + ) is not None: + await self.call_noblock(target_client, "UserEndedWatching", user_id) logger.debug( f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching" ) @@ -195,18 +196,19 @@ class SpectatorHub(Hub[StoreClientState]): Send all active player states to newly connected clients. """ logger.info(f"[SpectatorHub] Client {client.user_id} connected") - + # Track online user from app.router.v2.stats import add_online_user + asyncio.create_task(add_online_user(client.user_id)) - + # Send all current player states to the new client # This matches the official OnConnectedAsync behavior active_states = [] for user_id, store in self.state.items(): if store.state is not None: active_states.append((user_id, store.state)) - + if active_states: logger.debug( f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}" @@ -216,8 +218,10 @@ class SpectatorHub(Hub[StoreClientState]): try: await self.call_noblock(client, "UserBeganPlaying", user_id, state) except Exception as e: - logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}") - + logger.debug( + f"[SpectatorHub] Failed to send state for user {user_id}: {e}" + ) + # Also sync with MultiplayerHub for cross-hub spectating await self._sync_with_multiplayer_hub(client) @@ -229,14 +233,15 @@ class SpectatorHub(Hub[StoreClientState]): try: # Import here to avoid circular imports from app.signalr.hub import MultiplayerHubs - + # Check all active multiplayer rooms for playing users for room_id, server_room in MultiplayerHubs.rooms.items(): for room_user in server_room.room.users: # If user is playing in multiplayer but we don't have their spectator state - if (room_user.state.is_playing and - room_user.user_id not in self.state): - + if ( + room_user.state.is_playing + and room_user.user_id not in self.state + ): # Create a synthetic SpectatorState for multiplayer players # This helps with cross-hub spectating try: @@ -245,9 +250,9 @@ class SpectatorHub(Hub[StoreClientState]): ruleset_id=room_user.ruleset_id or 0, # Default to osu! mods=room_user.mods, state=SpectatedUserState.Playing, - maximum_statistics={} + maximum_statistics={}, ) - + await self.call_noblock( client, "UserBeganPlaying", @@ -258,8 +263,10 @@ class SpectatorHub(Hub[StoreClientState]): f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}" ) except Exception as e: - logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}") - + logger.debug( + f"[SpectatorHub] Failed to create synthetic state: {e}" + ) + except Exception as e: logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}") # This is not critical, so we don't raise the exception @@ -306,6 +313,7 @@ class SpectatorHub(Hub[StoreClientState]): # Track playing user from app.router.v2.stats import add_playing_user + asyncio.create_task(add_playing_user(user_id)) # # 预缓存beatmap文件以加速后续PP计算 @@ -356,11 +364,12 @@ class SpectatorHub(Hub[StoreClientState]): ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()): await self._process_score(store, client) await self._end_session(user_id, state, store) - + # Remove from playing user tracking from app.router.v2.stats import remove_playing_user + asyncio.create_task(remove_playing_user(user_id)) - + store.state = None store.beatmap_status = None store.checksum = None @@ -473,8 +482,10 @@ class SpectatorHub(Hub[StoreClientState]): if state.state == SpectatedUserState.Playing: state.state = SpectatedUserState.Quit - logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}") - + logger.debug( + f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}" + ) + # Calculate exit time safely exit_time = 0 if store.score and store.score.replay_frames: @@ -491,7 +502,7 @@ class SpectatorHub(Hub[StoreClientState]): ) self.tasks.add(task) task.add_done_callback(self.tasks.discard) - + # Background task for failtime tracking - only for failed/quit states with valid data if ( state.beatmap_id is not None @@ -519,14 +530,16 @@ class SpectatorHub(Hub[StoreClientState]): Properly handles state synchronization and watcher notifications. """ user_id = int(client.connection_id) - + logger.info(f"[SpectatorHub] {user_id} started watching {target_id}") - + try: # Get target user's current state if it exists target_store = self.state.get(target_id) if target_store and target_store.state: - logger.debug(f"[SpectatorHub] {target_id} is currently {target_store.state.state}") + logger.debug( + f"[SpectatorHub] {target_id} is currently {target_store.state.state}" + ) # Send current state to the watcher immediately await self.call_noblock( client, @@ -552,7 +565,9 @@ class SpectatorHub(Hub[StoreClientState]): await session.exec(select(User.username).where(User.id == user_id)) ).first() if not username: - logger.warning(f"[SpectatorHub] Could not find username for user {user_id}") + logger.warning( + f"[SpectatorHub] Could not find username for user {user_id}" + ) return # Notify target user that someone started watching @@ -562,7 +577,9 @@ class SpectatorHub(Hub[StoreClientState]): await self.call_noblock( target_client, "UserStartedWatching", watcher_info ) - logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching") + logger.debug( + f"[SpectatorHub] Notified {target_id} that {username} started watching" + ) except Exception as e: logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}") @@ -572,19 +589,23 @@ class SpectatorHub(Hub[StoreClientState]): Properly cleans up watcher state and notifies target user. """ user_id = int(client.connection_id) - + logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}") - + # Remove from SignalR group self.remove_from_group(client, self.group_id(target_id)) - + # Remove from our tracked watched users store = self.get_or_create_state(client) store.watched_user.discard(target_id) - + # Notify target user that watcher stopped watching if (target_client := self.get_client_by_id(str(target_id))) is not None: await self.call_noblock(target_client, "UserEndedWatching", user_id) - logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching") + logger.debug( + f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching" + ) else: - logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification") + logger.debug( + f"[SpectatorHub] Target user {target_id} not found for end watching notification" + )