整理代码
This commit is contained in:
@@ -45,7 +45,7 @@ from .relationship import (
|
||||
)
|
||||
from .score_token import ScoreToken
|
||||
|
||||
from pydantic import field_validator, field_serializer
|
||||
from pydantic import field_serializer, field_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
@@ -126,7 +126,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
if isinstance(v, dict):
|
||||
serialized = {}
|
||||
for key, value in v.items():
|
||||
if hasattr(key, 'value'):
|
||||
if hasattr(key, "value"):
|
||||
# 如果是枚举,使用其值
|
||||
serialized[key.value] = value
|
||||
else:
|
||||
@@ -138,7 +138,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
@field_serializer("rank", when_used="json")
|
||||
def serialize_rank(self, v):
|
||||
"""序列化等级,确保枚举值正确转换为字符串"""
|
||||
if hasattr(v, 'value'):
|
||||
if hasattr(v, "value"):
|
||||
return v.value
|
||||
return str(v)
|
||||
|
||||
@@ -188,7 +188,7 @@ class Score(ScoreBase, table=True):
|
||||
@field_serializer("gamemode", when_used="json")
|
||||
def serialize_gamemode(self, v):
|
||||
"""序列化游戏模式,确保枚举值正确转换为字符串"""
|
||||
if hasattr(v, 'value'):
|
||||
if hasattr(v, "value"):
|
||||
return v.value
|
||||
return str(v)
|
||||
|
||||
@@ -281,7 +281,7 @@ class ScoreResp(ScoreBase):
|
||||
if isinstance(v, dict):
|
||||
serialized = {}
|
||||
for key, value in v.items():
|
||||
if hasattr(key, 'value'):
|
||||
if hasattr(key, "value"):
|
||||
# 如果是枚举,使用其值
|
||||
serialized[key.value] = value
|
||||
else:
|
||||
@@ -294,7 +294,7 @@ class ScoreResp(ScoreBase):
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
# 确保 score 对象完全加载,避免懒加载问题
|
||||
await session.refresh(score)
|
||||
|
||||
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
await score.awaitable_attrs.beatmap
|
||||
|
||||
@@ -10,8 +10,8 @@ from app.config import settings
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
import redis as sync_redis
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -40,7 +40,9 @@ engine = create_async_engine(
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
|
||||
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
redis_message_client = sync_redis.from_url(
|
||||
settings.redis_url, decode_responses=True, db=1
|
||||
)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
|
||||
@@ -7,7 +7,7 @@ from app.config import settings
|
||||
|
||||
from .mods import API_MODS, APIMod
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator, field_serializer
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rosu_pp_py as rosu
|
||||
@@ -212,7 +212,7 @@ class SoloScoreSubmissionInfo(BaseModel):
|
||||
if isinstance(v, dict):
|
||||
serialized = {}
|
||||
for key, value in v.items():
|
||||
if hasattr(key, 'value'):
|
||||
if hasattr(key, "value"):
|
||||
# 如果是枚举,使用其值
|
||||
serialized[key.value] = value
|
||||
else:
|
||||
@@ -224,7 +224,7 @@ class SoloScoreSubmissionInfo(BaseModel):
|
||||
@field_serializer("rank", when_used="json")
|
||||
def serialize_rank(self, v):
|
||||
"""序列化等级,确保枚举值正确转换为字符串"""
|
||||
if hasattr(v, 'value'):
|
||||
if hasattr(v, "value"):
|
||||
return v.value
|
||||
return str(v)
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OnlineStats(BaseModel):
|
||||
"""在线统计信息"""
|
||||
|
||||
registered_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
@@ -16,6 +16,7 @@ class OnlineStats(BaseModel):
|
||||
|
||||
class OnlineHistoryPoint(BaseModel):
|
||||
"""在线历史数据点"""
|
||||
|
||||
timestamp: datetime
|
||||
online_count: int
|
||||
playing_count: int
|
||||
@@ -23,12 +24,14 @@ class OnlineHistoryPoint(BaseModel):
|
||||
|
||||
class OnlineHistoryStats(BaseModel):
|
||||
"""24小时在线历史统计"""
|
||||
|
||||
history: list[OnlineHistoryPoint]
|
||||
current_stats: OnlineStats
|
||||
|
||||
|
||||
class ServerStatistics(BaseModel):
|
||||
"""服务器统计信息"""
|
||||
|
||||
total_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
|
||||
@@ -62,7 +62,7 @@ async def get_update(
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
db_channel,
|
||||
@@ -122,9 +122,7 @@ async def join_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
@@ -154,9 +152,7 @@ async def leave_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
@@ -187,7 +183,7 @@ async def get_channel_list(
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
@@ -230,19 +226,17 @@ async def get_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
@@ -325,7 +319,9 @@ async def create_channel(
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel_name)
|
||||
)
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
@@ -350,11 +346,11 @@ async def create_channel(
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -16,14 +11,13 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis, get_redis_message
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.service.optimized_message import optimized_message_service
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
from app.log import logger
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
@@ -106,11 +100,9 @@ async def send_message(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
@@ -118,29 +110,29 @@ async def send_message(
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
user=current_user,
|
||||
content=req.message,
|
||||
is_action=req.is_action,
|
||||
user_uuid=req.uuid
|
||||
user_uuid=req.uuid,
|
||||
)
|
||||
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
|
||||
|
||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||
temp_msg = ChatMessage(
|
||||
@@ -151,7 +143,7 @@ async def send_message(
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
|
||||
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
@@ -163,7 +155,7 @@ async def send_message(
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -191,11 +183,9 @@ async def get_message(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
@@ -218,7 +208,7 @@ async def get_message(
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
if until is not None:
|
||||
query = query.where(col(ChatMessage.message_id) < until)
|
||||
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
messages = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
@@ -247,14 +237,12 @@ async def mark_as_read(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
@@ -96,8 +96,10 @@ class ChatServer:
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
|
||||
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
@@ -107,24 +109,32 @@ class ChatServer:
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
logger.info(
|
||||
f"Broadcasting message to all users in channel {message.channel_id}"
|
||||
)
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
await self.redis.set(
|
||||
f"chat:{message.channel_id}:last_msg", message.message_id
|
||||
)
|
||||
logger.info(
|
||||
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
logger.debug(
|
||||
f"Skipping last message update for message ID: {message.message_id}"
|
||||
)
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
@@ -340,11 +350,9 @@ async def chat_websocket(
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == 1)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
|
||||
).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
@@ -5,4 +5,3 @@ from fastapi import APIRouter
|
||||
router = APIRouter(prefix="/api/v2")
|
||||
|
||||
# 导入所有子路由模块来注册路由
|
||||
from . import stats # 统计路由
|
||||
|
||||
@@ -75,9 +75,10 @@ READ_SCORE_TIMEOUT = 10
|
||||
|
||||
|
||||
async def process_user_achievement(score_id: int):
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.dependencies.database import engine
|
||||
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
session = AsyncSession(engine)
|
||||
try:
|
||||
await process_achievements(session, get_redis(), score_id)
|
||||
@@ -99,7 +100,7 @@ async def submit_score(
|
||||
):
|
||||
# 立即获取用户ID,避免后续的懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
@@ -166,13 +167,15 @@ async def submit_score(
|
||||
has_pp,
|
||||
has_leaderboard,
|
||||
)
|
||||
score = (await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(Score.id == score_id)
|
||||
)).first()
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(Score.id == score_id)
|
||||
)
|
||||
).first()
|
||||
assert score is not None
|
||||
|
||||
|
||||
resp = await ScoreResp.from_db(db, score)
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).first()
|
||||
assert total_users is not None
|
||||
@@ -202,13 +205,10 @@ async def submit_score(
|
||||
# 确保score对象已刷新,避免在后台任务中触发延迟加载
|
||||
await db.refresh(score)
|
||||
score_gamemode = score.gamemode
|
||||
|
||||
|
||||
if user_id is not None:
|
||||
background_task.add_task(
|
||||
_refresh_user_cache_background,
|
||||
redis,
|
||||
user_id,
|
||||
score_gamemode
|
||||
_refresh_user_cache_background, redis, user_id, score_gamemode
|
||||
)
|
||||
background_task.add_task(process_user_achievement, resp.id)
|
||||
return resp
|
||||
@@ -217,9 +217,10 @@ async def submit_score(
|
||||
async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameMode):
|
||||
"""后台任务:刷新用户缓存"""
|
||||
try:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.dependencies.database import engine
|
||||
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
user_cache_service = get_user_cache_service(redis)
|
||||
# 创建独立的数据库会话
|
||||
session = AsyncSession(engine)
|
||||
@@ -422,7 +423,7 @@ async def create_solo_score(
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
@@ -480,7 +481,7 @@ async def create_playlist_score(
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -557,10 +558,10 @@ async def submit_playlist_score(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
|
||||
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
@@ -627,7 +628,7 @@ async def index_playlist_scores(
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -694,7 +695,7 @@ async def show_playlist_score(
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -803,7 +804,7 @@ async def pin_score(
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
@@ -848,7 +849,7 @@ async def unpin_score(
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
@@ -892,7 +893,7 @@ async def reorder_score_pin(
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
@@ -986,9 +987,9 @@ async def download_score_replay(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -1,42 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Redis key constants
|
||||
REDIS_ONLINE_USERS_KEY = "server:online_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
|
||||
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
|
||||
|
||||
# 线程池用于同步Redis操作
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
async def _redis_exec(func, *args, **kwargs):
|
||||
"""在线程池中执行同步Redis操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
class ServerStats(BaseModel):
|
||||
"""服务器统计信息响应模型"""
|
||||
|
||||
registered_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class OnlineHistoryPoint(BaseModel):
|
||||
"""在线历史数据点"""
|
||||
|
||||
timestamp: datetime
|
||||
online_count: int
|
||||
playing_count: int
|
||||
@@ -44,33 +47,36 @@ class OnlineHistoryPoint(BaseModel):
|
||||
peak_playing: int | None = None # 峰值游玩数(增强数据)
|
||||
total_samples: int | None = None # 采样次数(增强数据)
|
||||
|
||||
|
||||
class OnlineHistoryResponse(BaseModel):
|
||||
"""24小时在线历史响应模型"""
|
||||
|
||||
history: list[OnlineHistoryPoint]
|
||||
current_stats: ServerStats
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ServerStats, tags=["统计"])
|
||||
async def get_server_stats() -> ServerStats:
|
||||
"""
|
||||
获取服务器实时统计信息
|
||||
|
||||
|
||||
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
|
||||
"""
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 并行获取所有统计数据
|
||||
registered_count, online_count, playing_count = await asyncio.gather(
|
||||
_get_registered_users_count(redis),
|
||||
_get_online_users_count(redis),
|
||||
_get_playing_users_count(redis)
|
||||
_get_playing_users_count(redis),
|
||||
)
|
||||
|
||||
|
||||
return ServerStats(
|
||||
registered_users=registered_count,
|
||||
online_users=online_count,
|
||||
playing_users=playing_count,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting server stats: {e}")
|
||||
@@ -79,75 +85,86 @@ async def get_server_stats() -> ServerStats:
|
||||
registered_users=0,
|
||||
online_users=0,
|
||||
playing_users=0,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
|
||||
async def get_online_history() -> OnlineHistoryResponse:
|
||||
"""
|
||||
获取最近24小时在线统计历史
|
||||
|
||||
|
||||
返回过去24小时内每小时的在线用户数和游玩用户数统计,
|
||||
包含当前实时数据作为最新数据点
|
||||
"""
|
||||
try:
|
||||
# 获取历史数据 - 使用同步Redis客户端
|
||||
redis_sync = get_redis_message()
|
||||
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
history_data = await _redis_exec(
|
||||
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
|
||||
)
|
||||
history_points = []
|
||||
|
||||
|
||||
# 处理历史数据
|
||||
for data in history_data:
|
||||
try:
|
||||
point_data = json.loads(data)
|
||||
# 支持新旧格式的历史数据
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"],
|
||||
peak_online=point_data.get("peak_online"), # 新字段,可能不存在
|
||||
peak_playing=point_data.get("peak_playing"), # 新字段,可能不存在
|
||||
total_samples=point_data.get("total_samples") # 新字段,可能不存在
|
||||
))
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"],
|
||||
peak_online=point_data.get("peak_online"), # 新字段,可能不存在
|
||||
peak_playing=point_data.get(
|
||||
"peak_playing"
|
||||
), # 新字段,可能不存在
|
||||
total_samples=point_data.get(
|
||||
"total_samples"
|
||||
), # 新字段,可能不存在
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid history data point: {data}, error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取当前实时统计信息
|
||||
current_stats = await get_server_stats()
|
||||
|
||||
|
||||
# 如果历史数据为空或者最新数据超过15分钟,添加当前数据点
|
||||
if not history_points or (
|
||||
history_points and
|
||||
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
|
||||
history_points
|
||||
and (
|
||||
current_stats.timestamp
|
||||
- max(history_points, key=lambda x: x.timestamp).timestamp
|
||||
).total_seconds()
|
||||
> 15 * 60
|
||||
):
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users,
|
||||
peak_online=current_stats.online_users, # 当前实时数据作为峰值
|
||||
peak_playing=current_stats.playing_users,
|
||||
total_samples=1
|
||||
))
|
||||
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users,
|
||||
peak_online=current_stats.online_users, # 当前实时数据作为峰值
|
||||
peak_playing=current_stats.playing_users,
|
||||
total_samples=1,
|
||||
)
|
||||
)
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
history_points.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
|
||||
# 限制到最多48个数据点(24小时)
|
||||
history_points = history_points[:48]
|
||||
|
||||
|
||||
return OnlineHistoryResponse(
|
||||
history=history_points,
|
||||
current_stats=current_stats
|
||||
history=history_points, current_stats=current_stats
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting online history: {e}")
|
||||
# 返回空历史和当前状态
|
||||
current_stats = await get_server_stats()
|
||||
return OnlineHistoryResponse(
|
||||
history=[],
|
||||
current_stats=current_stats
|
||||
)
|
||||
return OnlineHistoryResponse(history=[], current_stats=current_stats)
|
||||
|
||||
|
||||
async def _get_registered_users_count(redis) -> int:
|
||||
@@ -159,6 +176,7 @@ async def _get_registered_users_count(redis) -> int:
|
||||
logger.error(f"Error getting registered users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_online_users_count(redis) -> int:
|
||||
"""获取当前在线用户数"""
|
||||
try:
|
||||
@@ -168,6 +186,7 @@ async def _get_online_users_count(redis) -> int:
|
||||
logger.error(f"Error getting online users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_playing_users_count(redis) -> int:
|
||||
"""获取当前游玩用户数"""
|
||||
try:
|
||||
@@ -177,14 +196,16 @@ async def _get_playing_users_count(redis) -> int:
|
||||
logger.error(f"Error getting playing users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 统计更新功能
|
||||
async def update_registered_users_count() -> None:
|
||||
"""更新注册用户数缓存"""
|
||||
from app.dependencies.database import with_db
|
||||
from app.database import User
|
||||
from app.const import BANCHOBOT_ID
|
||||
from sqlmodel import select, func
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies.database import with_db
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
redis = get_redis()
|
||||
try:
|
||||
async with with_db() as db:
|
||||
@@ -198,6 +219,7 @@ async def update_registered_users_count() -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating registered users count: {e}")
|
||||
|
||||
|
||||
async def add_online_user(user_id: int) -> None:
|
||||
"""添加在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -209,14 +231,16 @@ async def add_online_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added online user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_online_user(user_id: int) -> None:
|
||||
"""移除在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -226,6 +250,7 @@ async def remove_online_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def add_playing_user(user_id: int) -> None:
|
||||
"""添加游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -237,14 +262,16 @@ async def add_playing_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added playing user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_playing_user(user_id: int) -> None:
|
||||
"""移除游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -253,6 +280,7 @@ async def remove_playing_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def record_hourly_stats() -> None:
|
||||
"""记录统计数据 - 简化版本,主要作为fallback使用"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -260,10 +288,10 @@ async def record_hourly_stats() -> None:
|
||||
try:
|
||||
# 先确保Redis连接正常
|
||||
await redis_async.ping()
|
||||
|
||||
|
||||
online_count = await _get_online_users_count(redis_async)
|
||||
playing_count = await _get_playing_users_count(redis_async)
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
history_point = {
|
||||
"timestamp": current_time.isoformat(),
|
||||
@@ -271,16 +299,20 @@ async def record_hourly_stats() -> None:
|
||||
"playing_count": playing_count,
|
||||
"peak_online": online_count,
|
||||
"peak_playing": playing_count,
|
||||
"total_samples": 1
|
||||
"total_samples": 1,
|
||||
}
|
||||
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
|
||||
)
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
|
||||
|
||||
logger.info(
|
||||
f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording fallback stats: {e}")
|
||||
|
||||
@@ -4,11 +4,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Set, Optional, List
|
||||
from dataclasses import dataclass, asdict
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
@@ -16,7 +14,7 @@ from app.router.v2.stats import (
|
||||
REDIS_ONLINE_HISTORY_KEY,
|
||||
_get_online_users_count,
|
||||
_get_playing_users_count,
|
||||
_redis_exec
|
||||
_redis_exec,
|
||||
)
|
||||
|
||||
# Redis keys for interval statistics
|
||||
@@ -29,34 +27,36 @@ CURRENT_INTERVAL_INFO_KEY = "server:current_interval_info" # 当前区间信息
|
||||
@dataclass
|
||||
class IntervalInfo:
|
||||
"""区间信息"""
|
||||
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
interval_key: str
|
||||
|
||||
|
||||
def is_current(self) -> bool:
|
||||
"""检查是否是当前区间"""
|
||||
now = datetime.utcnow()
|
||||
return self.start_time <= now < self.end_time
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'start_time': self.start_time.isoformat(),
|
||||
'end_time': self.end_time.isoformat(),
|
||||
'interval_key': self.interval_key
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat(),
|
||||
"interval_key": self.interval_key,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'IntervalInfo':
|
||||
def from_dict(cls, data: dict) -> "IntervalInfo":
|
||||
return cls(
|
||||
start_time=datetime.fromisoformat(data['start_time']),
|
||||
end_time=datetime.fromisoformat(data['end_time']),
|
||||
interval_key=data['interval_key']
|
||||
start_time=datetime.fromisoformat(data["start_time"]),
|
||||
end_time=datetime.fromisoformat(data["end_time"]),
|
||||
interval_key=data["interval_key"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntervalStats:
|
||||
"""区间统计数据"""
|
||||
|
||||
interval_key: str
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
@@ -66,38 +66,38 @@ class IntervalStats:
|
||||
peak_playing_count: int # 区间内游玩用户数峰值
|
||||
total_samples: int # 采样次数
|
||||
created_at: datetime
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'interval_key': self.interval_key,
|
||||
'start_time': self.start_time.isoformat(),
|
||||
'end_time': self.end_time.isoformat(),
|
||||
'unique_online_users': self.unique_online_users,
|
||||
'unique_playing_users': self.unique_playing_users,
|
||||
'peak_online_count': self.peak_online_count,
|
||||
'peak_playing_count': self.peak_playing_count,
|
||||
'total_samples': self.total_samples,
|
||||
'created_at': self.created_at.isoformat()
|
||||
"interval_key": self.interval_key,
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat(),
|
||||
"unique_online_users": self.unique_online_users,
|
||||
"unique_playing_users": self.unique_playing_users,
|
||||
"peak_online_count": self.peak_online_count,
|
||||
"peak_playing_count": self.peak_playing_count,
|
||||
"total_samples": self.total_samples,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'IntervalStats':
|
||||
def from_dict(cls, data: dict) -> "IntervalStats":
|
||||
return cls(
|
||||
interval_key=data['interval_key'],
|
||||
start_time=datetime.fromisoformat(data['start_time']),
|
||||
end_time=datetime.fromisoformat(data['end_time']),
|
||||
unique_online_users=data['unique_online_users'],
|
||||
unique_playing_users=data['unique_playing_users'],
|
||||
peak_online_count=data['peak_online_count'],
|
||||
peak_playing_count=data['peak_playing_count'],
|
||||
total_samples=data['total_samples'],
|
||||
created_at=datetime.fromisoformat(data['created_at'])
|
||||
interval_key=data["interval_key"],
|
||||
start_time=datetime.fromisoformat(data["start_time"]),
|
||||
end_time=datetime.fromisoformat(data["end_time"]),
|
||||
unique_online_users=data["unique_online_users"],
|
||||
unique_playing_users=data["unique_playing_users"],
|
||||
peak_online_count=data["peak_online_count"],
|
||||
peak_playing_count=data["peak_playing_count"],
|
||||
total_samples=data["total_samples"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
)
|
||||
|
||||
|
||||
class EnhancedIntervalStatsManager:
|
||||
"""增强的区间统计管理器 - 真正统计半小时区间内的用户活跃情况"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_current_interval_boundaries() -> tuple[datetime, datetime]:
|
||||
"""获取当前30分钟区间的边界"""
|
||||
@@ -108,49 +108,53 @@ class EnhancedIntervalStatsManager:
|
||||
# 区间结束时间
|
||||
end_time = start_time + timedelta(minutes=30)
|
||||
return start_time, end_time
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_interval_key(start_time: datetime) -> str:
|
||||
"""生成区间唯一标识"""
|
||||
return f"{INTERVAL_STATS_BASE_KEY}:{start_time.strftime('%Y%m%d_%H%M')}"
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_current_interval_info() -> IntervalInfo:
|
||||
"""获取当前区间信息"""
|
||||
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
|
||||
|
||||
return IntervalInfo(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
interval_key=interval_key
|
||||
start_time, end_time = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
|
||||
|
||||
return IntervalInfo(
|
||||
start_time=start_time, end_time=end_time, interval_key=interval_key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def initialize_current_interval() -> None:
|
||||
"""初始化当前区间"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
|
||||
# 存储当前区间信息
|
||||
await _redis_exec(
|
||||
redis_sync.set,
|
||||
CURRENT_INTERVAL_INFO_KEY,
|
||||
json.dumps(current_interval.to_dict())
|
||||
redis_sync.set,
|
||||
CURRENT_INTERVAL_INFO_KEY,
|
||||
json.dumps(current_interval.to_dict()),
|
||||
)
|
||||
await redis_async.expire(CURRENT_INTERVAL_INFO_KEY, 35 * 60) # 35分钟过期
|
||||
|
||||
|
||||
# 初始化区间用户集合(如果不存在)
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
|
||||
# 设置过期时间为35分钟
|
||||
await redis_async.expire(online_key, 35 * 60)
|
||||
await redis_async.expire(playing_key, 35 * 60)
|
||||
|
||||
|
||||
# 初始化区间统计记录
|
||||
stats = IntervalStats(
|
||||
interval_key=current_interval.interval_key,
|
||||
@@ -161,157 +165,193 @@ class EnhancedIntervalStatsManager:
|
||||
peak_online_count=0,
|
||||
peak_playing_count=0,
|
||||
total_samples=0,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
await _redis_exec(
|
||||
redis_sync.set,
|
||||
current_interval.interval_key,
|
||||
json.dumps(stats.to_dict())
|
||||
json.dumps(stats.to_dict()),
|
||||
)
|
||||
await redis_async.expire(current_interval.interval_key, 35 * 60)
|
||||
|
||||
|
||||
# 如果历史记录为空,自动填充前24小时数据为0
|
||||
await EnhancedIntervalStatsManager._ensure_24h_history_exists()
|
||||
|
||||
logger.info(f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing current interval: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _ensure_24h_history_exists() -> None:
|
||||
"""确保24小时历史数据存在,不存在则用0填充"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 检查现有历史数据数量
|
||||
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
|
||||
|
||||
history_length = await _redis_exec(
|
||||
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
|
||||
)
|
||||
|
||||
if history_length < 48: # 少于48个数据点(24小时*2)
|
||||
logger.info(f"History has only {history_length} points, filling with zeros for 24h")
|
||||
|
||||
logger.info(
|
||||
f"History has only {history_length} points, filling with zeros for 24h"
|
||||
)
|
||||
|
||||
# 计算需要填充的数据点数量
|
||||
needed_points = 48 - history_length
|
||||
|
||||
|
||||
# 从当前时间往前推,创建缺失的时间点(都填充为0)
|
||||
current_time = datetime.utcnow()
|
||||
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
|
||||
current_interval_start, _ = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
|
||||
# 从当前区间开始往前推,创建历史数据点(确保时间对齐到30分钟边界)
|
||||
fill_points = []
|
||||
for i in range(needed_points):
|
||||
# 每次往前推30分钟,确保时间对齐
|
||||
point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
|
||||
|
||||
point_time = current_interval_start - timedelta(
|
||||
minutes=30 * (i + 1)
|
||||
)
|
||||
|
||||
# 确保时间对齐到30分钟边界
|
||||
aligned_minute = (point_time.minute // 30) * 30
|
||||
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
|
||||
|
||||
point_time = point_time.replace(
|
||||
minute=aligned_minute, second=0, microsecond=0
|
||||
)
|
||||
|
||||
history_point = {
|
||||
"timestamp": point_time.isoformat(),
|
||||
"online_count": 0,
|
||||
"playing_count": 0,
|
||||
"peak_online": 0,
|
||||
"peak_playing": 0,
|
||||
"total_samples": 0
|
||||
"total_samples": 0,
|
||||
}
|
||||
fill_points.append(json.dumps(history_point))
|
||||
|
||||
|
||||
# 将填充数据添加到历史记录末尾(最旧的数据)
|
||||
if fill_points:
|
||||
# 先将现有数据转移到临时位置
|
||||
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
|
||||
if history_length > 0:
|
||||
# 复制现有数据到临时key
|
||||
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
|
||||
)
|
||||
if existing_data:
|
||||
for data in existing_data:
|
||||
await _redis_exec(redis_sync.rpush, temp_key, data)
|
||||
|
||||
|
||||
# 清空原有key
|
||||
await redis_async.delete(REDIS_ONLINE_HISTORY_KEY)
|
||||
|
||||
|
||||
# 先添加填充数据(最旧的)
|
||||
for point in reversed(fill_points): # 反向添加,最旧的在最后
|
||||
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
|
||||
|
||||
await _redis_exec(
|
||||
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
|
||||
)
|
||||
|
||||
# 再添加原有数据(较新的)
|
||||
if history_length > 0:
|
||||
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, temp_key, 0, -1
|
||||
)
|
||||
for data in existing_data:
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
|
||||
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
|
||||
)
|
||||
|
||||
# 清理临时key
|
||||
await redis_async.delete(temp_key)
|
||||
|
||||
|
||||
# 确保只保留48个数据点
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
|
||||
|
||||
# 设置过期时间
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(f"Filled {len(fill_points)} historical data points with zeros")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Filled {len(fill_points)} historical data points with zeros"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring 24h history exists: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def add_user_to_interval(user_id: int, is_playing: bool = False) -> None:
|
||||
"""添加用户到当前区间统计 - 实时更新当前运行的区间"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
|
||||
# 添加到区间在线用户集合
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
await _redis_exec(redis_sync.sadd, online_key, str(user_id))
|
||||
await redis_async.expire(online_key, 35 * 60)
|
||||
|
||||
|
||||
# 如果用户在游玩,也添加到游玩用户集合
|
||||
if is_playing:
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
|
||||
await redis_async.expire(playing_key, 35 * 60)
|
||||
|
||||
|
||||
# 立即更新区间统计(同步更新,确保数据实时性)
|
||||
await EnhancedIntervalStatsManager._update_interval_stats()
|
||||
|
||||
logger.debug(f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding user {user_id} to interval: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def _update_interval_stats() -> None:
|
||||
"""更新当前区间统计 - 立即同步更新"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
|
||||
# 获取区间内独特用户数
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
|
||||
unique_online = await _redis_exec(redis_sync.scard, online_key)
|
||||
unique_playing = await _redis_exec(redis_sync.scard, playing_key)
|
||||
|
||||
|
||||
# 获取当前实时用户数作为峰值参考
|
||||
current_online = await _get_online_users_count(redis_async)
|
||||
current_playing = await _get_playing_users_count(redis_async)
|
||||
|
||||
|
||||
# 获取现有统计数据
|
||||
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
if existing_data:
|
||||
stats = IntervalStats.from_dict(json.loads(existing_data))
|
||||
# 更新峰值
|
||||
stats.peak_online_count = max(stats.peak_online_count, current_online)
|
||||
stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
|
||||
stats.peak_playing_count = max(
|
||||
stats.peak_playing_count, current_playing
|
||||
)
|
||||
stats.total_samples += 1
|
||||
else:
|
||||
# 创建新的统计记录
|
||||
@@ -324,46 +364,52 @@ class EnhancedIntervalStatsManager:
|
||||
peak_online_count=current_online,
|
||||
peak_playing_count=current_playing,
|
||||
total_samples=1,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
# 更新独特用户数
|
||||
stats.unique_online_users = unique_online
|
||||
stats.unique_playing_users = unique_playing
|
||||
|
||||
|
||||
# 立即保存更新的统计数据
|
||||
await _redis_exec(
|
||||
redis_sync.set,
|
||||
current_interval.interval_key,
|
||||
json.dumps(stats.to_dict())
|
||||
json.dumps(stats.to_dict()),
|
||||
)
|
||||
await redis_async.expire(current_interval.interval_key, 35 * 60)
|
||||
|
||||
logger.debug(f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating interval stats: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def finalize_interval() -> Optional[IntervalStats]:
|
||||
async def finalize_interval() -> IntervalStats | None:
|
||||
"""完成当前区间统计并保存到历史"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
|
||||
# 最后一次更新统计
|
||||
await EnhancedIntervalStatsManager._update_interval_stats()
|
||||
|
||||
|
||||
# 获取最终统计数据
|
||||
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
if not stats_data:
|
||||
logger.warning("No interval stats found to finalize")
|
||||
return None
|
||||
|
||||
|
||||
stats = IntervalStats.from_dict(json.loads(stats_data))
|
||||
|
||||
|
||||
# 创建历史记录点(使用区间结束时间作为时间戳,确保时间对齐)
|
||||
history_point = {
|
||||
"timestamp": current_interval.end_time.isoformat(),
|
||||
@@ -371,16 +417,18 @@ class EnhancedIntervalStatsManager:
|
||||
"playing_count": stats.unique_playing_users,
|
||||
"peak_online": stats.peak_online_count,
|
||||
"peak_playing": stats.peak_playing_count,
|
||||
"total_samples": stats.total_samples
|
||||
"total_samples": stats.total_samples,
|
||||
}
|
||||
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
|
||||
)
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Finalized interval stats: "
|
||||
f"unique_online={stats.unique_online_users}, "
|
||||
@@ -390,64 +438,70 @@ class EnhancedIntervalStatsManager:
|
||||
f"samples={stats.total_samples} "
|
||||
f"for {stats.start_time.strftime('%H:%M')}-{stats.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finalizing interval stats: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_current_interval_stats() -> Optional[IntervalStats]:
|
||||
async def get_current_interval_stats() -> IntervalStats | None:
|
||||
"""获取当前区间统计"""
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
|
||||
try:
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
|
||||
if stats_data:
|
||||
return IntervalStats.from_dict(json.loads(stats_data))
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current interval stats: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_intervals() -> None:
|
||||
"""清理过期的区间数据"""
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 删除过期的区间统计数据(超过2小时的)
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=2)
|
||||
pattern = f"{INTERVAL_STATS_BASE_KEY}:*"
|
||||
|
||||
|
||||
keys = await redis_async.keys(pattern)
|
||||
for key in keys:
|
||||
try:
|
||||
# 从key中提取时间
|
||||
time_part = key.decode().split(':')[-1] # YYYYMMDD_HHMM格式
|
||||
key_time = datetime.strptime(time_part, '%Y%m%d_%H%M')
|
||||
|
||||
time_part = key.decode().split(":")[-1] # YYYYMMDD_HHMM格式
|
||||
key_time = datetime.strptime(time_part, "%Y%m%d_%H%M")
|
||||
|
||||
if key_time < cutoff_time:
|
||||
await redis_async.delete(key)
|
||||
# 也删除对应的用户集合
|
||||
await redis_async.delete(f"{INTERVAL_ONLINE_USERS_KEY}:{key}")
|
||||
await redis_async.delete(f"{INTERVAL_PLAYING_USERS_KEY}:{key}")
|
||||
|
||||
|
||||
except (ValueError, IndexError):
|
||||
# 忽略解析错误的key
|
||||
continue
|
||||
|
||||
|
||||
logger.debug("Cleaned up old interval data")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old intervals: {e}")
|
||||
|
||||
|
||||
# 便捷函数,用于替换现有的统计更新函数
|
||||
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
|
||||
async def update_user_activity_in_interval(
|
||||
user_id: int, is_playing: bool = False
|
||||
) -> None:
|
||||
"""用户活动时更新区间统计(在登录、开始游玩等时调用)"""
|
||||
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)
|
||||
|
||||
@@ -3,53 +3,52 @@ Redis 消息队列服务
|
||||
用于实现实时消息推送和异步数据库持久化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
import concurrent.futures
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
"""Redis 消息队列服务"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis()
|
||||
self._processing = False
|
||||
self._batch_size = 50 # 批量处理大小
|
||||
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行同步 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self._executor, func, *args)
|
||||
|
||||
|
||||
async def start_processing(self):
|
||||
"""启动消息处理任务"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
asyncio.create_task(self._process_message_queue())
|
||||
logger.info("Message queue processing started")
|
||||
|
||||
|
||||
async def stop_processing(self):
|
||||
"""停止消息处理"""
|
||||
self._processing = False
|
||||
logger.info("Message queue processing stopped")
|
||||
|
||||
|
||||
async def enqueue_message(self, message_data: dict) -> str:
|
||||
"""
|
||||
将消息加入 Redis 队列(实时响应)
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典,包含所有必要的字段
|
||||
|
||||
|
||||
Returns:
|
||||
消息的临时 UUID
|
||||
"""
|
||||
@@ -58,36 +57,42 @@ class MessageQueue:
|
||||
message_data["temp_uuid"] = temp_uuid
|
||||
message_data["timestamp"] = datetime.now().isoformat()
|
||||
message_data["status"] = "pending" # pending, processing, completed, failed
|
||||
|
||||
|
||||
# 将消息存储到 Redis
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
|
||||
)
|
||||
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"msg:{temp_uuid}", 3600
|
||||
) # 1小时过期
|
||||
|
||||
# 加入处理队列
|
||||
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
|
||||
|
||||
|
||||
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
|
||||
return temp_uuid
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||||
"""获取消息状态"""
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
if not message_data:
|
||||
return None
|
||||
|
||||
|
||||
return message_data
|
||||
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
|
||||
async def get_cached_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从 Redis 获取缓存的消息
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
limit: 限制数量
|
||||
since: 获取自此消息 ID 之后的消息
|
||||
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
@@ -95,29 +100,39 @@ class MessageQueue:
|
||||
message_uuids = await self._run_in_executor(
|
||||
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
|
||||
|
||||
messages = []
|
||||
for uuid_str in message_uuids:
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{uuid_str}"
|
||||
)
|
||||
if message_data:
|
||||
# 检查是否满足 since 条件
|
||||
if since > 0 and "message_id" in message_data:
|
||||
if int(message_data["message_id"]) <= since:
|
||||
continue
|
||||
|
||||
|
||||
messages.append(message_data)
|
||||
|
||||
|
||||
return messages[::-1] # 返回时间顺序
|
||||
|
||||
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
|
||||
|
||||
async def cache_channel_message(
|
||||
self, channel_id: int, temp_uuid: str, max_cache: int = 100
|
||||
):
|
||||
"""将消息 UUID 缓存到频道消息列表"""
|
||||
# 添加到频道消息列表开头
|
||||
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
|
||||
)
|
||||
# 限制缓存大小
|
||||
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
|
||||
await self._run_in_executor(
|
||||
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
|
||||
)
|
||||
# 设置过期时间(24小时)
|
||||
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
|
||||
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"channel:{channel_id}:messages", 86400
|
||||
)
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""异步处理消息队列,批量写入数据库"""
|
||||
while self._processing:
|
||||
@@ -132,75 +147,90 @@ class MessageQueue:
|
||||
message_uuids.append(result[1])
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if message_uuids:
|
||||
await self._process_message_batch(message_uuids)
|
||||
else:
|
||||
# 没有消息时短暂等待
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message queue: {e}")
|
||||
await asyncio.sleep(1) # 错误时等待1秒再重试
|
||||
|
||||
|
||||
async def _process_message_batch(self, message_uuids: list[str]):
|
||||
"""批量处理消息写入数据库"""
|
||||
async with with_db() as session:
|
||||
messages_to_insert = []
|
||||
|
||||
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
if not message_data:
|
||||
continue
|
||||
|
||||
|
||||
# 更新状态为处理中
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
|
||||
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
|
||||
)
|
||||
|
||||
# 创建数据库消息对象
|
||||
msg = ChatMessage(
|
||||
channel_id=int(message_data["channel_id"]),
|
||||
content=message_data["content"],
|
||||
sender_id=int(message_data["sender_id"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("user_uuid") # 用户提供的 UUID(如果有)
|
||||
uuid=message_data.get("user_uuid"), # 用户提供的 UUID(如果有)
|
||||
)
|
||||
|
||||
|
||||
messages_to_insert.append((msg, temp_uuid))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing message {temp_uuid}: {e}")
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||||
)
|
||||
|
||||
if messages_to_insert:
|
||||
try:
|
||||
# 批量插入数据库
|
||||
for msg, temp_uuid in messages_to_insert:
|
||||
session.add(msg)
|
||||
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 更新所有消息状态和真实 ID
|
||||
for msg, temp_uuid in messages_to_insert:
|
||||
await session.refresh(msg)
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={
|
||||
"status": "completed",
|
||||
"message_id": str(msg.message_id),
|
||||
"created_at": msg.timestamp.isoformat() if msg.timestamp else ""
|
||||
})
|
||||
lambda: self.redis.hset(
|
||||
f"msg:{temp_uuid}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"message_id": str(msg.message_id),
|
||||
"created_at": msg.timestamp.isoformat()
|
||||
if msg.timestamp
|
||||
else "",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting messages to database: {e}")
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# 标记所有消息为失败
|
||||
for _, temp_uuid in messages_to_insert:
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||||
)
|
||||
|
||||
|
||||
# 全局消息队列实例
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
专门处理 Redis 消息队列的异步写入数据库
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
@@ -17,103 +17,132 @@ from app.log import logger
|
||||
|
||||
class MessageQueueProcessor:
|
||||
"""消息队列处理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.redis_message = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._processing = False
|
||||
self._queue_task = None
|
||||
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
|
||||
async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str):
|
||||
"""将消息缓存到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
|
||||
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data
|
||||
)
|
||||
await self._redis_exec(
|
||||
self.redis_message.expire, f"msg:{temp_uuid}", 3600
|
||||
) # 1小时过期
|
||||
|
||||
# 加入频道消息列表
|
||||
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
|
||||
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid
|
||||
)
|
||||
await self._redis_exec(
|
||||
self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99
|
||||
) # 保持最新100条
|
||||
await self._redis_exec(
|
||||
self.redis_message.expire, f"channel:{channel_id}:messages", 86400
|
||||
) # 24小时过期
|
||||
|
||||
# 加入异步处理队列
|
||||
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis_message.lpush, "message_write_queue", temp_uuid
|
||||
)
|
||||
|
||||
logger.info(f"Message cached to Redis: {temp_uuid}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache message to Redis: {e}")
|
||||
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
|
||||
async def get_cached_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息"""
|
||||
try:
|
||||
message_uuids = await self._redis_exec(
|
||||
self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
self.redis_message.lrange,
|
||||
f"channel:{channel_id}:messages",
|
||||
0,
|
||||
limit - 1,
|
||||
)
|
||||
|
||||
|
||||
messages = []
|
||||
for temp_uuid in message_uuids:
|
||||
# 解码 UUID 如果它是字节类型
|
||||
if isinstance(temp_uuid, bytes):
|
||||
temp_uuid = temp_uuid.decode('utf-8')
|
||||
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
temp_uuid = temp_uuid.decode("utf-8")
|
||||
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
if raw_data:
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
|
||||
"utf-8"
|
||||
)
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
|
||||
|
||||
# 检查 since 条件
|
||||
if since > 0 and message_data.get("message_id"):
|
||||
if int(message_data["message_id"]) <= since:
|
||||
continue
|
||||
messages.append(message_data)
|
||||
|
||||
|
||||
return messages[::-1] # 按时间顺序返回
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached messages: {e}")
|
||||
return []
|
||||
|
||||
async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None):
|
||||
|
||||
async def update_message_status(
|
||||
self, temp_uuid: str, status: str, message_id: int | None = None
|
||||
):
|
||||
"""更新消息状态"""
|
||||
try:
|
||||
update_data = {"status": status}
|
||||
if message_id:
|
||||
update_data["message_id"] = str(message_id)
|
||||
update_data["db_timestamp"] = datetime.now().isoformat()
|
||||
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message status: {e}")
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||||
"""获取消息状态"""
|
||||
try:
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
if not raw_data:
|
||||
return None
|
||||
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
return {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get message status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""处理消息队列,异步写入数据库"""
|
||||
logger.info("Message queue processing started")
|
||||
|
||||
|
||||
while self._processing:
|
||||
try:
|
||||
# 批量获取消息
|
||||
@@ -126,47 +155,52 @@ class MessageQueueProcessor:
|
||||
# result是 (queue_name, value) 的元组,需要解码
|
||||
uuid_value = result[1]
|
||||
if isinstance(uuid_value, bytes):
|
||||
uuid_value = uuid_value.decode('utf-8')
|
||||
uuid_value = uuid_value.decode("utf-8")
|
||||
message_uuids.append(uuid_value)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if not message_uuids:
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
|
||||
# 批量写入数据库
|
||||
await self._process_message_batch(message_uuids)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message queue processing: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
logger.info("Message queue processing stopped")
|
||||
|
||||
|
||||
async def _process_message_batch(self, message_uuids: list[str]):
|
||||
"""批量处理消息写入数据库"""
|
||||
async with with_db() as session:
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据并解码
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
|
||||
"utf-8"
|
||||
)
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
|
||||
|
||||
if message_data.get("status") != "pending":
|
||||
continue
|
||||
|
||||
|
||||
# 更新状态为处理中
|
||||
await self.update_message_status(temp_uuid, "processing")
|
||||
|
||||
|
||||
# 创建数据库消息
|
||||
msg = ChatMessage(
|
||||
channel_id=int(message_data["channel_id"]),
|
||||
@@ -175,15 +209,17 @@ class MessageQueueProcessor:
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("user_uuid") or None,
|
||||
)
|
||||
|
||||
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
|
||||
|
||||
# 更新成功状态,包含临时消息ID映射
|
||||
assert msg.message_id is not None
|
||||
await self.update_message_status(temp_uuid, "completed", msg.message_id)
|
||||
|
||||
await self.update_message_status(
|
||||
temp_uuid, "completed", msg.message_id
|
||||
)
|
||||
|
||||
# 如果有临时消息ID,存储映射关系并通知客户端更新
|
||||
if message_data.get("temp_message_id"):
|
||||
temp_msg_id = int(message_data["temp_message_id"])
|
||||
@@ -191,53 +227,65 @@ class MessageQueueProcessor:
|
||||
self.redis_message.set,
|
||||
f"temp_to_real:{temp_msg_id}",
|
||||
str(msg.message_id),
|
||||
ex=3600 # 1小时过期
|
||||
ex=3600, # 1小时过期
|
||||
)
|
||||
|
||||
|
||||
# 发送消息ID更新通知到频道
|
||||
channel_id = int(message_data["channel_id"])
|
||||
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
|
||||
|
||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}")
|
||||
|
||||
await self._notify_message_update(
|
||||
channel_id, temp_msg_id, msg.message_id, message_data
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {temp_uuid}: {e}")
|
||||
await self.update_message_status(temp_uuid, "failed")
|
||||
|
||||
async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict):
|
||||
|
||||
async def _notify_message_update(
|
||||
self,
|
||||
channel_id: int,
|
||||
temp_message_id: int,
|
||||
real_message_id: int,
|
||||
message_data: dict,
|
||||
):
|
||||
"""通知客户端消息ID已更新"""
|
||||
try:
|
||||
# 这里我们需要通过 SignalR 发送消息更新通知
|
||||
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
|
||||
update_event = {
|
||||
"event": "chat.message.update",
|
||||
"event": "chat.message.update",
|
||||
"data": {
|
||||
"channel_id": channel_id,
|
||||
"temp_message_id": temp_message_id,
|
||||
"real_message_id": real_message_id,
|
||||
"timestamp": message_data.get("timestamp")
|
||||
}
|
||||
"timestamp": message_data.get("timestamp"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 发布到 Redis 频道,让 SignalR 服务处理
|
||||
await self._redis_exec(
|
||||
self.redis_message.publish,
|
||||
f"chat_updates:{channel_id}",
|
||||
json.dumps(update_event)
|
||||
json.dumps(update_event),
|
||||
)
|
||||
|
||||
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to notify message update: {e}")
|
||||
|
||||
|
||||
def start_processing(self):
|
||||
"""启动消息队列处理"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
self._queue_task = asyncio.create_task(self._process_message_queue())
|
||||
logger.info("Message queue processor started")
|
||||
|
||||
|
||||
def stop_processing(self):
|
||||
"""停止消息队列处理"""
|
||||
if self._processing:
|
||||
@@ -246,10 +294,10 @@ class MessageQueueProcessor:
|
||||
self._queue_task.cancel()
|
||||
self._queue_task = None
|
||||
logger.info("Message queue processor stopped")
|
||||
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'executor'):
|
||||
if hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
@@ -272,11 +320,13 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid:
|
||||
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
|
||||
|
||||
|
||||
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
async def get_cached_messages(
|
||||
channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息 - 便捷接口"""
|
||||
return await message_queue_processor.get_cached_messages(channel_id, limit, since)
|
||||
|
||||
|
||||
async def get_message_status(temp_uuid: str) -> Optional[dict]:
|
||||
async def get_message_status(temp_uuid: str) -> dict | None:
|
||||
"""获取消息状态 - 便捷接口"""
|
||||
return await message_queue_processor.get_message_status(temp_uuid)
|
||||
|
||||
@@ -3,23 +3,26 @@
|
||||
结合 Redis 缓存和异步数据库写入实现实时消息传送
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatMessageResp,
|
||||
MessageType,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.router.notification.server import server
|
||||
from app.service.message_queue import message_queue
|
||||
from app.log import logger
|
||||
from app.service.message_queue import message_queue
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class OptimizedMessageService:
|
||||
"""优化的消息服务"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.message_queue = message_queue
|
||||
|
||||
|
||||
async def send_message_fast(
|
||||
self,
|
||||
channel_id: int,
|
||||
@@ -28,12 +31,12 @@ class OptimizedMessageService:
|
||||
content: str,
|
||||
sender: User,
|
||||
is_action: bool = False,
|
||||
user_uuid: Optional[str] = None,
|
||||
session: Optional[AsyncSession] = None
|
||||
user_uuid: str | None = None,
|
||||
session: AsyncSession | None = None,
|
||||
) -> ChatMessageResp:
|
||||
"""
|
||||
快速发送消息(先缓存到 Redis,异步写入数据库)
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
channel_type: 频道类型
|
||||
@@ -43,12 +46,12 @@ class OptimizedMessageService:
|
||||
is_action: 是否为动作消息
|
||||
user_uuid: 用户提供的 UUID
|
||||
session: 数据库会话(可选,用于一些验证)
|
||||
|
||||
|
||||
Returns:
|
||||
消息响应对象
|
||||
"""
|
||||
assert sender.id is not None
|
||||
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"channel_id": str(channel_id),
|
||||
@@ -57,27 +60,28 @@ class OptimizedMessageService:
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"user_uuid": user_uuid or "",
|
||||
"channel_type": channel_type.value,
|
||||
"channel_name": channel_name
|
||||
"channel_name": channel_name,
|
||||
}
|
||||
|
||||
|
||||
# 立即将消息加入 Redis 队列(实时响应)
|
||||
temp_uuid = await self.message_queue.enqueue_message(message_data)
|
||||
|
||||
|
||||
# 缓存到频道消息列表
|
||||
await self.message_queue.cache_channel_message(channel_id, temp_uuid)
|
||||
|
||||
|
||||
# 创建临时响应对象(简化版本,用于立即响应)
|
||||
from datetime import datetime
|
||||
|
||||
from app.database.lazer_user import UserResp
|
||||
|
||||
|
||||
# 创建基本的用户响应对象
|
||||
user_resp = UserResp(
|
||||
id=sender.id,
|
||||
username=sender.username,
|
||||
country_code=getattr(sender, 'country_code', 'XX'),
|
||||
country_code=getattr(sender, "country_code", "XX"),
|
||||
# 基本字段,其他复杂字段可以后续异步加载
|
||||
)
|
||||
|
||||
|
||||
temp_response = ChatMessageResp(
|
||||
message_id=0, # 临时 ID,等数据库写入后会更新
|
||||
channel_id=channel_id,
|
||||
@@ -86,63 +90,62 @@ class OptimizedMessageService:
|
||||
sender_id=sender.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid
|
||||
uuid=user_uuid,
|
||||
)
|
||||
temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新
|
||||
|
||||
|
||||
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
|
||||
return temp_response
|
||||
|
||||
|
||||
async def get_cached_messages(
|
||||
self,
|
||||
channel_id: int,
|
||||
limit: int = 50,
|
||||
since: int = 0
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
"""
|
||||
获取缓存的消息
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
limit: 限制数量
|
||||
since: 获取自此消息 ID 之后的消息
|
||||
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return await self.message_queue.get_cached_messages(channel_id, limit, since)
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||||
"""
|
||||
获取消息状态
|
||||
|
||||
|
||||
Args:
|
||||
temp_uuid: 临时消息 UUID
|
||||
|
||||
|
||||
Returns:
|
||||
消息状态信息
|
||||
"""
|
||||
return await self.message_queue.get_message_status(temp_uuid)
|
||||
|
||||
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]:
|
||||
|
||||
async def wait_for_message_persisted(
|
||||
self, temp_uuid: str, timeout: int = 30
|
||||
) -> dict | None:
|
||||
"""
|
||||
等待消息持久化到数据库
|
||||
|
||||
|
||||
Args:
|
||||
temp_uuid: 临时消息 UUID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
|
||||
Returns:
|
||||
完成后的消息状态
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
for _ in range(timeout * 10): # 每100ms检查一次
|
||||
status = await self.get_message_status(temp_uuid)
|
||||
if status and status.get("status") in ["completed", "failed"]:
|
||||
return status
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -5,59 +5,66 @@
|
||||
- 支持消息状态同步和故障恢复
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType, ChatMessageResp
|
||||
from app.database.lazer_user import User, UserResp, RANKING_INCLUDES
|
||||
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
"""Redis 消息系统"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._batch_timer: Optional[asyncio.Task] = None
|
||||
self._batch_timer: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self.batch_interval = 5.0 # 5秒批量存储一次
|
||||
self.max_batch_size = 100 # 每批最多处理100条消息
|
||||
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
async def send_message(self, channel_id: int, user: User, content: str,
|
||||
is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp:
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: int,
|
||||
user: User,
|
||||
content: str,
|
||||
is_action: bool = False,
|
||||
user_uuid: str | None = None,
|
||||
) -> ChatMessageResp:
|
||||
"""
|
||||
发送消息 - 立即存储到 Redis 并返回
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
user: 发送用户
|
||||
content: 消息内容
|
||||
is_action: 是否为动作消息
|
||||
user_uuid: 用户UUID
|
||||
|
||||
|
||||
Returns:
|
||||
ChatMessageResp: 消息响应对象
|
||||
"""
|
||||
# 生成消息ID和时间戳
|
||||
message_id = await self._generate_message_id(channel_id)
|
||||
timestamp = datetime.now()
|
||||
|
||||
|
||||
# 确保用户ID存在
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"message_id": message_id,
|
||||
@@ -68,19 +75,20 @@ class RedisMessageSystem:
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time()
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
|
||||
# 立即存储到 Redis
|
||||
await self._store_to_redis(message_id, channel_id, message_data)
|
||||
|
||||
|
||||
# 创建响应对象
|
||||
async with with_db() as session:
|
||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
|
||||
|
||||
# 确保 statistics 不为空
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=user.playmode,
|
||||
global_rank=0,
|
||||
@@ -96,9 +104,9 @@ class RedisMessageSystem:
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
|
||||
response = ChatMessageResp(
|
||||
message_id=message_id,
|
||||
channel_id=channel_id,
|
||||
@@ -107,51 +115,71 @@ class RedisMessageSystem:
|
||||
sender_id=user.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid
|
||||
uuid=user_uuid,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]:
|
||||
|
||||
async def get_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[ChatMessageResp]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
limit: 消息数量限制
|
||||
since: 起始消息ID
|
||||
|
||||
|
||||
Returns:
|
||||
List[ChatMessageResp]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
|
||||
try:
|
||||
# 从 Redis 获取最新消息
|
||||
redis_messages = await self._get_from_redis(channel_id, limit, since)
|
||||
|
||||
|
||||
# 为每条消息构建响应对象
|
||||
async with with_db() as session:
|
||||
for msg_data in redis_messages:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
|
||||
user_resp = await UserResp.from_db(
|
||||
sender, session, RANKING_INCLUDES
|
||||
)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=sender.playmode,
|
||||
global_rank=0, country_rank=0, pp=0.0,
|
||||
ranked_score=0, hit_accuracy=0.0, play_count=0,
|
||||
play_time=0, total_score=0, total_hits=0,
|
||||
maximum_combo=0, replays_watched_by_others=0,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
grade_counts={
|
||||
"ssh": 0,
|
||||
"ss": 0,
|
||||
"sh": 0,
|
||||
"s": 0,
|
||||
"a": 0,
|
||||
},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
|
||||
message_resp = ChatMessageResp(
|
||||
message_id=msg_data["message_id"],
|
||||
channel_id=msg_data["channel_id"],
|
||||
@@ -160,77 +188,97 @@ class RedisMessageSystem:
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender=user_resp,
|
||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
||||
uuid=msg_data.get("uuid") or None
|
||||
uuid=msg_data.get("uuid") or None,
|
||||
)
|
||||
messages.append(message_resp)
|
||||
|
||||
|
||||
# 如果 Redis 消息不够,从数据库补充
|
||||
if len(messages) < limit and since == 0:
|
||||
await self._backfill_from_database(channel_id, messages, limit)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
# 回退到数据库查询
|
||||
messages = await self._get_from_database_only(channel_id, limit, since)
|
||||
|
||||
|
||||
return messages[:limit]
|
||||
|
||||
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
|
||||
message_id = await self._redis_exec(
|
||||
self.redis.incr, "global_message_id_counter"
|
||||
)
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]):
|
||||
|
||||
async def _store_to_redis(
|
||||
self, message_id: int, channel_id: int, message_data: dict[str, Any]
|
||||
):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()}
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
|
||||
)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
|
||||
# 检查键的类型,如果不是 zset 类型则删除
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
logger.warning(
|
||||
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
|
||||
)
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
logger.warning(
|
||||
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
|
||||
)
|
||||
# 如果检查失败,直接删除键以确保清理
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
await self._redis_exec(
|
||||
self.redis.zadd,
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id}
|
||||
self.redis.zadd,
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||
)
|
||||
|
||||
# 添加到待持久化队列
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]:
|
||||
|
||||
async def _get_from_redis(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -239,22 +287,22 @@ class RedisMessageSystem:
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrangebyscore,
|
||||
f"channel:{channel_id}:messages",
|
||||
since + 1, "+inf",
|
||||
start=0, num=limit
|
||||
since + 1,
|
||||
"+inf",
|
||||
start=0,
|
||||
num=limit,
|
||||
)
|
||||
else:
|
||||
# 获取最新的消息(倒序获取,然后反转)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrevrange,
|
||||
f"channel:{channel_id}:messages",
|
||||
0, limit - 1
|
||||
self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
key = key.decode("utf-8")
|
||||
|
||||
# 获取消息数据
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, key)
|
||||
if raw_data:
|
||||
@@ -262,106 +310,118 @@ class RedisMessageSystem:
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
|
||||
v = v.decode("utf-8")
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ['grade_counts', 'level'] or v.startswith(('{', '[')):
|
||||
if k in ["grade_counts", "level"] or v.startswith(
|
||||
("{", "[")
|
||||
):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ['message_id', 'channel_id', 'sender_id']:
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
elif k == 'created_at':
|
||||
elif k == "created_at":
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
message_data[k] = v
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
message_data[k] = v
|
||||
|
||||
|
||||
messages.append(message_data)
|
||||
|
||||
|
||||
# 确保消息按ID正序排序(时间顺序)
|
||||
messages.sort(key=lambda x: x.get('message_id', 0))
|
||||
|
||||
messages.sort(key=lambda x: x.get("message_id", 0))
|
||||
|
||||
# 如果是获取最新消息(since=0),需要保持倒序(最新的在前面)
|
||||
if since == 0:
|
||||
messages.reverse()
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int):
|
||||
|
||||
async def _backfill_from_database(
|
||||
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
|
||||
):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
min_id = float('inf')
|
||||
min_id = float("inf")
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if msg.message_id is not None and msg.message_id < min_id:
|
||||
min_id = msg.message_id
|
||||
|
||||
|
||||
needed = limit - len(existing_messages)
|
||||
|
||||
|
||||
if needed <= 0:
|
||||
return
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
query = select(ChatMessage).where(
|
||||
ChatMessage.channel_id == channel_id
|
||||
)
|
||||
|
||||
if min_id != float('inf'):
|
||||
from sqlmodel import col, select
|
||||
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
|
||||
if min_id != float("inf"):
|
||||
query = query.where(col(ChatMessage.message_id) < min_id)
|
||||
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed)
|
||||
|
||||
|
||||
db_messages = (await session.exec(query)).all()
|
||||
|
||||
|
||||
for msg in reversed(db_messages): # 按时间正序插入
|
||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
||||
existing_messages.insert(0, msg_resp)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]:
|
||||
|
||||
async def _get_from_database_only(
|
||||
self, channel_id: int, limit: int, since: int
|
||||
) -> list[ChatMessageResp]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
from sqlmodel import col, select
|
||||
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
|
||||
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息,按ID正序
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(
|
||||
limit
|
||||
)
|
||||
else:
|
||||
# 获取最新消息,按ID倒序(最新的在前面)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(
|
||||
limit
|
||||
)
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
|
||||
|
||||
results = [
|
||||
await ChatMessageResp.from_db(msg, session) for msg in messages
|
||||
]
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
results.reverse()
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _batch_persist_to_database(self):
|
||||
"""批量持久化消息到数据库"""
|
||||
logger.info("Starting batch persistence to database")
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 获取待处理的消息
|
||||
@@ -374,52 +434,52 @@ class RedisMessageSystem:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
message_keys.append(value)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if message_keys:
|
||||
await self._process_message_batch(message_keys)
|
||||
else:
|
||||
await asyncio.sleep(self.batch_interval)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch persistence: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
logger.info("Stopped batch persistence to database")
|
||||
|
||||
async def _process_message_batch(self, message_keys: List[str]):
|
||||
|
||||
async def _process_message_batch(self, message_keys: list[str]):
|
||||
"""处理消息批次"""
|
||||
async with with_db() as session:
|
||||
for key in message_keys:
|
||||
try:
|
||||
# 解析频道ID和消息ID
|
||||
channel_id, message_id = map(int, key.split(':'))
|
||||
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
|
||||
)
|
||||
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
v = v.decode("utf-8")
|
||||
message_data[k] = v
|
||||
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
existing = await session.get(ChatMessage, int(message_id))
|
||||
if existing:
|
||||
continue
|
||||
|
||||
|
||||
# 创建数据库消息 - 使用 Redis 生成的正数ID
|
||||
db_message = ChatMessage(
|
||||
message_id=int(message_id), # 使用 Redis 系统生成的正数ID
|
||||
@@ -428,31 +488,34 @@ class RedisMessageSystem:
|
||||
content=message_data["content"],
|
||||
timestamp=datetime.fromisoformat(message_data["timestamp"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("uuid") or None
|
||||
uuid=message_data.get("uuid") or None,
|
||||
)
|
||||
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status", "persisted"
|
||||
"status",
|
||||
"persisted",
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Message {message_id} persisted to database")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {key}: {e}")
|
||||
|
||||
|
||||
# 提交批次
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(f"Batch of {len(message_keys)} messages committed to database")
|
||||
logger.info(
|
||||
f"Batch of {len(message_keys)} messages committed to database"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit message batch: {e}")
|
||||
await session.rollback()
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动系统"""
|
||||
if not self._running:
|
||||
@@ -461,63 +524,71 @@ class RedisMessageSystem:
|
||||
# 启动时初始化消息ID计数器
|
||||
asyncio.create_task(self._initialize_message_counter())
|
||||
logger.info("Redis message system started")
|
||||
|
||||
|
||||
async def _initialize_message_counter(self):
|
||||
"""初始化全局消息ID计数器,确保从数据库最大ID开始"""
|
||||
try:
|
||||
# 清理可能存在的问题键
|
||||
await self._cleanup_redis_keys()
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, func
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
# 获取数据库中最大的消息ID
|
||||
result = await session.exec(
|
||||
select(func.max(ChatMessage.message_id))
|
||||
)
|
||||
result = await session.exec(select(func.max(ChatMessage.message_id)))
|
||||
max_id = result.one() or 0
|
||||
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = await self._redis_exec(
|
||||
self.redis.get, "global_message_id_counter"
|
||||
)
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.set, "global_message_id_counter", initial_counter
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initialized global message ID counter to {initial_counter}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.setnx, "global_message_id_counter", 1000000
|
||||
)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
try:
|
||||
# 扫描所有 channel:*:messages 键并检查类型
|
||||
keys_pattern = "channel:*:messages"
|
||||
keys = await self._redis_exec(self.redis.keys, keys_pattern)
|
||||
|
||||
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
key = key.decode("utf-8")
|
||||
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
logger.warning(
|
||||
f"Cleaning up Redis key {key} with wrong type: {key_type}"
|
||||
)
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
# 强制删除问题键
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
|
||||
|
||||
logger.info("Redis keys cleanup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup Redis keys: {e}")
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""停止系统"""
|
||||
if self._running:
|
||||
@@ -526,10 +597,10 @@ class RedisMessageSystem:
|
||||
self._batch_timer.cancel()
|
||||
self._batch_timer = None
|
||||
logger.info("Redis message system stopped")
|
||||
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'executor'):
|
||||
if hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
|
||||
@@ -1,81 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
from app.router.v2.stats import REDIS_ONLINE_USERS_KEY, REDIS_PLAYING_USERS_KEY, _redis_exec
|
||||
from app.router.v2.stats import (
|
||||
REDIS_ONLINE_USERS_KEY,
|
||||
REDIS_PLAYING_USERS_KEY,
|
||||
_redis_exec,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_stale_online_users() -> tuple[int, int]:
|
||||
"""清理过期的在线和游玩用户,返回清理的用户数"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
online_cleaned = 0
|
||||
playing_cleaned = 0
|
||||
|
||||
|
||||
try:
|
||||
# 获取所有在线用户
|
||||
online_users = await _redis_exec(redis_sync.smembers, REDIS_ONLINE_USERS_KEY)
|
||||
playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY)
|
||||
|
||||
|
||||
# 检查在线用户的最后活动时间
|
||||
current_time = datetime.utcnow()
|
||||
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期
|
||||
|
||||
|
||||
# 对于在线用户,我们检查metadata在线标记
|
||||
stale_online_users = []
|
||||
for user_id in online_users:
|
||||
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
user_id_str = (
|
||||
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
)
|
||||
metadata_key = f"metadata:online:{user_id_str}"
|
||||
|
||||
|
||||
# 如果metadata标记不存在,说明用户已经离线
|
||||
if not await redis_async.exists(metadata_key):
|
||||
stale_online_users.append(user_id_str)
|
||||
|
||||
|
||||
# 清理过期的在线用户
|
||||
if stale_online_users:
|
||||
await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users)
|
||||
await _redis_exec(
|
||||
redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users
|
||||
)
|
||||
online_cleaned = len(stale_online_users)
|
||||
logger.info(f"Cleaned {online_cleaned} stale online users")
|
||||
|
||||
|
||||
# 对于游玩用户,我们也检查对应的spectator状态
|
||||
stale_playing_users = []
|
||||
for user_id in playing_users:
|
||||
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
|
||||
user_id_str = (
|
||||
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
)
|
||||
|
||||
# 如果用户不在在线用户列表中,说明已经离线,也应该从游玩列表中移除
|
||||
if user_id_str in stale_online_users or user_id_str not in [
|
||||
u.decode() if isinstance(u, bytes) else str(u) for u in online_users
|
||||
]:
|
||||
stale_playing_users.append(user_id_str)
|
||||
|
||||
|
||||
# 清理过期的游玩用户
|
||||
if stale_playing_users:
|
||||
await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users)
|
||||
await _redis_exec(
|
||||
redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users
|
||||
)
|
||||
playing_cleaned = len(stale_playing_users)
|
||||
logger.info(f"Cleaned {playing_cleaned} stale playing users")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning stale users: {e}")
|
||||
|
||||
|
||||
return online_cleaned, playing_cleaned
|
||||
|
||||
|
||||
async def refresh_redis_key_expiry() -> None:
|
||||
"""刷新Redis键的过期时间,防止数据丢失"""
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 刷新在线用户key的过期时间
|
||||
if await redis_async.exists(REDIS_ONLINE_USERS_KEY):
|
||||
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 6 * 3600) # 6小时
|
||||
|
||||
|
||||
# 刷新游玩用户key的过期时间
|
||||
if await redis_async.exists(REDIS_PLAYING_USERS_KEY):
|
||||
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 6 * 3600) # 6小时
|
||||
|
||||
|
||||
logger.debug("Refreshed Redis key expiry times")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing Redis key expiry: {e}")
|
||||
|
||||
@@ -5,46 +5,49 @@ from datetime import datetime, timedelta
|
||||
|
||||
from app.log import logger
|
||||
from app.router.v2.stats import record_hourly_stats, update_registered_users_count
|
||||
from app.service.stats_cleanup import cleanup_stale_online_users, refresh_redis_key_expiry
|
||||
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
|
||||
from app.service.stats_cleanup import (
|
||||
cleanup_stale_online_users,
|
||||
refresh_redis_key_expiry,
|
||||
)
|
||||
|
||||
|
||||
class StatsScheduler:
|
||||
"""统计数据调度器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._running = False
|
||||
self._stats_task: asyncio.Task | None = None
|
||||
self._registered_task: asyncio.Task | None = None
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动调度器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
self._stats_task = asyncio.create_task(self._stats_loop())
|
||||
self._registered_task = asyncio.create_task(self._registered_users_loop())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("Stats scheduler started")
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止调度器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._stats_task:
|
||||
self._stats_task.cancel()
|
||||
if self._registered_task:
|
||||
self._registered_task.cancel()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
|
||||
|
||||
logger.info("Stats scheduler stopped")
|
||||
|
||||
|
||||
async def _stats_loop(self) -> None:
|
||||
"""统计数据记录循环 - 每30分钟记录一次"""
|
||||
# 启动时立即记录一次统计数据
|
||||
@@ -53,49 +56,57 @@ class StatsScheduler:
|
||||
logger.info("Initial enhanced interval statistics initialized on startup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing enhanced interval stats: {e}")
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 计算下次记录时间(下个30分钟整点)
|
||||
now = datetime.utcnow()
|
||||
|
||||
|
||||
# 计算当前区间边界
|
||||
current_minute = (now.minute // 30) * 30
|
||||
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30)
|
||||
|
||||
current_interval_end = now.replace(
|
||||
minute=current_minute, second=0, microsecond=0
|
||||
) + timedelta(minutes=30)
|
||||
|
||||
# 如果已经过了当前区间结束时间,立即处理
|
||||
if now >= current_interval_end:
|
||||
current_interval_end += timedelta(minutes=30)
|
||||
|
||||
|
||||
# 计算需要等待的时间(到下个区间结束)
|
||||
sleep_seconds = (current_interval_end - now).total_seconds()
|
||||
|
||||
|
||||
# 确保至少等待1分钟,最多等待31分钟
|
||||
sleep_seconds = max(min(sleep_seconds, 31 * 60), 60)
|
||||
|
||||
logger.debug(f"Next interval finalization in {sleep_seconds/60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}")
|
||||
|
||||
logger.debug(
|
||||
f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}"
|
||||
)
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
|
||||
# 完成当前区间并记录到历史
|
||||
finalized_stats = await EnhancedIntervalStatsManager.finalize_interval()
|
||||
if finalized_stats:
|
||||
logger.info(f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(
|
||||
f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
else:
|
||||
# 如果区间完成失败,使用原有方式记录
|
||||
await record_hourly_stats()
|
||||
logger.info(f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
logger.info(
|
||||
f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
# 开始新的区间统计
|
||||
await EnhancedIntervalStatsManager.initialize_current_interval()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stats loop: {e}")
|
||||
# 出错时等待5分钟再重试
|
||||
await asyncio.sleep(5 * 60)
|
||||
|
||||
|
||||
async def _registered_users_loop(self) -> None:
|
||||
"""注册用户数更新循环 - 每5分钟更新一次"""
|
||||
# 启动时立即更新一次注册用户数
|
||||
@@ -104,14 +115,14 @@ class StatsScheduler:
|
||||
logger.info("Initial registered users count updated on startup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating initial registered users count: {e}")
|
||||
|
||||
|
||||
while self._running:
|
||||
# 等待5分钟
|
||||
await asyncio.sleep(5 * 60)
|
||||
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
await update_registered_users_count()
|
||||
logger.debug("Updated registered users count")
|
||||
@@ -124,31 +135,35 @@ class StatsScheduler:
|
||||
try:
|
||||
online_cleaned, playing_cleaned = await cleanup_stale_online_users()
|
||||
if online_cleaned > 0 or playing_cleaned > 0:
|
||||
logger.info(f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users")
|
||||
|
||||
logger.info(
|
||||
f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
|
||||
)
|
||||
|
||||
await refresh_redis_key_expiry()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in initial cleanup: {e}")
|
||||
|
||||
|
||||
while self._running:
|
||||
# 等待10分钟
|
||||
await asyncio.sleep(10 * 60)
|
||||
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
# 清理过期用户
|
||||
online_cleaned, playing_cleaned = await cleanup_stale_online_users()
|
||||
if online_cleaned > 0 or playing_cleaned > 0:
|
||||
logger.info(f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users")
|
||||
|
||||
logger.info(
|
||||
f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
|
||||
)
|
||||
|
||||
# 刷新Redis key过期时间
|
||||
await refresh_redis_key_expiry()
|
||||
|
||||
|
||||
# 清理过期的区间数据
|
||||
await EnhancedIntervalStatsManager.cleanup_old_intervals()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
# 出错时等待2分钟再重试
|
||||
|
||||
@@ -92,11 +92,12 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
@override
|
||||
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||
user_id = int(state.connection_id)
|
||||
|
||||
|
||||
# Remove from online user tracking
|
||||
from app.router.v2.stats import remove_online_user
|
||||
|
||||
asyncio.create_task(remove_online_user(user_id))
|
||||
|
||||
|
||||
if state.pushable:
|
||||
await asyncio.gather(*self.broadcast_tasks(user_id, None))
|
||||
redis = get_redis()
|
||||
@@ -125,6 +126,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
|
||||
# Track online user
|
||||
from app.router.v2.stats import add_online_user
|
||||
|
||||
asyncio.create_task(add_online_user(user_id))
|
||||
|
||||
async with with_db() as session:
|
||||
|
||||
@@ -163,11 +163,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
@override
|
||||
async def _clean_state(self, state: MultiplayerClientState):
|
||||
user_id = int(state.connection_id)
|
||||
|
||||
|
||||
# Remove from online user tracking
|
||||
from app.router.v2.stats import remove_online_user
|
||||
|
||||
asyncio.create_task(remove_online_user(user_id))
|
||||
|
||||
|
||||
if state.room_id != 0 and state.room_id in self.rooms:
|
||||
server_room = self.rooms[state.room_id]
|
||||
room = server_room.room
|
||||
@@ -180,9 +181,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
"""Track online users when connecting to multiplayer hub"""
|
||||
logger.info(f"[MultiplayerHub] Client {client.user_id} connected")
|
||||
|
||||
|
||||
# Track online user
|
||||
from app.router.v2.stats import add_online_user
|
||||
|
||||
asyncio.create_task(add_online_user(client.user_id))
|
||||
|
||||
def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom:
|
||||
@@ -292,11 +294,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
room.users.append(user)
|
||||
self.add_to_group(client, self.group_id(room_id))
|
||||
await server_room.match_type_handler.handle_join(user)
|
||||
|
||||
|
||||
# Critical fix: Send current room and gameplay state to new user
|
||||
# This ensures spectators joining ongoing games get proper state sync
|
||||
await self._send_room_state_to_new_user(client, server_room)
|
||||
|
||||
|
||||
await self.event_logger.player_joined(room_id, user.user_id)
|
||||
|
||||
async with with_db() as session:
|
||||
@@ -669,16 +671,22 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
# Enhanced spectator validation - allow transitions from more states
|
||||
# This matches official osu-server-spectator behavior
|
||||
if old not in (
|
||||
MultiplayerUserState.IDLE,
|
||||
MultiplayerUserState.IDLE,
|
||||
MultiplayerUserState.READY,
|
||||
MultiplayerUserState.RESULTS, # Allow spectating after results
|
||||
):
|
||||
# Allow spectating during gameplay states only if the room is in appropriate state
|
||||
if not (old.is_playing and room.room.state in (
|
||||
MultiplayerRoomState.WAITING_FOR_LOAD,
|
||||
MultiplayerRoomState.PLAYING
|
||||
)):
|
||||
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||
if not (
|
||||
old.is_playing
|
||||
and room.room.state
|
||||
in (
|
||||
MultiplayerRoomState.WAITING_FOR_LOAD,
|
||||
MultiplayerRoomState.PLAYING,
|
||||
)
|
||||
):
|
||||
raise InvokeException(
|
||||
f"Cannot change state from {old} to {new}"
|
||||
)
|
||||
case _:
|
||||
raise InvokeException(f"Invalid state transition from {old} to {new}")
|
||||
|
||||
@@ -691,7 +699,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
|
||||
if user.state == state:
|
||||
return
|
||||
|
||||
|
||||
# Special handling for state changes during gameplay
|
||||
match state:
|
||||
case MultiplayerUserState.IDLE:
|
||||
@@ -704,15 +712,15 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
logger.info(
|
||||
f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}"
|
||||
)
|
||||
|
||||
|
||||
await self.validate_user_stare(
|
||||
server_room,
|
||||
user.state,
|
||||
state,
|
||||
)
|
||||
|
||||
|
||||
await self.change_user_state(server_room, user, state)
|
||||
|
||||
|
||||
# Enhanced spectator handling based on official implementation
|
||||
if state == MultiplayerUserState.SPECTATING:
|
||||
await self.handle_spectator_state_change(client, server_room, user)
|
||||
@@ -738,24 +746,21 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
)
|
||||
|
||||
async def handle_spectator_state_change(
|
||||
self,
|
||||
client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
user: MultiplayerRoomUser
|
||||
self, client: Client, room: ServerMultiplayerRoom, user: MultiplayerRoomUser
|
||||
):
|
||||
"""
|
||||
Handle special logic for users entering spectator mode during ongoing gameplay.
|
||||
Based on official osu-server-spectator implementation.
|
||||
"""
|
||||
room_state = room.room.state
|
||||
|
||||
|
||||
# If switching to spectating during gameplay, immediately request load
|
||||
if room_state == MultiplayerRoomState.WAITING_FOR_LOAD:
|
||||
logger.info(
|
||||
f"[MultiplayerHub] Spectator {user.user_id} joining during load phase"
|
||||
)
|
||||
await self.call_noblock(client, "LoadRequested")
|
||||
|
||||
|
||||
elif room_state == MultiplayerRoomState.PLAYING:
|
||||
logger.info(
|
||||
f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay"
|
||||
@@ -763,9 +768,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
await self.call_noblock(client, "LoadRequested")
|
||||
|
||||
async def _send_current_gameplay_state_to_spectator(
|
||||
self,
|
||||
client: Client,
|
||||
room: ServerMultiplayerRoom
|
||||
self, client: Client, room: ServerMultiplayerRoom
|
||||
):
|
||||
"""
|
||||
Send current gameplay state information to a newly joined spectator.
|
||||
@@ -773,12 +776,8 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
"""
|
||||
try:
|
||||
# Send current room state
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"RoomStateChanged",
|
||||
room.room.state
|
||||
)
|
||||
|
||||
await self.call_noblock(client, "RoomStateChanged", room.room.state)
|
||||
|
||||
# Send current user states for all players
|
||||
for room_user in room.room.users:
|
||||
if room_user.state.is_playing:
|
||||
@@ -788,7 +787,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
room_user.user_id,
|
||||
room_user.state,
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}"
|
||||
)
|
||||
@@ -798,9 +797,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
)
|
||||
|
||||
async def _send_room_state_to_new_user(
|
||||
self,
|
||||
client: Client,
|
||||
room: ServerMultiplayerRoom
|
||||
self, client: Client, room: ServerMultiplayerRoom
|
||||
):
|
||||
"""
|
||||
Send complete room state to a newly joined user.
|
||||
@@ -809,23 +806,19 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
try:
|
||||
# Send current room state
|
||||
if room.room.state != MultiplayerRoomState.OPEN:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"RoomStateChanged",
|
||||
room.room.state
|
||||
)
|
||||
|
||||
await self.call_noblock(client, "RoomStateChanged", room.room.state)
|
||||
|
||||
# If room is in gameplay state, send LoadRequested immediately
|
||||
if room.room.state in (
|
||||
MultiplayerRoomState.WAITING_FOR_LOAD,
|
||||
MultiplayerRoomState.PLAYING
|
||||
MultiplayerRoomState.PLAYING,
|
||||
):
|
||||
logger.info(
|
||||
f"[MultiplayerHub] Sending LoadRequested to user {client.user_id} "
|
||||
f"joining ongoing game (room state: {room.room.state})"
|
||||
)
|
||||
await self.call_noblock(client, "LoadRequested")
|
||||
|
||||
|
||||
# Send all user states to help with synchronization
|
||||
for room_user in room.room.users:
|
||||
if room_user.user_id != client.user_id: # Don't send own state
|
||||
@@ -835,11 +828,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
room_user.user_id,
|
||||
room_user.state,
|
||||
)
|
||||
|
||||
|
||||
# Critical addition: Send current playing users to SpectatorHub for cross-hub sync
|
||||
# This ensures spectators can watch multiplayer players properly
|
||||
await self._sync_with_spectator_hub(client, room)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"[MultiplayerHub] Sent complete room state to new user {client.user_id}"
|
||||
)
|
||||
@@ -849,9 +842,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
)
|
||||
|
||||
async def _sync_with_spectator_hub(
|
||||
self,
|
||||
client: Client,
|
||||
room: ServerMultiplayerRoom
|
||||
self, client: Client, room: ServerMultiplayerRoom
|
||||
):
|
||||
"""
|
||||
Sync with SpectatorHub to ensure cross-hub spectating works properly.
|
||||
@@ -860,7 +851,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.signalr.hub import SpectatorHubs
|
||||
|
||||
|
||||
# For each playing user in the room, check if they have SpectatorHub state
|
||||
# and notify the new client about their playing status
|
||||
for room_user in room.room.users:
|
||||
@@ -878,7 +869,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
f"[MultiplayerHub] Synced spectator state for user {room_user.user_id} "
|
||||
f"to new client {client.user_id}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[MultiplayerHub] Failed to sync with SpectatorHub: {e}")
|
||||
# This is not critical, so we don't raise the exception
|
||||
|
||||
@@ -170,21 +170,22 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
Properly notifies watched users when spectator disconnects.
|
||||
"""
|
||||
user_id = int(state.connection_id)
|
||||
|
||||
|
||||
# Remove from online and playing tracking
|
||||
from app.router.v2.stats import remove_online_user
|
||||
|
||||
asyncio.create_task(remove_online_user(user_id))
|
||||
|
||||
|
||||
if state.state:
|
||||
await self._end_session(user_id, state.state, state)
|
||||
|
||||
|
||||
# Critical fix: Notify all watched users that this spectator has disconnected
|
||||
# This matches the official CleanUpState implementation
|
||||
for watched_user_id in state.watched_user:
|
||||
if (target_client := self.get_client_by_id(str(watched_user_id))) is not None:
|
||||
await self.call_noblock(
|
||||
target_client, "UserEndedWatching", user_id
|
||||
)
|
||||
if (
|
||||
target_client := self.get_client_by_id(str(watched_user_id))
|
||||
) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching"
|
||||
)
|
||||
@@ -195,18 +196,19 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
Send all active player states to newly connected clients.
|
||||
"""
|
||||
logger.info(f"[SpectatorHub] Client {client.user_id} connected")
|
||||
|
||||
|
||||
# Track online user
|
||||
from app.router.v2.stats import add_online_user
|
||||
|
||||
asyncio.create_task(add_online_user(client.user_id))
|
||||
|
||||
|
||||
# Send all current player states to the new client
|
||||
# This matches the official OnConnectedAsync behavior
|
||||
active_states = []
|
||||
for user_id, store in self.state.items():
|
||||
if store.state is not None:
|
||||
active_states.append((user_id, store.state))
|
||||
|
||||
|
||||
if active_states:
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}"
|
||||
@@ -216,8 +218,10 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
try:
|
||||
await self.call_noblock(client, "UserBeganPlaying", user_id, state)
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Failed to send state for user {user_id}: {e}"
|
||||
)
|
||||
|
||||
# Also sync with MultiplayerHub for cross-hub spectating
|
||||
await self._sync_with_multiplayer_hub(client)
|
||||
|
||||
@@ -229,14 +233,15 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
|
||||
|
||||
# Check all active multiplayer rooms for playing users
|
||||
for room_id, server_room in MultiplayerHubs.rooms.items():
|
||||
for room_user in server_room.room.users:
|
||||
# If user is playing in multiplayer but we don't have their spectator state
|
||||
if (room_user.state.is_playing and
|
||||
room_user.user_id not in self.state):
|
||||
|
||||
if (
|
||||
room_user.state.is_playing
|
||||
and room_user.user_id not in self.state
|
||||
):
|
||||
# Create a synthetic SpectatorState for multiplayer players
|
||||
# This helps with cross-hub spectating
|
||||
try:
|
||||
@@ -245,9 +250,9 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
ruleset_id=room_user.ruleset_id or 0, # Default to osu!
|
||||
mods=room_user.mods,
|
||||
state=SpectatedUserState.Playing,
|
||||
maximum_statistics={}
|
||||
maximum_statistics={},
|
||||
)
|
||||
|
||||
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
@@ -258,8 +263,10 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Failed to create synthetic state: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}")
|
||||
# This is not critical, so we don't raise the exception
|
||||
@@ -306,6 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
|
||||
# Track playing user
|
||||
from app.router.v2.stats import add_playing_user
|
||||
|
||||
asyncio.create_task(add_playing_user(user_id))
|
||||
|
||||
# # 预缓存beatmap文件以加速后续PP计算
|
||||
@@ -356,11 +364,12 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
|
||||
await self._process_score(store, client)
|
||||
await self._end_session(user_id, state, store)
|
||||
|
||||
|
||||
# Remove from playing user tracking
|
||||
from app.router.v2.stats import remove_playing_user
|
||||
|
||||
asyncio.create_task(remove_playing_user(user_id))
|
||||
|
||||
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
store.checksum = None
|
||||
@@ -473,8 +482,10 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}")
|
||||
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}"
|
||||
)
|
||||
|
||||
# Calculate exit time safely
|
||||
exit_time = 0
|
||||
if store.score and store.score.replay_frames:
|
||||
@@ -491,7 +502,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
|
||||
# Background task for failtime tracking - only for failed/quit states with valid data
|
||||
if (
|
||||
state.beatmap_id is not None
|
||||
@@ -519,14 +530,16 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
Properly handles state synchronization and watcher notifications.
|
||||
"""
|
||||
user_id = int(client.connection_id)
|
||||
|
||||
|
||||
logger.info(f"[SpectatorHub] {user_id} started watching {target_id}")
|
||||
|
||||
|
||||
try:
|
||||
# Get target user's current state if it exists
|
||||
target_store = self.state.get(target_id)
|
||||
if target_store and target_store.state:
|
||||
logger.debug(f"[SpectatorHub] {target_id} is currently {target_store.state.state}")
|
||||
logger.debug(
|
||||
f"[SpectatorHub] {target_id} is currently {target_store.state.state}"
|
||||
)
|
||||
# Send current state to the watcher immediately
|
||||
await self.call_noblock(
|
||||
client,
|
||||
@@ -552,7 +565,9 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
await session.exec(select(User.username).where(User.id == user_id))
|
||||
).first()
|
||||
if not username:
|
||||
logger.warning(f"[SpectatorHub] Could not find username for user {user_id}")
|
||||
logger.warning(
|
||||
f"[SpectatorHub] Could not find username for user {user_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Notify target user that someone started watching
|
||||
@@ -562,7 +577,9 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
await self.call_noblock(
|
||||
target_client, "UserStartedWatching", watcher_info
|
||||
)
|
||||
logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching")
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Notified {target_id} that {username} started watching"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}")
|
||||
|
||||
@@ -572,19 +589,23 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
Properly cleans up watcher state and notifies target user.
|
||||
"""
|
||||
user_id = int(client.connection_id)
|
||||
|
||||
|
||||
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
|
||||
|
||||
|
||||
# Remove from SignalR group
|
||||
self.remove_from_group(client, self.group_id(target_id))
|
||||
|
||||
|
||||
# Remove from our tracked watched users
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.discard(target_id)
|
||||
|
||||
|
||||
# Notify target user that watcher stopped watching
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching")
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification")
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Target user {target_id} not found for end watching notification"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user