整理代码

This commit is contained in:
咕谷酱
2025-08-22 05:57:28 +08:00
parent ad131c0158
commit ce465aa049
20 changed files with 1078 additions and 799 deletions

View File

@@ -45,7 +45,7 @@ from .relationship import (
) )
from .score_token import ScoreToken from .score_token import ScoreToken
from pydantic import field_validator, field_serializer from pydantic import field_serializer, field_validator
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
@@ -126,7 +126,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
if isinstance(v, dict): if isinstance(v, dict):
serialized = {} serialized = {}
for key, value in v.items(): for key, value in v.items():
if hasattr(key, 'value'): if hasattr(key, "value"):
# 如果是枚举,使用其值 # 如果是枚举,使用其值
serialized[key.value] = value serialized[key.value] = value
else: else:
@@ -138,7 +138,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
@field_serializer("rank", when_used="json") @field_serializer("rank", when_used="json")
def serialize_rank(self, v): def serialize_rank(self, v):
"""序列化等级,确保枚举值正确转换为字符串""" """序列化等级,确保枚举值正确转换为字符串"""
if hasattr(v, 'value'): if hasattr(v, "value"):
return v.value return v.value
return str(v) return str(v)
@@ -188,7 +188,7 @@ class Score(ScoreBase, table=True):
@field_serializer("gamemode", when_used="json") @field_serializer("gamemode", when_used="json")
def serialize_gamemode(self, v): def serialize_gamemode(self, v):
"""序列化游戏模式,确保枚举值正确转换为字符串""" """序列化游戏模式,确保枚举值正确转换为字符串"""
if hasattr(v, 'value'): if hasattr(v, "value"):
return v.value return v.value
return str(v) return str(v)
@@ -281,7 +281,7 @@ class ScoreResp(ScoreBase):
if isinstance(v, dict): if isinstance(v, dict):
serialized = {} serialized = {}
for key, value in v.items(): for key, value in v.items():
if hasattr(key, 'value'): if hasattr(key, "value"):
# 如果是枚举,使用其值 # 如果是枚举,使用其值
serialized[key.value] = value serialized[key.value] = value
else: else:
@@ -294,7 +294,7 @@ class ScoreResp(ScoreBase):
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
# 确保 score 对象完全加载,避免懒加载问题 # 确保 score 对象完全加载,避免懒加载问题
await session.refresh(score) await session.refresh(score)
s = cls.model_validate(score.model_dump()) s = cls.model_validate(score.model_dump())
assert score.id assert score.id
await score.awaitable_attrs.beatmap await score.awaitable_attrs.beatmap

View File

@@ -10,8 +10,8 @@ from app.config import settings
from fastapi import Depends from fastapi import Depends
from pydantic import BaseModel from pydantic import BaseModel
import redis.asyncio as redis
import redis as sync_redis import redis as sync_redis
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -40,7 +40,9 @@ engine = create_async_engine(
redis_client = redis.from_url(settings.redis_url, decode_responses=True) redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 # Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1) redis_message_client = sync_redis.from_url(
settings.redis_url, decode_responses=True, db=1
)
# 数据库依赖 # 数据库依赖

View File

@@ -7,7 +7,7 @@ from app.config import settings
from .mods import API_MODS, APIMod from .mods import API_MODS, APIMod
from pydantic import BaseModel, Field, ValidationInfo, field_validator, field_serializer from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator
if TYPE_CHECKING: if TYPE_CHECKING:
import rosu_pp_py as rosu import rosu_pp_py as rosu
@@ -212,7 +212,7 @@ class SoloScoreSubmissionInfo(BaseModel):
if isinstance(v, dict): if isinstance(v, dict):
serialized = {} serialized = {}
for key, value in v.items(): for key, value in v.items():
if hasattr(key, 'value'): if hasattr(key, "value"):
# 如果是枚举,使用其值 # 如果是枚举,使用其值
serialized[key.value] = value serialized[key.value] = value
else: else:
@@ -224,7 +224,7 @@ class SoloScoreSubmissionInfo(BaseModel):
@field_serializer("rank", when_used="json") @field_serializer("rank", when_used="json")
def serialize_rank(self, v): def serialize_rank(self, v):
"""序列化等级,确保枚举值正确转换为字符串""" """序列化等级,确保枚举值正确转换为字符串"""
if hasattr(v, 'value'): if hasattr(v, "value"):
return v.value return v.value
return str(v) return str(v)

View File

@@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
class OnlineStats(BaseModel): class OnlineStats(BaseModel):
"""在线统计信息""" """在线统计信息"""
registered_users: int registered_users: int
online_users: int online_users: int
playing_users: int playing_users: int
@@ -16,6 +16,7 @@ class OnlineStats(BaseModel):
class OnlineHistoryPoint(BaseModel): class OnlineHistoryPoint(BaseModel):
"""在线历史数据点""" """在线历史数据点"""
timestamp: datetime timestamp: datetime
online_count: int online_count: int
playing_count: int playing_count: int
@@ -23,12 +24,14 @@ class OnlineHistoryPoint(BaseModel):
class OnlineHistoryStats(BaseModel): class OnlineHistoryStats(BaseModel):
"""24小时在线历史统计""" """24小时在线历史统计"""
history: list[OnlineHistoryPoint] history: list[OnlineHistoryPoint]
current_stats: OnlineStats current_stats: OnlineStats
class ServerStatistics(BaseModel): class ServerStatistics(BaseModel):
"""服务器统计信息""" """服务器统计信息"""
total_users: int total_users: int
online_users: int online_users: int
playing_users: int playing_users: int

View File

@@ -62,7 +62,7 @@ async def get_update(
if db_channel: if db_channel:
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_type = db_channel.type channel_type = db_channel.type
resp.presence.append( resp.presence.append(
await ChatChannelResp.from_db( await ChatChannelResp.from_db(
db_channel, db_channel,
@@ -122,9 +122,7 @@ async def join_channel(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
@@ -154,9 +152,7 @@ async def leave_channel(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
@@ -187,7 +183,7 @@ async def get_channel_list(
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_id = channel.channel_id channel_id = channel.channel_id
channel_type = channel.type channel_type = channel.type
assert channel_id is not None assert channel_id is not None
results.append( results.append(
await ChatChannelResp.from_db( await ChatChannelResp.from_db(
@@ -230,19 +226,17 @@ async def get_channel(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性 # 立即提取需要的属性
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
channel_type = db_channel.type channel_type = db_channel.type
channel_name = db_channel.name channel_name = db_channel.name
assert channel_id is not None assert channel_id is not None
users = [] users = []
@@ -325,7 +319,9 @@ async def create_channel(
channel_name = f"pm_{current_user.id}_{req.target_id}" channel_name = f"pm_{current_user.id}_{req.target_id}"
else: else:
channel_name = req.channel.name if req.channel else "Unnamed Channel" channel_name = req.channel.name if req.channel else "Unnamed Channel"
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name)) result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel_name)
)
channel = result.first() channel = result.first()
if channel is None: if channel is None:
@@ -350,11 +346,11 @@ async def create_channel(
await server.batch_join_channel([*target_users, current_user], channel, session) await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session) await server.join_channel(current_user, channel, session)
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id assert channel_id
return await ChatChannelResp.from_db( return await ChatChannelResp.from_db(
channel, channel,
session, session,

View File

@@ -1,10 +1,5 @@
from __future__ import annotations from __future__ import annotations
import json
import uuid
from datetime import datetime
from typing import Optional
from app.database import ChatMessageResp from app.database import ChatMessageResp
from app.database.chat import ( from app.database.chat import (
ChannelType, ChannelType,
@@ -16,14 +11,13 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis, get_redis_message from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.log import logger
from app.models.notification import ChannelMessage, ChannelMessageTeam from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
from app.service.optimized_message import optimized_message_service
from app.service.redis_message_system import redis_message_system from app.service.redis_message_system import redis_message_system
from app.log import logger
from .banchobot import bot from .banchobot import bot
from .server import server from .server import server
@@ -106,11 +100,9 @@ async def send_message(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -118,29 +110,29 @@ async def send_message(
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
channel_type = db_channel.type channel_type = db_channel.type
channel_name = db_channel.name channel_name = db_channel.name
assert channel_id is not None assert channel_id is not None
assert current_user.id assert current_user.id
# 使用 Redis 消息系统发送消息 - 立即返回 # 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message( resp = await redis_message_system.send_message(
channel_id=channel_id, channel_id=channel_id,
user=current_user, user=current_user,
content=req.message, content=req.message,
is_action=req.is_action, is_action=req.is_action,
user_uuid=req.uuid user_uuid=req.uuid,
) )
# 立即广播消息给所有客户端 # 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!") is_bot_command = req.message.startswith("!")
await server.send_message_to_channel( await server.send_message_to_channel(
resp, is_bot_command and channel_type == ChannelType.PUBLIC resp, is_bot_command and channel_type == ChannelType.PUBLIC
) )
# 处理机器人命令 # 处理机器人命令
if is_bot_command: if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session) await bot.try_handle(current_user, db_channel, req.message, session)
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道) # 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]: if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage( temp_msg = ChatMessage(
@@ -151,7 +143,7 @@ async def send_message(
type=MessageType.ACTION if req.is_action else MessageType.PLAIN, type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid, uuid=req.uuid,
) )
if channel_type == ChannelType.PM: if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:] user_ids = channel_name.split("_")[1:]
await server.new_private_notification( await server.new_private_notification(
@@ -163,7 +155,7 @@ async def send_message(
await server.new_private_notification( await server.new_private_notification(
ChannelMessageTeam.init(temp_msg, current_user) ChannelMessageTeam.init(temp_msg, current_user)
) )
return resp return resp
@@ -191,11 +183,9 @@ async def get_message(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -218,7 +208,7 @@ async def get_message(
query = query.where(col(ChatMessage.message_id) > since) query = query.where(col(ChatMessage.message_id) > since)
if until is not None: if until is not None:
query = query.where(col(ChatMessage.message_id) < until) query = query.where(col(ChatMessage.message_id) < until)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
messages = (await session.exec(query)).all() messages = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
@@ -247,14 +237,12 @@ async def mark_as_read(
).first() ).first()
else: else:
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
).first() ).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性 # 立即提取需要的属性
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
assert channel_id assert channel_id

View File

@@ -96,8 +96,10 @@ class ChatServer:
async def send_message_to_channel( async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False self, message: ChatMessageResp, is_bot_command: bool = False
): ):
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}") logger.info(
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
)
event = ChatEvent( event = ChatEvent(
event="chat.message.new", event="chat.message.new",
data={"messages": [message], "users": [message.sender]}, data={"messages": [message], "users": [message.sender]},
@@ -107,24 +109,32 @@ class ChatServer:
self._add_task(self.send_event(message.sender_id, event)) self._add_task(self.send_event(message.sender_id, event))
else: else:
# 总是广播消息无论是临时ID还是真实ID # 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}") logger.info(
f"Broadcasting message to all users in channel {message.channel_id}"
)
self._add_task( self._add_task(
self.broadcast( self.broadcast(
message.channel_id, message.channel_id,
event, event,
) )
) )
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息 # 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理 # Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0: if message.message_id and message.message_id > 0:
await self.mark_as_read( await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id message.channel_id, message.sender_id, message.message_id
) )
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) await self.redis.set(
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}") f"chat:{message.channel_id}:last_msg", message.message_id
)
logger.info(
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
)
else: else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}") logger.debug(
f"Skipping last message update for message ID: {message.message_id}"
)
async def batch_join_channel( async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession self, users: list[User], channel: ChatChannel, session: AsyncSession
@@ -340,11 +350,9 @@ async def chat_websocket(
server.connect(user_id, websocket) server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
select(ChatChannel).where(ChatChannel.channel_id == 1)
)
).first() ).first()
if db_channel is not None: if db_channel is not None:
await server.join_channel(user, db_channel, session) await server.join_channel(user, db_channel, session)
await _listen_stop(websocket, user_id, factory) await _listen_stop(websocket, user_id, factory)

View File

@@ -5,4 +5,3 @@ from fastapi import APIRouter
router = APIRouter(prefix="/api/v2") router = APIRouter(prefix="/api/v2")
# 导入所有子路由模块来注册路由 # 导入所有子路由模块来注册路由
from . import stats # 统计路由

View File

@@ -75,9 +75,10 @@ READ_SCORE_TIMEOUT = 10
async def process_user_achievement(score_id: int): async def process_user_achievement(score_id: int):
from sqlmodel.ext.asyncio.session import AsyncSession
from app.dependencies.database import engine from app.dependencies.database import engine
from sqlmodel.ext.asyncio.session import AsyncSession
session = AsyncSession(engine) session = AsyncSession(engine)
try: try:
await process_achievements(session, get_redis(), score_id) await process_achievements(session, get_redis(), score_id)
@@ -99,7 +100,7 @@ async def submit_score(
): ):
# 立即获取用户ID避免后续的懒加载问题 # 立即获取用户ID避免后续的懒加载问题
user_id = current_user.id user_id = current_user.id
if not info.passed: if not info.passed:
info.rank = Rank.F info.rank = Rank.F
score_token = ( score_token = (
@@ -166,13 +167,15 @@ async def submit_score(
has_pp, has_pp,
has_leaderboard, has_leaderboard,
) )
score = (await db.exec( score = (
select(Score) await db.exec(
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] select(Score)
.where(Score.id == score_id) .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
)).first() .where(Score.id == score_id)
)
).first()
assert score is not None assert score is not None
resp = await ScoreResp.from_db(db, score) resp = await ScoreResp.from_db(db, score)
total_users = (await db.exec(select(func.count()).select_from(User))).first() total_users = (await db.exec(select(func.count()).select_from(User))).first()
assert total_users is not None assert total_users is not None
@@ -202,13 +205,10 @@ async def submit_score(
# 确保score对象已刷新避免在后台任务中触发延迟加载 # 确保score对象已刷新避免在后台任务中触发延迟加载
await db.refresh(score) await db.refresh(score)
score_gamemode = score.gamemode score_gamemode = score.gamemode
if user_id is not None: if user_id is not None:
background_task.add_task( background_task.add_task(
_refresh_user_cache_background, _refresh_user_cache_background, redis, user_id, score_gamemode
redis,
user_id,
score_gamemode
) )
background_task.add_task(process_user_achievement, resp.id) background_task.add_task(process_user_achievement, resp.id)
return resp return resp
@@ -217,9 +217,10 @@ async def submit_score(
async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameMode): async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameMode):
"""后台任务:刷新用户缓存""" """后台任务:刷新用户缓存"""
try: try:
from sqlmodel.ext.asyncio.session import AsyncSession
from app.dependencies.database import engine from app.dependencies.database import engine
from sqlmodel.ext.asyncio.session import AsyncSession
user_cache_service = get_user_cache_service(redis) user_cache_service = get_user_cache_service(redis)
# 创建独立的数据库会话 # 创建独立的数据库会话
session = AsyncSession(engine) session = AsyncSession(engine)
@@ -422,7 +423,7 @@ async def create_solo_score(
assert current_user.id is not None assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
async with db: async with db:
score_token = ScoreToken( score_token = ScoreToken(
@@ -480,7 +481,7 @@ async def create_playlist_score(
assert current_user.id is not None assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -557,10 +558,10 @@ async def submit_playlist_score(
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
assert current_user.id is not None assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
item = ( item = (
await session.exec( await session.exec(
select(Playlist).where( select(Playlist).where(
@@ -627,7 +628,7 @@ async def index_playlist_scores(
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -694,7 +695,7 @@ async def show_playlist_score(
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -803,7 +804,7 @@ async def pin_score(
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score_record = ( score_record = (
await db.exec( await db.exec(
select(Score).where( select(Score).where(
@@ -848,7 +849,7 @@ async def unpin_score(
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score_record = ( score_record = (
await db.exec( await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id) select(Score).where(Score.id == score_id, Score.user_id == user_id)
@@ -892,7 +893,7 @@ async def reorder_score_pin(
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score_record = ( score_record = (
await db.exec( await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id) select(Score).where(Score.id == score_id, Score.user_id == user_id)
@@ -986,9 +987,9 @@ async def download_score_replay(
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService = Depends(get_storage_service),
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score = (await db.exec(select(Score).where(Score.id == score_id))).first() score = (await db.exec(select(Score).where(Score.id == score_id))).first()
if not score: if not score:
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")

View File

@@ -1,42 +1,45 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta
import json
from typing import Any
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
from app.dependencies.database import get_redis, get_redis_message from app.dependencies.database import get_redis, get_redis_message
from app.log import logger from app.log import logger
from .router import router from .router import router
from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
# Redis key constants # Redis key constants
REDIS_ONLINE_USERS_KEY = "server:online_users" REDIS_ONLINE_USERS_KEY = "server:online_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users" REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_REGISTERED_USERS_KEY = "server:registered_users" REDIS_REGISTERED_USERS_KEY = "server:registered_users"
REDIS_ONLINE_HISTORY_KEY = "server:online_history" REDIS_ONLINE_HISTORY_KEY = "server:online_history"
# 线程池用于同步Redis操作 # 线程池用于同步Redis操作
_executor = ThreadPoolExecutor(max_workers=2) _executor = ThreadPoolExecutor(max_workers=2)
async def _redis_exec(func, *args, **kwargs): async def _redis_exec(func, *args, **kwargs):
"""在线程池中执行同步Redis操作""" """在线程池中执行同步Redis操作"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, func, *args, **kwargs) return await loop.run_in_executor(_executor, func, *args, **kwargs)
class ServerStats(BaseModel): class ServerStats(BaseModel):
"""服务器统计信息响应模型""" """服务器统计信息响应模型"""
registered_users: int registered_users: int
online_users: int online_users: int
playing_users: int playing_users: int
timestamp: datetime timestamp: datetime
class OnlineHistoryPoint(BaseModel): class OnlineHistoryPoint(BaseModel):
"""在线历史数据点""" """在线历史数据点"""
timestamp: datetime timestamp: datetime
online_count: int online_count: int
playing_count: int playing_count: int
@@ -44,33 +47,36 @@ class OnlineHistoryPoint(BaseModel):
peak_playing: int | None = None # 峰值游玩数(增强数据) peak_playing: int | None = None # 峰值游玩数(增强数据)
total_samples: int | None = None # 采样次数(增强数据) total_samples: int | None = None # 采样次数(增强数据)
class OnlineHistoryResponse(BaseModel): class OnlineHistoryResponse(BaseModel):
"""24小时在线历史响应模型""" """24小时在线历史响应模型"""
history: list[OnlineHistoryPoint] history: list[OnlineHistoryPoint]
current_stats: ServerStats current_stats: ServerStats
@router.get("/stats", response_model=ServerStats, tags=["统计"]) @router.get("/stats", response_model=ServerStats, tags=["统计"])
async def get_server_stats() -> ServerStats: async def get_server_stats() -> ServerStats:
""" """
获取服务器实时统计信息 获取服务器实时统计信息
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息 返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
""" """
redis = get_redis() redis = get_redis()
try: try:
# 并行获取所有统计数据 # 并行获取所有统计数据
registered_count, online_count, playing_count = await asyncio.gather( registered_count, online_count, playing_count = await asyncio.gather(
_get_registered_users_count(redis), _get_registered_users_count(redis),
_get_online_users_count(redis), _get_online_users_count(redis),
_get_playing_users_count(redis) _get_playing_users_count(redis),
) )
return ServerStats( return ServerStats(
registered_users=registered_count, registered_users=registered_count,
online_users=online_count, online_users=online_count,
playing_users=playing_count, playing_users=playing_count,
timestamp=datetime.utcnow() timestamp=datetime.utcnow(),
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting server stats: {e}") logger.error(f"Error getting server stats: {e}")
@@ -79,75 +85,86 @@ async def get_server_stats() -> ServerStats:
registered_users=0, registered_users=0,
online_users=0, online_users=0,
playing_users=0, playing_users=0,
timestamp=datetime.utcnow() timestamp=datetime.utcnow(),
) )
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"]) @router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
async def get_online_history() -> OnlineHistoryResponse: async def get_online_history() -> OnlineHistoryResponse:
""" """
获取最近24小时在线统计历史 获取最近24小时在线统计历史
返回过去24小时内每小时的在线用户数和游玩用户数统计 返回过去24小时内每小时的在线用户数和游玩用户数统计
包含当前实时数据作为最新数据点 包含当前实时数据作为最新数据点
""" """
try: try:
# 获取历史数据 - 使用同步Redis客户端 # 获取历史数据 - 使用同步Redis客户端
redis_sync = get_redis_message() redis_sync = get_redis_message()
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) history_data = await _redis_exec(
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
)
history_points = [] history_points = []
# 处理历史数据 # 处理历史数据
for data in history_data: for data in history_data:
try: try:
point_data = json.loads(data) point_data = json.loads(data)
# 支持新旧格式的历史数据 # 支持新旧格式的历史数据
history_points.append(OnlineHistoryPoint( history_points.append(
timestamp=datetime.fromisoformat(point_data["timestamp"]), OnlineHistoryPoint(
online_count=point_data["online_count"], timestamp=datetime.fromisoformat(point_data["timestamp"]),
playing_count=point_data["playing_count"], online_count=point_data["online_count"],
peak_online=point_data.get("peak_online"), # 新字段,可能不存在 playing_count=point_data["playing_count"],
peak_playing=point_data.get("peak_playing"), # 新字段,可能不存在 peak_online=point_data.get("peak_online"), # 新字段,可能不存在
total_samples=point_data.get("total_samples") # 新字段,可能不存在 peak_playing=point_data.get(
)) "peak_playing"
), # 新字段,可能不存在
total_samples=point_data.get(
"total_samples"
), # 新字段,可能不存在
)
)
except (json.JSONDecodeError, KeyError, ValueError) as e: except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid history data point: {data}, error: {e}") logger.warning(f"Invalid history data point: {data}, error: {e}")
continue continue
# 获取当前实时统计信息 # 获取当前实时统计信息
current_stats = await get_server_stats() current_stats = await get_server_stats()
# 如果历史数据为空或者最新数据超过15分钟添加当前数据点 # 如果历史数据为空或者最新数据超过15分钟添加当前数据点
if not history_points or ( if not history_points or (
history_points and history_points
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60 and (
current_stats.timestamp
- max(history_points, key=lambda x: x.timestamp).timestamp
).total_seconds()
> 15 * 60
): ):
history_points.append(OnlineHistoryPoint( history_points.append(
timestamp=current_stats.timestamp, OnlineHistoryPoint(
online_count=current_stats.online_users, timestamp=current_stats.timestamp,
playing_count=current_stats.playing_users, online_count=current_stats.online_users,
peak_online=current_stats.online_users, # 当前实时数据作为峰值 playing_count=current_stats.playing_users,
peak_playing=current_stats.playing_users, peak_online=current_stats.online_users, # 当前实时数据作为峰值
total_samples=1 peak_playing=current_stats.playing_users,
)) total_samples=1,
)
)
# 按时间排序(最新的在前) # 按时间排序(最新的在前)
history_points.sort(key=lambda x: x.timestamp, reverse=True) history_points.sort(key=lambda x: x.timestamp, reverse=True)
# 限制到最多48个数据点24小时 # 限制到最多48个数据点24小时
history_points = history_points[:48] history_points = history_points[:48]
return OnlineHistoryResponse( return OnlineHistoryResponse(
history=history_points, history=history_points, current_stats=current_stats
current_stats=current_stats
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting online history: {e}") logger.error(f"Error getting online history: {e}")
# 返回空历史和当前状态 # 返回空历史和当前状态
current_stats = await get_server_stats() current_stats = await get_server_stats()
return OnlineHistoryResponse( return OnlineHistoryResponse(history=[], current_stats=current_stats)
history=[],
current_stats=current_stats
)
async def _get_registered_users_count(redis) -> int: async def _get_registered_users_count(redis) -> int:
@@ -159,6 +176,7 @@ async def _get_registered_users_count(redis) -> int:
logger.error(f"Error getting registered users count: {e}") logger.error(f"Error getting registered users count: {e}")
return 0 return 0
async def _get_online_users_count(redis) -> int: async def _get_online_users_count(redis) -> int:
"""获取当前在线用户数""" """获取当前在线用户数"""
try: try:
@@ -168,6 +186,7 @@ async def _get_online_users_count(redis) -> int:
logger.error(f"Error getting online users count: {e}") logger.error(f"Error getting online users count: {e}")
return 0 return 0
async def _get_playing_users_count(redis) -> int: async def _get_playing_users_count(redis) -> int:
"""获取当前游玩用户数""" """获取当前游玩用户数"""
try: try:
@@ -177,14 +196,16 @@ async def _get_playing_users_count(redis) -> int:
logger.error(f"Error getting playing users count: {e}") logger.error(f"Error getting playing users count: {e}")
return 0 return 0
# 统计更新功能 # 统计更新功能
async def update_registered_users_count() -> None: async def update_registered_users_count() -> None:
"""更新注册用户数缓存""" """更新注册用户数缓存"""
from app.dependencies.database import with_db
from app.database import User
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from sqlmodel import select, func from app.database import User
from app.dependencies.database import with_db
from sqlmodel import func, select
redis = get_redis() redis = get_redis()
try: try:
async with with_db() as db: async with with_db() as db:
@@ -198,6 +219,7 @@ async def update_registered_users_count() -> None:
except Exception as e: except Exception as e:
logger.error(f"Error updating registered users count: {e}") logger.error(f"Error updating registered users count: {e}")
async def add_online_user(user_id: int) -> None: async def add_online_user(user_id: int) -> None:
"""添加在线用户""" """添加在线用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -209,14 +231,16 @@ async def add_online_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期 if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期 await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added online user {user_id}") logger.debug(f"Added online user {user_id}")
# 立即更新当前区间统计 # 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False)) asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
except Exception as e: except Exception as e:
logger.error(f"Error adding online user {user_id}: {e}") logger.error(f"Error adding online user {user_id}: {e}")
async def remove_online_user(user_id: int) -> None: async def remove_online_user(user_id: int) -> None:
"""移除在线用户""" """移除在线用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -226,6 +250,7 @@ async def remove_online_user(user_id: int) -> None:
except Exception as e: except Exception as e:
logger.error(f"Error removing online user {user_id}: {e}") logger.error(f"Error removing online user {user_id}: {e}")
async def add_playing_user(user_id: int) -> None: async def add_playing_user(user_id: int) -> None:
"""添加游玩用户""" """添加游玩用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -237,14 +262,16 @@ async def add_playing_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期 if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期 await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added playing user {user_id}") logger.debug(f"Added playing user {user_id}")
# 立即更新当前区间统计 # 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True)) asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
except Exception as e: except Exception as e:
logger.error(f"Error adding playing user {user_id}: {e}") logger.error(f"Error adding playing user {user_id}: {e}")
async def remove_playing_user(user_id: int) -> None: async def remove_playing_user(user_id: int) -> None:
"""移除游玩用户""" """移除游玩用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -253,6 +280,7 @@ async def remove_playing_user(user_id: int) -> None:
except Exception as e: except Exception as e:
logger.error(f"Error removing playing user {user_id}: {e}") logger.error(f"Error removing playing user {user_id}: {e}")
async def record_hourly_stats() -> None: async def record_hourly_stats() -> None:
"""记录统计数据 - 简化版本主要作为fallback使用""" """记录统计数据 - 简化版本主要作为fallback使用"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -260,10 +288,10 @@ async def record_hourly_stats() -> None:
try: try:
# 先确保Redis连接正常 # 先确保Redis连接正常
await redis_async.ping() await redis_async.ping()
online_count = await _get_online_users_count(redis_async) online_count = await _get_online_users_count(redis_async)
playing_count = await _get_playing_users_count(redis_async) playing_count = await _get_playing_users_count(redis_async)
current_time = datetime.utcnow() current_time = datetime.utcnow()
history_point = { history_point = {
"timestamp": current_time.isoformat(), "timestamp": current_time.isoformat(),
@@ -271,16 +299,20 @@ async def record_hourly_stats() -> None:
"playing_count": playing_count, "playing_count": playing_count,
"peak_online": online_count, "peak_online": online_count,
"peak_playing": playing_count, "peak_playing": playing_count,
"total_samples": 1 "total_samples": 1,
} }
# 添加到历史记录 # 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
)
# 只保留48个数据点24小时每30分钟一个点 # 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲 # 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}") logger.info(
f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}"
)
except Exception as e: except Exception as e:
logger.error(f"Error recording fallback stats: {e}") logger.error(f"Error recording fallback stats: {e}")

View File

@@ -4,11 +4,9 @@
from __future__ import annotations from __future__ import annotations
import json from dataclasses import dataclass
import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Set, Optional, List import json
from dataclasses import dataclass, asdict
from app.dependencies.database import get_redis, get_redis_message from app.dependencies.database import get_redis, get_redis_message
from app.log import logger from app.log import logger
@@ -16,7 +14,7 @@ from app.router.v2.stats import (
REDIS_ONLINE_HISTORY_KEY, REDIS_ONLINE_HISTORY_KEY,
_get_online_users_count, _get_online_users_count,
_get_playing_users_count, _get_playing_users_count,
_redis_exec _redis_exec,
) )
# Redis keys for interval statistics # Redis keys for interval statistics
@@ -29,34 +27,36 @@ CURRENT_INTERVAL_INFO_KEY = "server:current_interval_info" # 当前区间信息
@dataclass @dataclass
class IntervalInfo: class IntervalInfo:
"""区间信息""" """区间信息"""
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
interval_key: str interval_key: str
def is_current(self) -> bool: def is_current(self) -> bool:
"""检查是否是当前区间""" """检查是否是当前区间"""
now = datetime.utcnow() now = datetime.utcnow()
return self.start_time <= now < self.end_time return self.start_time <= now < self.end_time
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
'start_time': self.start_time.isoformat(), "start_time": self.start_time.isoformat(),
'end_time': self.end_time.isoformat(), "end_time": self.end_time.isoformat(),
'interval_key': self.interval_key "interval_key": self.interval_key,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'IntervalInfo': def from_dict(cls, data: dict) -> "IntervalInfo":
return cls( return cls(
start_time=datetime.fromisoformat(data['start_time']), start_time=datetime.fromisoformat(data["start_time"]),
end_time=datetime.fromisoformat(data['end_time']), end_time=datetime.fromisoformat(data["end_time"]),
interval_key=data['interval_key'] interval_key=data["interval_key"],
) )
@dataclass @dataclass
class IntervalStats: class IntervalStats:
"""区间统计数据""" """区间统计数据"""
interval_key: str interval_key: str
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
@@ -66,38 +66,38 @@ class IntervalStats:
peak_playing_count: int # 区间内游玩用户数峰值 peak_playing_count: int # 区间内游玩用户数峰值
total_samples: int # 采样次数 total_samples: int # 采样次数
created_at: datetime created_at: datetime
def to_dict(self) -> Dict: def to_dict(self) -> dict:
return { return {
'interval_key': self.interval_key, "interval_key": self.interval_key,
'start_time': self.start_time.isoformat(), "start_time": self.start_time.isoformat(),
'end_time': self.end_time.isoformat(), "end_time": self.end_time.isoformat(),
'unique_online_users': self.unique_online_users, "unique_online_users": self.unique_online_users,
'unique_playing_users': self.unique_playing_users, "unique_playing_users": self.unique_playing_users,
'peak_online_count': self.peak_online_count, "peak_online_count": self.peak_online_count,
'peak_playing_count': self.peak_playing_count, "peak_playing_count": self.peak_playing_count,
'total_samples': self.total_samples, "total_samples": self.total_samples,
'created_at': self.created_at.isoformat() "created_at": self.created_at.isoformat(),
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'IntervalStats': def from_dict(cls, data: dict) -> "IntervalStats":
return cls( return cls(
interval_key=data['interval_key'], interval_key=data["interval_key"],
start_time=datetime.fromisoformat(data['start_time']), start_time=datetime.fromisoformat(data["start_time"]),
end_time=datetime.fromisoformat(data['end_time']), end_time=datetime.fromisoformat(data["end_time"]),
unique_online_users=data['unique_online_users'], unique_online_users=data["unique_online_users"],
unique_playing_users=data['unique_playing_users'], unique_playing_users=data["unique_playing_users"],
peak_online_count=data['peak_online_count'], peak_online_count=data["peak_online_count"],
peak_playing_count=data['peak_playing_count'], peak_playing_count=data["peak_playing_count"],
total_samples=data['total_samples'], total_samples=data["total_samples"],
created_at=datetime.fromisoformat(data['created_at']) created_at=datetime.fromisoformat(data["created_at"]),
) )
class EnhancedIntervalStatsManager: class EnhancedIntervalStatsManager:
"""增强的区间统计管理器 - 真正统计半小时区间内的用户活跃情况""" """增强的区间统计管理器 - 真正统计半小时区间内的用户活跃情况"""
@staticmethod @staticmethod
def get_current_interval_boundaries() -> tuple[datetime, datetime]: def get_current_interval_boundaries() -> tuple[datetime, datetime]:
"""获取当前30分钟区间的边界""" """获取当前30分钟区间的边界"""
@@ -108,49 +108,53 @@ class EnhancedIntervalStatsManager:
# 区间结束时间 # 区间结束时间
end_time = start_time + timedelta(minutes=30) end_time = start_time + timedelta(minutes=30)
return start_time, end_time return start_time, end_time
@staticmethod @staticmethod
def generate_interval_key(start_time: datetime) -> str: def generate_interval_key(start_time: datetime) -> str:
"""生成区间唯一标识""" """生成区间唯一标识"""
return f"{INTERVAL_STATS_BASE_KEY}:{start_time.strftime('%Y%m%d_%H%M')}" return f"{INTERVAL_STATS_BASE_KEY}:{start_time.strftime('%Y%m%d_%H%M')}"
@staticmethod @staticmethod
async def get_current_interval_info() -> IntervalInfo: async def get_current_interval_info() -> IntervalInfo:
"""获取当前区间信息""" """获取当前区间信息"""
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries() start_time, end_time = (
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time) EnhancedIntervalStatsManager.get_current_interval_boundaries()
return IntervalInfo(
start_time=start_time,
end_time=end_time,
interval_key=interval_key
) )
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
return IntervalInfo(
start_time=start_time, end_time=end_time, interval_key=interval_key
)
@staticmethod @staticmethod
async def initialize_current_interval() -> None: async def initialize_current_interval() -> None:
"""初始化当前区间""" """初始化当前区间"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 存储当前区间信息 # 存储当前区间信息
await _redis_exec( await _redis_exec(
redis_sync.set, redis_sync.set,
CURRENT_INTERVAL_INFO_KEY, CURRENT_INTERVAL_INFO_KEY,
json.dumps(current_interval.to_dict()) json.dumps(current_interval.to_dict()),
) )
await redis_async.expire(CURRENT_INTERVAL_INFO_KEY, 35 * 60) # 35分钟过期 await redis_async.expire(CURRENT_INTERVAL_INFO_KEY, 35 * 60) # 35分钟过期
# 初始化区间用户集合(如果不存在) # 初始化区间用户集合(如果不存在)
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
# 设置过期时间为35分钟 # 设置过期时间为35分钟
await redis_async.expire(online_key, 35 * 60) await redis_async.expire(online_key, 35 * 60)
await redis_async.expire(playing_key, 35 * 60) await redis_async.expire(playing_key, 35 * 60)
# 初始化区间统计记录 # 初始化区间统计记录
stats = IntervalStats( stats = IntervalStats(
interval_key=current_interval.interval_key, interval_key=current_interval.interval_key,
@@ -161,157 +165,193 @@ class EnhancedIntervalStatsManager:
peak_online_count=0, peak_online_count=0,
peak_playing_count=0, peak_playing_count=0,
total_samples=0, total_samples=0,
created_at=datetime.utcnow() created_at=datetime.utcnow(),
) )
await _redis_exec( await _redis_exec(
redis_sync.set, redis_sync.set,
current_interval.interval_key, current_interval.interval_key,
json.dumps(stats.to_dict()) json.dumps(stats.to_dict()),
) )
await redis_async.expire(current_interval.interval_key, 35 * 60) await redis_async.expire(current_interval.interval_key, 35 * 60)
# 如果历史记录为空自动填充前24小时数据为0 # 如果历史记录为空自动填充前24小时数据为0
await EnhancedIntervalStatsManager._ensure_24h_history_exists() await EnhancedIntervalStatsManager._ensure_24h_history_exists()
logger.info(f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}") logger.info(
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e: except Exception as e:
logger.error(f"Error initializing current interval: {e}") logger.error(f"Error initializing current interval: {e}")
@staticmethod @staticmethod
async def _ensure_24h_history_exists() -> None: async def _ensure_24h_history_exists() -> None:
"""确保24小时历史数据存在不存在则用0填充""" """确保24小时历史数据存在不存在则用0填充"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
try: try:
# 检查现有历史数据数量 # 检查现有历史数据数量
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY) history_length = await _redis_exec(
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
)
if history_length < 48: # 少于48个数据点24小时*2 if history_length < 48: # 少于48个数据点24小时*2
logger.info(f"History has only {history_length} points, filling with zeros for 24h") logger.info(
f"History has only {history_length} points, filling with zeros for 24h"
)
# 计算需要填充的数据点数量 # 计算需要填充的数据点数量
needed_points = 48 - history_length needed_points = 48 - history_length
# 从当前时间往前推创建缺失的时间点都填充为0 # 从当前时间往前推创建缺失的时间点都填充为0
current_time = datetime.utcnow() current_time = datetime.utcnow()
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries() current_interval_start, _ = (
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
# 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界 # 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界
fill_points = [] fill_points = []
for i in range(needed_points): for i in range(needed_points):
# 每次往前推30分钟确保时间对齐 # 每次往前推30分钟确保时间对齐
point_time = current_interval_start - timedelta(minutes=30 * (i + 1)) point_time = current_interval_start - timedelta(
minutes=30 * (i + 1)
)
# 确保时间对齐到30分钟边界 # 确保时间对齐到30分钟边界
aligned_minute = (point_time.minute // 30) * 30 aligned_minute = (point_time.minute // 30) * 30
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0) point_time = point_time.replace(
minute=aligned_minute, second=0, microsecond=0
)
history_point = { history_point = {
"timestamp": point_time.isoformat(), "timestamp": point_time.isoformat(),
"online_count": 0, "online_count": 0,
"playing_count": 0, "playing_count": 0,
"peak_online": 0, "peak_online": 0,
"peak_playing": 0, "peak_playing": 0,
"total_samples": 0 "total_samples": 0,
} }
fill_points.append(json.dumps(history_point)) fill_points.append(json.dumps(history_point))
# 将填充数据添加到历史记录末尾(最旧的数据) # 将填充数据添加到历史记录末尾(最旧的数据)
if fill_points: if fill_points:
# 先将现有数据转移到临时位置 # 先将现有数据转移到临时位置
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp" temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
if history_length > 0: if history_length > 0:
# 复制现有数据到临时key # 复制现有数据到临时key
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) existing_data = await _redis_exec(
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
)
if existing_data: if existing_data:
for data in existing_data: for data in existing_data:
await _redis_exec(redis_sync.rpush, temp_key, data) await _redis_exec(redis_sync.rpush, temp_key, data)
# 清空原有key # 清空原有key
await redis_async.delete(REDIS_ONLINE_HISTORY_KEY) await redis_async.delete(REDIS_ONLINE_HISTORY_KEY)
# 先添加填充数据(最旧的) # 先添加填充数据(最旧的)
for point in reversed(fill_points): # 反向添加,最旧的在最后 for point in reversed(fill_points): # 反向添加,最旧的在最后
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point) await _redis_exec(
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
)
# 再添加原有数据(较新的) # 再添加原有数据(较新的)
if history_length > 0: if history_length > 0:
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1) existing_data = await _redis_exec(
redis_sync.lrange, temp_key, 0, -1
)
for data in existing_data: for data in existing_data:
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data) await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
)
# 清理临时key # 清理临时key
await redis_async.delete(temp_key) await redis_async.delete(temp_key)
# 确保只保留48个数据点 # 确保只保留48个数据点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间 # 设置过期时间
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Filled {len(fill_points)} historical data points with zeros") logger.info(
f"Filled {len(fill_points)} historical data points with zeros"
)
except Exception as e: except Exception as e:
logger.error(f"Error ensuring 24h history exists: {e}") logger.error(f"Error ensuring 24h history exists: {e}")
@staticmethod @staticmethod
async def add_user_to_interval(user_id: int, is_playing: bool = False) -> None: async def add_user_to_interval(user_id: int, is_playing: bool = False) -> None:
"""添加用户到当前区间统计 - 实时更新当前运行的区间""" """添加用户到当前区间统计 - 实时更新当前运行的区间"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 添加到区间在线用户集合 # 添加到区间在线用户集合
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
await _redis_exec(redis_sync.sadd, online_key, str(user_id)) await _redis_exec(redis_sync.sadd, online_key, str(user_id))
await redis_async.expire(online_key, 35 * 60) await redis_async.expire(online_key, 35 * 60)
# 如果用户在游玩,也添加到游玩用户集合 # 如果用户在游玩,也添加到游玩用户集合
if is_playing: if is_playing:
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
await _redis_exec(redis_sync.sadd, playing_key, str(user_id)) await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
await redis_async.expire(playing_key, 35 * 60) await redis_async.expire(playing_key, 35 * 60)
# 立即更新区间统计(同步更新,确保数据实时性) # 立即更新区间统计(同步更新,确保数据实时性)
await EnhancedIntervalStatsManager._update_interval_stats() await EnhancedIntervalStatsManager._update_interval_stats()
logger.debug(f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}") logger.debug(
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e: except Exception as e:
logger.error(f"Error adding user {user_id} to interval: {e}") logger.error(f"Error adding user {user_id} to interval: {e}")
@staticmethod @staticmethod
async def _update_interval_stats() -> None: async def _update_interval_stats() -> None:
"""更新当前区间统计 - 立即同步更新""" """更新当前区间统计 - 立即同步更新"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 获取区间内独特用户数 # 获取区间内独特用户数
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
unique_online = await _redis_exec(redis_sync.scard, online_key) unique_online = await _redis_exec(redis_sync.scard, online_key)
unique_playing = await _redis_exec(redis_sync.scard, playing_key) unique_playing = await _redis_exec(redis_sync.scard, playing_key)
# 获取当前实时用户数作为峰值参考 # 获取当前实时用户数作为峰值参考
current_online = await _get_online_users_count(redis_async) current_online = await _get_online_users_count(redis_async)
current_playing = await _get_playing_users_count(redis_async) current_playing = await _get_playing_users_count(redis_async)
# 获取现有统计数据 # 获取现有统计数据
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key) existing_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
if existing_data: if existing_data:
stats = IntervalStats.from_dict(json.loads(existing_data)) stats = IntervalStats.from_dict(json.loads(existing_data))
# 更新峰值 # 更新峰值
stats.peak_online_count = max(stats.peak_online_count, current_online) stats.peak_online_count = max(stats.peak_online_count, current_online)
stats.peak_playing_count = max(stats.peak_playing_count, current_playing) stats.peak_playing_count = max(
stats.peak_playing_count, current_playing
)
stats.total_samples += 1 stats.total_samples += 1
else: else:
# 创建新的统计记录 # 创建新的统计记录
@@ -324,46 +364,52 @@ class EnhancedIntervalStatsManager:
peak_online_count=current_online, peak_online_count=current_online,
peak_playing_count=current_playing, peak_playing_count=current_playing,
total_samples=1, total_samples=1,
created_at=datetime.utcnow() created_at=datetime.utcnow(),
) )
# 更新独特用户数 # 更新独特用户数
stats.unique_online_users = unique_online stats.unique_online_users = unique_online
stats.unique_playing_users = unique_playing stats.unique_playing_users = unique_playing
# 立即保存更新的统计数据 # 立即保存更新的统计数据
await _redis_exec( await _redis_exec(
redis_sync.set, redis_sync.set,
current_interval.interval_key, current_interval.interval_key,
json.dumps(stats.to_dict()) json.dumps(stats.to_dict()),
) )
await redis_async.expire(current_interval.interval_key, 35 * 60) await redis_async.expire(current_interval.interval_key, 35 * 60)
logger.debug(f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}") logger.debug(
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating interval stats: {e}") logger.error(f"Error updating interval stats: {e}")
@staticmethod @staticmethod
async def finalize_interval() -> Optional[IntervalStats]: async def finalize_interval() -> IntervalStats | None:
"""完成当前区间统计并保存到历史""" """完成当前区间统计并保存到历史"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 最后一次更新统计 # 最后一次更新统计
await EnhancedIntervalStatsManager._update_interval_stats() await EnhancedIntervalStatsManager._update_interval_stats()
# 获取最终统计数据 # 获取最终统计数据
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key) stats_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
if not stats_data: if not stats_data:
logger.warning("No interval stats found to finalize") logger.warning("No interval stats found to finalize")
return None return None
stats = IntervalStats.from_dict(json.loads(stats_data)) stats = IntervalStats.from_dict(json.loads(stats_data))
# 创建历史记录点(使用区间结束时间作为时间戳,确保时间对齐) # 创建历史记录点(使用区间结束时间作为时间戳,确保时间对齐)
history_point = { history_point = {
"timestamp": current_interval.end_time.isoformat(), "timestamp": current_interval.end_time.isoformat(),
@@ -371,16 +417,18 @@ class EnhancedIntervalStatsManager:
"playing_count": stats.unique_playing_users, "playing_count": stats.unique_playing_users,
"peak_online": stats.peak_online_count, "peak_online": stats.peak_online_count,
"peak_playing": stats.peak_playing_count, "peak_playing": stats.peak_playing_count,
"total_samples": stats.total_samples "total_samples": stats.total_samples,
} }
# 添加到历史记录 # 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
)
# 只保留48个数据点24小时每30分钟一个点 # 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲 # 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info( logger.info(
f"Finalized interval stats: " f"Finalized interval stats: "
f"unique_online={stats.unique_online_users}, " f"unique_online={stats.unique_online_users}, "
@@ -390,64 +438,70 @@ class EnhancedIntervalStatsManager:
f"samples={stats.total_samples} " f"samples={stats.total_samples} "
f"for {stats.start_time.strftime('%H:%M')}-{stats.end_time.strftime('%H:%M')}" f"for {stats.start_time.strftime('%H:%M')}-{stats.end_time.strftime('%H:%M')}"
) )
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"Error finalizing interval stats: {e}") logger.error(f"Error finalizing interval stats: {e}")
return None return None
@staticmethod @staticmethod
async def get_current_interval_stats() -> Optional[IntervalStats]: async def get_current_interval_stats() -> IntervalStats | None:
"""获取当前区间统计""" """获取当前区间统计"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
try: try:
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = (
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key) await EnhancedIntervalStatsManager.get_current_interval_info()
)
stats_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
if stats_data: if stats_data:
return IntervalStats.from_dict(json.loads(stats_data)) return IntervalStats.from_dict(json.loads(stats_data))
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting current interval stats: {e}") logger.error(f"Error getting current interval stats: {e}")
return None return None
@staticmethod @staticmethod
async def cleanup_old_intervals() -> None: async def cleanup_old_intervals() -> None:
"""清理过期的区间数据""" """清理过期的区间数据"""
redis_async = get_redis() redis_async = get_redis()
try: try:
# 删除过期的区间统计数据超过2小时的 # 删除过期的区间统计数据超过2小时的
cutoff_time = datetime.utcnow() - timedelta(hours=2) cutoff_time = datetime.utcnow() - timedelta(hours=2)
pattern = f"{INTERVAL_STATS_BASE_KEY}:*" pattern = f"{INTERVAL_STATS_BASE_KEY}:*"
keys = await redis_async.keys(pattern) keys = await redis_async.keys(pattern)
for key in keys: for key in keys:
try: try:
# 从key中提取时间 # 从key中提取时间
time_part = key.decode().split(':')[-1] # YYYYMMDD_HHMM格式 time_part = key.decode().split(":")[-1] # YYYYMMDD_HHMM格式
key_time = datetime.strptime(time_part, '%Y%m%d_%H%M') key_time = datetime.strptime(time_part, "%Y%m%d_%H%M")
if key_time < cutoff_time: if key_time < cutoff_time:
await redis_async.delete(key) await redis_async.delete(key)
# 也删除对应的用户集合 # 也删除对应的用户集合
await redis_async.delete(f"{INTERVAL_ONLINE_USERS_KEY}:{key}") await redis_async.delete(f"{INTERVAL_ONLINE_USERS_KEY}:{key}")
await redis_async.delete(f"{INTERVAL_PLAYING_USERS_KEY}:{key}") await redis_async.delete(f"{INTERVAL_PLAYING_USERS_KEY}:{key}")
except (ValueError, IndexError): except (ValueError, IndexError):
# 忽略解析错误的key # 忽略解析错误的key
continue continue
logger.debug("Cleaned up old interval data") logger.debug("Cleaned up old interval data")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up old intervals: {e}") logger.error(f"Error cleaning up old intervals: {e}")
# 便捷函数,用于替换现有的统计更新函数 # 便捷函数,用于替换现有的统计更新函数
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None: async def update_user_activity_in_interval(
user_id: int, is_playing: bool = False
) -> None:
"""用户活动时更新区间统计(在登录、开始游玩等时调用)""" """用户活动时更新区间统计(在登录、开始游玩等时调用)"""
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing) await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)

View File

@@ -3,53 +3,52 @@ Redis 消息队列服务
用于实现实时消息推送和异步数据库持久化 用于实现实时消息推送和异步数据库持久化
""" """
import asyncio from __future__ import annotations
import json
import uuid
from datetime import datetime
from functools import partial
from typing import Optional, Union
import concurrent.futures
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType import asyncio
import concurrent.futures
from datetime import datetime
import uuid
from app.database.chat import ChatMessage, MessageType
from app.dependencies.database import get_redis, with_db from app.dependencies.database import get_redis, with_db
from app.log import logger from app.log import logger
class MessageQueue: class MessageQueue:
"""Redis 消息队列服务""" """Redis 消息队列服务"""
def __init__(self): def __init__(self):
self.redis = get_redis() self.redis = get_redis()
self._processing = False self._processing = False
self._batch_size = 50 # 批量处理大小 self._batch_size = 50 # 批量处理大小
self._batch_timeout = 1.0 # 批量处理超时时间(秒) self._batch_timeout = 1.0 # 批量处理超时时间(秒)
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def _run_in_executor(self, func, *args): async def _run_in_executor(self, func, *args):
"""在线程池中运行同步 Redis 操作""" """在线程池中运行同步 Redis 操作"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, func, *args) return await loop.run_in_executor(self._executor, func, *args)
async def start_processing(self): async def start_processing(self):
"""启动消息处理任务""" """启动消息处理任务"""
if not self._processing: if not self._processing:
self._processing = True self._processing = True
asyncio.create_task(self._process_message_queue()) asyncio.create_task(self._process_message_queue())
logger.info("Message queue processing started") logger.info("Message queue processing started")
async def stop_processing(self): async def stop_processing(self):
"""停止消息处理""" """停止消息处理"""
self._processing = False self._processing = False
logger.info("Message queue processing stopped") logger.info("Message queue processing stopped")
async def enqueue_message(self, message_data: dict) -> str: async def enqueue_message(self, message_data: dict) -> str:
""" """
将消息加入 Redis 队列(实时响应) 将消息加入 Redis 队列(实时响应)
Args: Args:
message_data: 消息数据字典,包含所有必要的字段 message_data: 消息数据字典,包含所有必要的字段
Returns: Returns:
消息的临时 UUID 消息的临时 UUID
""" """
@@ -58,36 +57,42 @@ class MessageQueue:
message_data["temp_uuid"] = temp_uuid message_data["temp_uuid"] = temp_uuid
message_data["timestamp"] = datetime.now().isoformat() message_data["timestamp"] = datetime.now().isoformat()
message_data["status"] = "pending" # pending, processing, completed, failed message_data["status"] = "pending" # pending, processing, completed, failed
# 将消息存储到 Redis # 将消息存储到 Redis
await self._run_in_executor( await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data) lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
) )
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 await self._run_in_executor(
self.redis.expire, f"msg:{temp_uuid}", 3600
) # 1小时过期
# 加入处理队列 # 加入处理队列
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid) await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}") logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
return temp_uuid return temp_uuid
async def get_message_status(self, temp_uuid: str) -> Optional[dict]: async def get_message_status(self, temp_uuid: str) -> dict | None:
"""获取消息状态""" """获取消息状态"""
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
if not message_data: if not message_data:
return None return None
return message_data return message_data
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
""" """
从 Redis 获取缓存的消息 从 Redis 获取缓存的消息
Args: Args:
channel_id: 频道 ID channel_id: 频道 ID
limit: 限制数量 limit: 限制数量
since: 获取自此消息 ID 之后的消息 since: 获取自此消息 ID 之后的消息
Returns: Returns:
消息列表 消息列表
""" """
@@ -95,29 +100,39 @@ class MessageQueue:
message_uuids = await self._run_in_executor( message_uuids = await self._run_in_executor(
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1 self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
) )
messages = [] messages = []
for uuid_str in message_uuids: for uuid_str in message_uuids:
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}") message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{uuid_str}"
)
if message_data: if message_data:
# 检查是否满足 since 条件 # 检查是否满足 since 条件
if since > 0 and "message_id" in message_data: if since > 0 and "message_id" in message_data:
if int(message_data["message_id"]) <= since: if int(message_data["message_id"]) <= since:
continue continue
messages.append(message_data) messages.append(message_data)
return messages[::-1] # 返回时间顺序 return messages[::-1] # 返回时间顺序
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100): async def cache_channel_message(
self, channel_id: int, temp_uuid: str, max_cache: int = 100
):
"""将消息 UUID 缓存到频道消息列表""" """将消息 UUID 缓存到频道消息列表"""
# 添加到频道消息列表开头 # 添加到频道消息列表开头
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid) await self._run_in_executor(
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
)
# 限制缓存大小 # 限制缓存大小
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1) await self._run_in_executor(
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
)
# 设置过期时间24小时 # 设置过期时间24小时
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400) await self._run_in_executor(
self.redis.expire, f"channel:{channel_id}:messages", 86400
)
async def _process_message_queue(self): async def _process_message_queue(self):
"""异步处理消息队列,批量写入数据库""" """异步处理消息队列,批量写入数据库"""
while self._processing: while self._processing:
@@ -132,75 +147,90 @@ class MessageQueue:
message_uuids.append(result[1]) message_uuids.append(result[1])
else: else:
break break
if message_uuids: if message_uuids:
await self._process_message_batch(message_uuids) await self._process_message_batch(message_uuids)
else: else:
# 没有消息时短暂等待 # 没有消息时短暂等待
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except Exception as e: except Exception as e:
logger.error(f"Error processing message queue: {e}") logger.error(f"Error processing message queue: {e}")
await asyncio.sleep(1) # 错误时等待1秒再重试 await asyncio.sleep(1) # 错误时等待1秒再重试
async def _process_message_batch(self, message_uuids: list[str]): async def _process_message_batch(self, message_uuids: list[str]):
"""批量处理消息写入数据库""" """批量处理消息写入数据库"""
async with with_db() as session: async with with_db() as session:
messages_to_insert = [] messages_to_insert = []
for temp_uuid in message_uuids: for temp_uuid in message_uuids:
try: try:
# 获取消息数据 # 获取消息数据
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
if not message_data: if not message_data:
continue continue
# 更新状态为处理中 # 更新状态为处理中
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing") await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
)
# 创建数据库消息对象 # 创建数据库消息对象
msg = ChatMessage( msg = ChatMessage(
channel_id=int(message_data["channel_id"]), channel_id=int(message_data["channel_id"]),
content=message_data["content"], content=message_data["content"],
sender_id=int(message_data["sender_id"]), sender_id=int(message_data["sender_id"]),
type=MessageType(message_data["type"]), type=MessageType(message_data["type"]),
uuid=message_data.get("user_uuid") # 用户提供的 UUID如果有 uuid=message_data.get("user_uuid"), # 用户提供的 UUID如果有
) )
messages_to_insert.append((msg, temp_uuid)) messages_to_insert.append((msg, temp_uuid))
except Exception as e: except Exception as e:
logger.error(f"Error preparing message {temp_uuid}: {e}") logger.error(f"Error preparing message {temp_uuid}: {e}")
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
if messages_to_insert: if messages_to_insert:
try: try:
# 批量插入数据库 # 批量插入数据库
for msg, temp_uuid in messages_to_insert: for msg, temp_uuid in messages_to_insert:
session.add(msg) session.add(msg)
await session.commit() await session.commit()
# 更新所有消息状态和真实 ID # 更新所有消息状态和真实 ID
for msg, temp_uuid in messages_to_insert: for msg, temp_uuid in messages_to_insert:
await session.refresh(msg) await session.refresh(msg)
await self._run_in_executor( await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={ lambda: self.redis.hset(
"status": "completed", f"msg:{temp_uuid}",
"message_id": str(msg.message_id), mapping={
"created_at": msg.timestamp.isoformat() if msg.timestamp else "" "status": "completed",
}) "message_id": str(msg.message_id),
"created_at": msg.timestamp.isoformat()
if msg.timestamp
else "",
},
)
) )
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}") logger.info(
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error inserting messages to database: {e}") logger.error(f"Error inserting messages to database: {e}")
await session.rollback() await session.rollback()
# 标记所有消息为失败 # 标记所有消息为失败
for _, temp_uuid in messages_to_insert: for _, temp_uuid in messages_to_insert:
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
# 全局消息队列实例 # 全局消息队列实例

View File

@@ -3,12 +3,12 @@
专门处理 Redis 消息队列的异步写入数据库 专门处理 Redis 消息队列的异步写入数据库
""" """
from __future__ import annotations
import asyncio import asyncio
import json
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from typing import Optional import json
from app.database.chat import ChatMessage, MessageType from app.database.chat import ChatMessage, MessageType
from app.dependencies.database import get_redis_message, with_db from app.dependencies.database import get_redis_message, with_db
@@ -17,103 +17,132 @@ from app.log import logger
class MessageQueueProcessor: class MessageQueueProcessor:
"""消息队列处理器""" """消息队列处理器"""
def __init__(self): def __init__(self):
self.redis_message = get_redis_message() self.redis_message = get_redis_message()
self.executor = ThreadPoolExecutor(max_workers=2) self.executor = ThreadPoolExecutor(max_workers=2)
self._processing = False self._processing = False
self._queue_task = None self._queue_task = None
async def _redis_exec(self, func, *args, **kwargs): async def _redis_exec(self, func, *args, **kwargs):
"""在线程池中执行 Redis 操作""" """在线程池中执行 Redis 操作"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str): async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str):
"""将消息缓存到 Redis""" """将消息缓存到 Redis"""
try: try:
# 存储消息数据 # 存储消息数据
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data) await self._redis_exec(
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data
)
await self._redis_exec(
self.redis_message.expire, f"msg:{temp_uuid}", 3600
) # 1小时过期
# 加入频道消息列表 # 加入频道消息列表
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid) await self._redis_exec(
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条 self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期 )
await self._redis_exec(
self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99
) # 保持最新100条
await self._redis_exec(
self.redis_message.expire, f"channel:{channel_id}:messages", 86400
) # 24小时过期
# 加入异步处理队列 # 加入异步处理队列
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid) await self._redis_exec(
self.redis_message.lpush, "message_write_queue", temp_uuid
)
logger.info(f"Message cached to Redis: {temp_uuid}") logger.info(f"Message cached to Redis: {temp_uuid}")
except Exception as e: except Exception as e:
logger.error(f"Failed to cache message to Redis: {e}") logger.error(f"Failed to cache message to Redis: {e}")
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
"""从 Redis 获取缓存的消息""" """从 Redis 获取缓存的消息"""
try: try:
message_uuids = await self._redis_exec( message_uuids = await self._redis_exec(
self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1 self.redis_message.lrange,
f"channel:{channel_id}:messages",
0,
limit - 1,
) )
messages = [] messages = []
for temp_uuid in message_uuids: for temp_uuid in message_uuids:
# 解码 UUID 如果它是字节类型 # 解码 UUID 如果它是字节类型
if isinstance(temp_uuid, bytes): if isinstance(temp_uuid, bytes):
temp_uuid = temp_uuid.decode('utf-8') temp_uuid = temp_uuid.decode("utf-8")
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
if raw_data: if raw_data:
# 解码 Redis 返回的字节数据 # 解码 Redis 返回的字节数据
message_data = { message_data = {
k.decode('utf-8') if isinstance(k, bytes) else k: k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
v.decode('utf-8') if isinstance(v, bytes) else v "utf-8"
)
if isinstance(v, bytes)
else v
for k, v in raw_data.items() for k, v in raw_data.items()
} }
# 检查 since 条件 # 检查 since 条件
if since > 0 and message_data.get("message_id"): if since > 0 and message_data.get("message_id"):
if int(message_data["message_id"]) <= since: if int(message_data["message_id"]) <= since:
continue continue
messages.append(message_data) messages.append(message_data)
return messages[::-1] # 按时间顺序返回 return messages[::-1] # 按时间顺序返回
except Exception as e: except Exception as e:
logger.error(f"Failed to get cached messages: {e}") logger.error(f"Failed to get cached messages: {e}")
return [] return []
async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None): async def update_message_status(
self, temp_uuid: str, status: str, message_id: int | None = None
):
"""更新消息状态""" """更新消息状态"""
try: try:
update_data = {"status": status} update_data = {"status": status}
if message_id: if message_id:
update_data["message_id"] = str(message_id) update_data["message_id"] = str(message_id)
update_data["db_timestamp"] = datetime.now().isoformat() update_data["db_timestamp"] = datetime.now().isoformat()
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data) await self._redis_exec(
self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data
)
except Exception as e: except Exception as e:
logger.error(f"Failed to update message status: {e}") logger.error(f"Failed to update message status: {e}")
async def get_message_status(self, temp_uuid: str) -> Optional[dict]: async def get_message_status(self, temp_uuid: str) -> dict | None:
"""获取消息状态""" """获取消息状态"""
try: try:
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
if not raw_data: if not raw_data:
return None return None
# 解码 Redis 返回的字节数据 # 解码 Redis 返回的字节数据
return { return {
k.decode('utf-8') if isinstance(k, bytes) else k: k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
v.decode('utf-8') if isinstance(v, bytes) else v if isinstance(v, bytes)
else v
for k, v in raw_data.items() for k, v in raw_data.items()
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to get message status: {e}") logger.error(f"Failed to get message status: {e}")
return None return None
async def _process_message_queue(self): async def _process_message_queue(self):
"""处理消息队列,异步写入数据库""" """处理消息队列,异步写入数据库"""
logger.info("Message queue processing started") logger.info("Message queue processing started")
while self._processing: while self._processing:
try: try:
# 批量获取消息 # 批量获取消息
@@ -126,47 +155,52 @@ class MessageQueueProcessor:
# result是 (queue_name, value) 的元组,需要解码 # result是 (queue_name, value) 的元组,需要解码
uuid_value = result[1] uuid_value = result[1]
if isinstance(uuid_value, bytes): if isinstance(uuid_value, bytes):
uuid_value = uuid_value.decode('utf-8') uuid_value = uuid_value.decode("utf-8")
message_uuids.append(uuid_value) message_uuids.append(uuid_value)
else: else:
break break
if not message_uuids: if not message_uuids:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
continue continue
# 批量写入数据库 # 批量写入数据库
await self._process_message_batch(message_uuids) await self._process_message_batch(message_uuids)
except Exception as e: except Exception as e:
logger.error(f"Error in message queue processing: {e}") logger.error(f"Error in message queue processing: {e}")
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("Message queue processing stopped") logger.info("Message queue processing stopped")
async def _process_message_batch(self, message_uuids: list[str]): async def _process_message_batch(self, message_uuids: list[str]):
"""批量处理消息写入数据库""" """批量处理消息写入数据库"""
async with with_db() as session: async with with_db() as session:
for temp_uuid in message_uuids: for temp_uuid in message_uuids:
try: try:
# 获取消息数据并解码 # 获取消息数据并解码
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
if not raw_data: if not raw_data:
continue continue
# 解码 Redis 返回的字节数据 # 解码 Redis 返回的字节数据
message_data = { message_data = {
k.decode('utf-8') if isinstance(k, bytes) else k: k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
v.decode('utf-8') if isinstance(v, bytes) else v "utf-8"
)
if isinstance(v, bytes)
else v
for k, v in raw_data.items() for k, v in raw_data.items()
} }
if message_data.get("status") != "pending": if message_data.get("status") != "pending":
continue continue
# 更新状态为处理中 # 更新状态为处理中
await self.update_message_status(temp_uuid, "processing") await self.update_message_status(temp_uuid, "processing")
# 创建数据库消息 # 创建数据库消息
msg = ChatMessage( msg = ChatMessage(
channel_id=int(message_data["channel_id"]), channel_id=int(message_data["channel_id"]),
@@ -175,15 +209,17 @@ class MessageQueueProcessor:
type=MessageType(message_data["type"]), type=MessageType(message_data["type"]),
uuid=message_data.get("user_uuid") or None, uuid=message_data.get("user_uuid") or None,
) )
session.add(msg) session.add(msg)
await session.commit() await session.commit()
await session.refresh(msg) await session.refresh(msg)
# 更新成功状态包含临时消息ID映射 # 更新成功状态包含临时消息ID映射
assert msg.message_id is not None assert msg.message_id is not None
await self.update_message_status(temp_uuid, "completed", msg.message_id) await self.update_message_status(
temp_uuid, "completed", msg.message_id
)
# 如果有临时消息ID存储映射关系并通知客户端更新 # 如果有临时消息ID存储映射关系并通知客户端更新
if message_data.get("temp_message_id"): if message_data.get("temp_message_id"):
temp_msg_id = int(message_data["temp_message_id"]) temp_msg_id = int(message_data["temp_message_id"])
@@ -191,53 +227,65 @@ class MessageQueueProcessor:
self.redis_message.set, self.redis_message.set,
f"temp_to_real:{temp_msg_id}", f"temp_to_real:{temp_msg_id}",
str(msg.message_id), str(msg.message_id),
ex=3600 # 1小时过期 ex=3600, # 1小时过期
) )
# 发送消息ID更新通知到频道 # 发送消息ID更新通知到频道
channel_id = int(message_data["channel_id"]) channel_id = int(message_data["channel_id"])
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data) await self._notify_message_update(
channel_id, temp_msg_id, msg.message_id, message_data
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}") )
logger.info(
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to process message {temp_uuid}: {e}") logger.error(f"Failed to process message {temp_uuid}: {e}")
await self.update_message_status(temp_uuid, "failed") await self.update_message_status(temp_uuid, "failed")
async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict): async def _notify_message_update(
self,
channel_id: int,
temp_message_id: int,
real_message_id: int,
message_data: dict,
):
"""通知客户端消息ID已更新""" """通知客户端消息ID已更新"""
try: try:
# 这里我们需要通过 SignalR 发送消息更新通知 # 这里我们需要通过 SignalR 发送消息更新通知
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件 # 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
update_event = { update_event = {
"event": "chat.message.update", "event": "chat.message.update",
"data": { "data": {
"channel_id": channel_id, "channel_id": channel_id,
"temp_message_id": temp_message_id, "temp_message_id": temp_message_id,
"real_message_id": real_message_id, "real_message_id": real_message_id,
"timestamp": message_data.get("timestamp") "timestamp": message_data.get("timestamp"),
} },
} }
# 发布到 Redis 频道,让 SignalR 服务处理 # 发布到 Redis 频道,让 SignalR 服务处理
await self._redis_exec( await self._redis_exec(
self.redis_message.publish, self.redis_message.publish,
f"chat_updates:{channel_id}", f"chat_updates:{channel_id}",
json.dumps(update_event) json.dumps(update_event),
) )
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}") logger.info(
f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to notify message update: {e}") logger.error(f"Failed to notify message update: {e}")
def start_processing(self): def start_processing(self):
"""启动消息队列处理""" """启动消息队列处理"""
if not self._processing: if not self._processing:
self._processing = True self._processing = True
self._queue_task = asyncio.create_task(self._process_message_queue()) self._queue_task = asyncio.create_task(self._process_message_queue())
logger.info("Message queue processor started") logger.info("Message queue processor started")
def stop_processing(self): def stop_processing(self):
"""停止消息队列处理""" """停止消息队列处理"""
if self._processing: if self._processing:
@@ -246,10 +294,10 @@ class MessageQueueProcessor:
self._queue_task.cancel() self._queue_task.cancel()
self._queue_task = None self._queue_task = None
logger.info("Message queue processor stopped") logger.info("Message queue processor stopped")
def __del__(self): def __del__(self):
"""清理资源""" """清理资源"""
if hasattr(self, 'executor'): if hasattr(self, "executor"):
self.executor.shutdown(wait=False) self.executor.shutdown(wait=False)
@@ -272,11 +320,13 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid:
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid) await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: async def get_cached_messages(
channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
"""从 Redis 获取缓存的消息 - 便捷接口""" """从 Redis 获取缓存的消息 - 便捷接口"""
return await message_queue_processor.get_cached_messages(channel_id, limit, since) return await message_queue_processor.get_cached_messages(channel_id, limit, since)
async def get_message_status(temp_uuid: str) -> Optional[dict]: async def get_message_status(temp_uuid: str) -> dict | None:
"""获取消息状态 - 便捷接口""" """获取消息状态 - 便捷接口"""
return await message_queue_processor.get_message_status(temp_uuid) return await message_queue_processor.get_message_status(temp_uuid)

View File

@@ -3,23 +3,26 @@
结合 Redis 缓存和异步数据库写入实现实时消息传送 结合 Redis 缓存和异步数据库写入实现实时消息传送
""" """
from typing import Optional from __future__ import annotations
from fastapi import HTTPException
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp from app.database.chat import (
ChannelType,
ChatMessageResp,
MessageType,
)
from app.database.lazer_user import User from app.database.lazer_user import User
from app.router.notification.server import server
from app.service.message_queue import message_queue
from app.log import logger from app.log import logger
from app.service.message_queue import message_queue
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
class OptimizedMessageService: class OptimizedMessageService:
"""优化的消息服务""" """优化的消息服务"""
def __init__(self): def __init__(self):
self.message_queue = message_queue self.message_queue = message_queue
async def send_message_fast( async def send_message_fast(
self, self,
channel_id: int, channel_id: int,
@@ -28,12 +31,12 @@ class OptimizedMessageService:
content: str, content: str,
sender: User, sender: User,
is_action: bool = False, is_action: bool = False,
user_uuid: Optional[str] = None, user_uuid: str | None = None,
session: Optional[AsyncSession] = None session: AsyncSession | None = None,
) -> ChatMessageResp: ) -> ChatMessageResp:
""" """
快速发送消息(先缓存到 Redis异步写入数据库 快速发送消息(先缓存到 Redis异步写入数据库
Args: Args:
channel_id: 频道 ID channel_id: 频道 ID
channel_type: 频道类型 channel_type: 频道类型
@@ -43,12 +46,12 @@ class OptimizedMessageService:
is_action: 是否为动作消息 is_action: 是否为动作消息
user_uuid: 用户提供的 UUID user_uuid: 用户提供的 UUID
session: 数据库会话(可选,用于一些验证) session: 数据库会话(可选,用于一些验证)
Returns: Returns:
消息响应对象 消息响应对象
""" """
assert sender.id is not None assert sender.id is not None
# 准备消息数据 # 准备消息数据
message_data = { message_data = {
"channel_id": str(channel_id), "channel_id": str(channel_id),
@@ -57,27 +60,28 @@ class OptimizedMessageService:
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"user_uuid": user_uuid or "", "user_uuid": user_uuid or "",
"channel_type": channel_type.value, "channel_type": channel_type.value,
"channel_name": channel_name "channel_name": channel_name,
} }
# 立即将消息加入 Redis 队列(实时响应) # 立即将消息加入 Redis 队列(实时响应)
temp_uuid = await self.message_queue.enqueue_message(message_data) temp_uuid = await self.message_queue.enqueue_message(message_data)
# 缓存到频道消息列表 # 缓存到频道消息列表
await self.message_queue.cache_channel_message(channel_id, temp_uuid) await self.message_queue.cache_channel_message(channel_id, temp_uuid)
# 创建临时响应对象(简化版本,用于立即响应) # 创建临时响应对象(简化版本,用于立即响应)
from datetime import datetime from datetime import datetime
from app.database.lazer_user import UserResp from app.database.lazer_user import UserResp
# 创建基本的用户响应对象 # 创建基本的用户响应对象
user_resp = UserResp( user_resp = UserResp(
id=sender.id, id=sender.id,
username=sender.username, username=sender.username,
country_code=getattr(sender, 'country_code', 'XX'), country_code=getattr(sender, "country_code", "XX"),
# 基本字段,其他复杂字段可以后续异步加载 # 基本字段,其他复杂字段可以后续异步加载
) )
temp_response = ChatMessageResp( temp_response = ChatMessageResp(
message_id=0, # 临时 ID等数据库写入后会更新 message_id=0, # 临时 ID等数据库写入后会更新
channel_id=channel_id, channel_id=channel_id,
@@ -86,63 +90,62 @@ class OptimizedMessageService:
sender_id=sender.id, sender_id=sender.id,
sender=user_resp, sender=user_resp,
is_action=is_action, is_action=is_action,
uuid=user_uuid uuid=user_uuid,
) )
temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新 temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}") logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
return temp_response return temp_response
async def get_cached_messages( async def get_cached_messages(
self, self, channel_id: int, limit: int = 50, since: int = 0
channel_id: int,
limit: int = 50,
since: int = 0
) -> list[dict]: ) -> list[dict]:
""" """
获取缓存的消息 获取缓存的消息
Args: Args:
channel_id: 频道 ID channel_id: 频道 ID
limit: 限制数量 limit: 限制数量
since: 获取自此消息 ID 之后的消息 since: 获取自此消息 ID 之后的消息
Returns: Returns:
消息列表 消息列表
""" """
return await self.message_queue.get_cached_messages(channel_id, limit, since) return await self.message_queue.get_cached_messages(channel_id, limit, since)
async def get_message_status(self, temp_uuid: str) -> Optional[dict]: async def get_message_status(self, temp_uuid: str) -> dict | None:
""" """
获取消息状态 获取消息状态
Args: Args:
temp_uuid: 临时消息 UUID temp_uuid: 临时消息 UUID
Returns: Returns:
消息状态信息 消息状态信息
""" """
return await self.message_queue.get_message_status(temp_uuid) return await self.message_queue.get_message_status(temp_uuid)
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]: async def wait_for_message_persisted(
self, temp_uuid: str, timeout: int = 30
) -> dict | None:
""" """
等待消息持久化到数据库 等待消息持久化到数据库
Args: Args:
temp_uuid: 临时消息 UUID temp_uuid: 临时消息 UUID
timeout: 超时时间(秒) timeout: 超时时间(秒)
Returns: Returns:
完成后的消息状态 完成后的消息状态
""" """
import asyncio import asyncio
for _ in range(timeout * 10): # 每100ms检查一次 for _ in range(timeout * 10): # 每100ms检查一次
status = await self.get_message_status(temp_uuid) status = await self.get_message_status(temp_uuid)
if status and status.get("status") in ["completed", "failed"]: if status and status.get("status") in ["completed", "failed"]:
return status return status
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return None return None

View File

@@ -5,59 +5,66 @@
- 支持消息状态同步和故障恢复 - 支持消息状态同步和故障恢复
""" """
from __future__ import annotations
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json import json
import time import time
import uuid from typing import Any
from datetime import datetime
from typing import Optional, List, Dict, Any
from concurrent.futures import ThreadPoolExecutor
from app.database.chat import ChatMessage, MessageType, ChatMessageResp from app.database.chat import ChatMessage, ChatMessageResp, MessageType
from app.database.lazer_user import User, UserResp, RANKING_INCLUDES from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
from app.dependencies.database import get_redis_message, with_db from app.dependencies.database import get_redis_message, with_db
from app.log import logger from app.log import logger
class RedisMessageSystem: class RedisMessageSystem:
"""Redis 消息系统""" """Redis 消息系统"""
def __init__(self): def __init__(self):
self.redis = get_redis_message() self.redis = get_redis_message()
self.executor = ThreadPoolExecutor(max_workers=2) self.executor = ThreadPoolExecutor(max_workers=2)
self._batch_timer: Optional[asyncio.Task] = None self._batch_timer: asyncio.Task | None = None
self._running = False self._running = False
self.batch_interval = 5.0 # 5秒批量存储一次 self.batch_interval = 5.0 # 5秒批量存储一次
self.max_batch_size = 100 # 每批最多处理100条消息 self.max_batch_size = 100 # 每批最多处理100条消息
async def _redis_exec(self, func, *args, **kwargs): async def _redis_exec(self, func, *args, **kwargs):
"""在线程池中执行 Redis 操作""" """在线程池中执行 Redis 操作"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
async def send_message(self, channel_id: int, user: User, content: str, async def send_message(
is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp: self,
channel_id: int,
user: User,
content: str,
is_action: bool = False,
user_uuid: str | None = None,
) -> ChatMessageResp:
""" """
发送消息 - 立即存储到 Redis 并返回 发送消息 - 立即存储到 Redis 并返回
Args: Args:
channel_id: 频道ID channel_id: 频道ID
user: 发送用户 user: 发送用户
content: 消息内容 content: 消息内容
is_action: 是否为动作消息 is_action: 是否为动作消息
user_uuid: 用户UUID user_uuid: 用户UUID
Returns: Returns:
ChatMessageResp: 消息响应对象 ChatMessageResp: 消息响应对象
""" """
# 生成消息ID和时间戳 # 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id) message_id = await self._generate_message_id(channel_id)
timestamp = datetime.now() timestamp = datetime.now()
# 确保用户ID存在 # 确保用户ID存在
if not user.id: if not user.id:
raise ValueError("User ID is required") raise ValueError("User ID is required")
# 准备消息数据 # 准备消息数据
message_data = { message_data = {
"message_id": message_id, "message_id": message_id,
@@ -68,19 +75,20 @@ class RedisMessageSystem:
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"uuid": user_uuid or "", "uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态 "status": "cached", # Redis 缓存状态
"created_at": time.time() "created_at": time.time(),
} }
# 立即存储到 Redis # 立即存储到 Redis
await self._store_to_redis(message_id, channel_id, message_data) await self._store_to_redis(message_id, channel_id, message_data)
# 创建响应对象 # 创建响应对象
async with with_db() as session: async with with_db() as session:
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES) user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
# 确保 statistics 不为空 # 确保 statistics 不为空
if user_resp.statistics is None: if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp from app.database.statistics import UserStatisticsResp
user_resp.statistics = UserStatisticsResp( user_resp.statistics = UserStatisticsResp(
mode=user.playmode, mode=user.playmode,
global_rank=0, global_rank=0,
@@ -96,9 +104,9 @@ class RedisMessageSystem:
replays_watched_by_others=0, replays_watched_by_others=0,
is_ranked=False, is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0} level={"current": 1, "progress": 0},
) )
response = ChatMessageResp( response = ChatMessageResp(
message_id=message_id, message_id=message_id,
channel_id=channel_id, channel_id=channel_id,
@@ -107,51 +115,71 @@ class RedisMessageSystem:
sender_id=user.id, sender_id=user.id,
sender=user_resp, sender=user_resp,
is_action=is_action, is_action=is_action,
uuid=user_uuid uuid=user_uuid,
)
logger.info(
f"Message {message_id} sent to Redis cache for channel {channel_id}"
) )
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response return response
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]: async def get_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[ChatMessageResp]:
""" """
获取频道消息 - 优先从 Redis 获取最新消息 获取频道消息 - 优先从 Redis 获取最新消息
Args: Args:
channel_id: 频道ID channel_id: 频道ID
limit: 消息数量限制 limit: 消息数量限制
since: 起始消息ID since: 起始消息ID
Returns: Returns:
List[ChatMessageResp]: 消息列表 List[ChatMessageResp]: 消息列表
""" """
messages = [] messages = []
try: try:
# 从 Redis 获取最新消息 # 从 Redis 获取最新消息
redis_messages = await self._get_from_redis(channel_id, limit, since) redis_messages = await self._get_from_redis(channel_id, limit, since)
# 为每条消息构建响应对象 # 为每条消息构建响应对象
async with with_db() as session: async with with_db() as session:
for msg_data in redis_messages: for msg_data in redis_messages:
# 获取发送者信息 # 获取发送者信息
sender = await session.get(User, msg_data["sender_id"]) sender = await session.get(User, msg_data["sender_id"])
if sender: if sender:
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) user_resp = await UserResp.from_db(
sender, session, RANKING_INCLUDES
)
if user_resp.statistics is None: if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp from app.database.statistics import UserStatisticsResp
user_resp.statistics = UserStatisticsResp( user_resp.statistics = UserStatisticsResp(
mode=sender.playmode, mode=sender.playmode,
global_rank=0, country_rank=0, pp=0.0, global_rank=0,
ranked_score=0, hit_accuracy=0.0, play_count=0, country_rank=0,
play_time=0, total_score=0, total_hits=0, pp=0.0,
maximum_combo=0, replays_watched_by_others=0, ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False, is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, grade_counts={
level={"current": 1, "progress": 0} "ssh": 0,
"ss": 0,
"sh": 0,
"s": 0,
"a": 0,
},
level={"current": 1, "progress": 0},
) )
message_resp = ChatMessageResp( message_resp = ChatMessageResp(
message_id=msg_data["message_id"], message_id=msg_data["message_id"],
channel_id=msg_data["channel_id"], channel_id=msg_data["channel_id"],
@@ -160,77 +188,97 @@ class RedisMessageSystem:
sender_id=msg_data["sender_id"], sender_id=msg_data["sender_id"],
sender=user_resp, sender=user_resp,
is_action=msg_data["type"] == MessageType.ACTION.value, is_action=msg_data["type"] == MessageType.ACTION.value,
uuid=msg_data.get("uuid") or None uuid=msg_data.get("uuid") or None,
) )
messages.append(message_resp) messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充 # 如果 Redis 消息不够,从数据库补充
if len(messages) < limit and since == 0: if len(messages) < limit and since == 0:
await self._backfill_from_database(channel_id, messages, limit) await self._backfill_from_database(channel_id, messages, limit)
except Exception as e: except Exception as e:
logger.error(f"Failed to get messages from Redis: {e}") logger.error(f"Failed to get messages from Redis: {e}")
# 回退到数据库查询 # 回退到数据库查询
messages = await self._get_from_database_only(channel_id, limit, since) messages = await self._get_from_database_only(channel_id, limit, since)
return messages[:limit] return messages[:limit]
async def _generate_message_id(self, channel_id: int) -> int: async def _generate_message_id(self, channel_id: int) -> int:
"""生成唯一的消息ID - 确保全局唯一且严格递增""" """生成唯一的消息ID - 确保全局唯一且严格递增"""
# 使用全局计数器确保所有频道的消息ID都是严格递增的 # 使用全局计数器确保所有频道的消息ID都是严格递增的
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") message_id = await self._redis_exec(
self.redis.incr, "global_message_id_counter"
)
# 同时更新频道的最后消息ID用于客户端状态同步 # 同时更新频道的最后消息ID用于客户端状态同步
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) await self._redis_exec(
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
)
return message_id return message_id
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]): async def _store_to_redis(
self, message_id: int, channel_id: int, message_data: dict[str, Any]
):
"""存储消息到 Redis""" """存储消息到 Redis"""
try: try:
# 存储消息数据 # 存储消息数据
await self._redis_exec( await self._redis_exec(
self.redis.hset, self.redis.hset,
f"msg:{channel_id}:{message_id}", f"msg:{channel_id}:{message_id}",
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) mapping={
for k, v in message_data.items()} k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
for k, v in message_data.items()
},
) )
# 设置消息过期时间7天 # 设置消息过期时间7天
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) await self._redis_exec(
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
channel_messages_key = f"channel:{channel_id}:messages" channel_messages_key = f"channel:{channel_id}:messages"
# 检查键的类型,如果不是 zset 类型则删除 # 检查键的类型,如果不是 zset 类型则删除
try: try:
key_type = await self._redis_exec(self.redis.type, channel_messages_key) key_type = await self._redis_exec(self.redis.type, channel_messages_key)
if key_type and key_type != "zset": if key_type and key_type != "zset":
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") logger.warning(
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
)
await self._redis_exec(self.redis.delete, channel_messages_key) await self._redis_exec(self.redis.delete, channel_messages_key)
except Exception as type_check_error: except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") logger.warning(
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
)
# 如果检查失败,直接删除键以确保清理 # 如果检查失败,直接删除键以确保清理
await self._redis_exec(self.redis.delete, channel_messages_key) await self._redis_exec(self.redis.delete, channel_messages_key)
# 添加到频道消息列表sorted set # 添加到频道消息列表sorted set
await self._redis_exec( await self._redis_exec(
self.redis.zadd, self.redis.zadd,
channel_messages_key, channel_messages_key,
{f"msg:{channel_id}:{message_id}": message_id} {f"msg:{channel_id}:{message_id}": message_id},
) )
# 保持频道消息列表大小最多1000条 # 保持频道消息列表大小最多1000条
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) await self._redis_exec(
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
)
# 添加到待持久化队列 # 添加到待持久化队列
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") await self._redis_exec(
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to store message to Redis: {e}") logger.error(f"Failed to store message to Redis: {e}")
raise raise
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]: async def _get_from_redis(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict[str, Any]]:
"""从 Redis 获取消息""" """从 Redis 获取消息"""
try: try:
# 获取消息键列表按消息ID排序 # 获取消息键列表按消息ID排序
@@ -239,22 +287,22 @@ class RedisMessageSystem:
message_keys = await self._redis_exec( message_keys = await self._redis_exec(
self.redis.zrangebyscore, self.redis.zrangebyscore,
f"channel:{channel_id}:messages", f"channel:{channel_id}:messages",
since + 1, "+inf", since + 1,
start=0, num=limit "+inf",
start=0,
num=limit,
) )
else: else:
# 获取最新的消息(倒序获取,然后反转) # 获取最新的消息(倒序获取,然后反转)
message_keys = await self._redis_exec( message_keys = await self._redis_exec(
self.redis.zrevrange, self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1
f"channel:{channel_id}:messages",
0, limit - 1
) )
messages = [] messages = []
for key in message_keys: for key in message_keys:
if isinstance(key, bytes): if isinstance(key, bytes):
key = key.decode('utf-8') key = key.decode("utf-8")
# 获取消息数据 # 获取消息数据
raw_data = await self._redis_exec(self.redis.hgetall, key) raw_data = await self._redis_exec(self.redis.hgetall, key)
if raw_data: if raw_data:
@@ -262,106 +310,118 @@ class RedisMessageSystem:
message_data = {} message_data = {}
for k, v in raw_data.items(): for k, v in raw_data.items():
if isinstance(k, bytes): if isinstance(k, bytes):
k = k.decode('utf-8') k = k.decode("utf-8")
if isinstance(v, bytes): if isinstance(v, bytes):
v = v.decode('utf-8') v = v.decode("utf-8")
# 尝试解析 JSON # 尝试解析 JSON
try: try:
if k in ['grade_counts', 'level'] or v.startswith(('{', '[')): if k in ["grade_counts", "level"] or v.startswith(
("{", "[")
):
message_data[k] = json.loads(v) message_data[k] = json.loads(v)
elif k in ['message_id', 'channel_id', 'sender_id']: elif k in ["message_id", "channel_id", "sender_id"]:
message_data[k] = int(v) message_data[k] = int(v)
elif k == 'created_at': elif k == "created_at":
message_data[k] = float(v) message_data[k] = float(v)
else: else:
message_data[k] = v message_data[k] = v
except (json.JSONDecodeError, ValueError): except (json.JSONDecodeError, ValueError):
message_data[k] = v message_data[k] = v
messages.append(message_data) messages.append(message_data)
# 确保消息按ID正序排序时间顺序 # 确保消息按ID正序排序时间顺序
messages.sort(key=lambda x: x.get('message_id', 0)) messages.sort(key=lambda x: x.get("message_id", 0))
# 如果是获取最新消息since=0需要保持倒序最新的在前面 # 如果是获取最新消息since=0需要保持倒序最新的在前面
if since == 0: if since == 0:
messages.reverse() messages.reverse()
return messages return messages
except Exception as e: except Exception as e:
logger.error(f"Failed to get messages from Redis: {e}") logger.error(f"Failed to get messages from Redis: {e}")
return [] return []
async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int): async def _backfill_from_database(
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
):
"""从数据库补充历史消息""" """从数据库补充历史消息"""
try: try:
# 找到最小的消息ID # 找到最小的消息ID
min_id = float('inf') min_id = float("inf")
if existing_messages: if existing_messages:
for msg in existing_messages: for msg in existing_messages:
if msg.message_id is not None and msg.message_id < min_id: if msg.message_id is not None and msg.message_id < min_id:
min_id = msg.message_id min_id = msg.message_id
needed = limit - len(existing_messages) needed = limit - len(existing_messages)
if needed <= 0: if needed <= 0:
return return
async with with_db() as session: async with with_db() as session:
from sqlmodel import select, col from sqlmodel import col, select
query = select(ChatMessage).where(
ChatMessage.channel_id == channel_id query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
)
if min_id != float("inf"):
if min_id != float('inf'):
query = query.where(col(ChatMessage.message_id) < min_id) query = query.where(col(ChatMessage.message_id) < min_id)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed) query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed)
db_messages = (await session.exec(query)).all() db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入 for msg in reversed(db_messages): # 按时间正序插入
msg_resp = await ChatMessageResp.from_db(msg, session) msg_resp = await ChatMessageResp.from_db(msg, session)
existing_messages.insert(0, msg_resp) existing_messages.insert(0, msg_resp)
except Exception as e: except Exception as e:
logger.error(f"Failed to backfill from database: {e}") logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]: async def _get_from_database_only(
self, channel_id: int, limit: int, since: int
) -> list[ChatMessageResp]:
"""仅从数据库获取消息(回退方案)""" """仅从数据库获取消息(回退方案)"""
try: try:
async with with_db() as session: async with with_db() as session:
from sqlmodel import select, col from sqlmodel import col, select
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
if since > 0: if since > 0:
# 获取指定ID之后的消息按ID正序 # 获取指定ID之后的消息按ID正序
query = query.where(col(ChatMessage.message_id) > since) query = query.where(col(ChatMessage.message_id) > since)
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit) query = query.order_by(col(ChatMessage.message_id).asc()).limit(
limit
)
else: else:
# 获取最新消息按ID倒序最新的在前面 # 获取最新消息按ID倒序最新的在前面
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) query = query.order_by(col(ChatMessage.message_id).desc()).limit(
limit
)
messages = (await session.exec(query)).all() messages = (await session.exec(query)).all()
results = [await ChatMessageResp.from_db(msg, session) for msg in messages] results = [
await ChatMessageResp.from_db(msg, session) for msg in messages
]
# 如果是 since > 0保持正序否则反转为时间正序 # 如果是 since > 0保持正序否则反转为时间正序
if since == 0: if since == 0:
results.reverse() results.reverse()
return results return results
except Exception as e: except Exception as e:
logger.error(f"Failed to get messages from database: {e}") logger.error(f"Failed to get messages from database: {e}")
return [] return []
async def _batch_persist_to_database(self): async def _batch_persist_to_database(self):
"""批量持久化消息到数据库""" """批量持久化消息到数据库"""
logger.info("Starting batch persistence to database") logger.info("Starting batch persistence to database")
while self._running: while self._running:
try: try:
# 获取待处理的消息 # 获取待处理的消息
@@ -374,52 +434,52 @@ class RedisMessageSystem:
# key 是 (queue_name, value) 的元组 # key 是 (queue_name, value) 的元组
value = key[1] value = key[1]
if isinstance(value, bytes): if isinstance(value, bytes):
value = value.decode('utf-8') value = value.decode("utf-8")
message_keys.append(value) message_keys.append(value)
else: else:
break break
if message_keys: if message_keys:
await self._process_message_batch(message_keys) await self._process_message_batch(message_keys)
else: else:
await asyncio.sleep(self.batch_interval) await asyncio.sleep(self.batch_interval)
except Exception as e: except Exception as e:
logger.error(f"Error in batch persistence: {e}") logger.error(f"Error in batch persistence: {e}")
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("Stopped batch persistence to database") logger.info("Stopped batch persistence to database")
async def _process_message_batch(self, message_keys: List[str]): async def _process_message_batch(self, message_keys: list[str]):
"""处理消息批次""" """处理消息批次"""
async with with_db() as session: async with with_db() as session:
for key in message_keys: for key in message_keys:
try: try:
# 解析频道ID和消息ID # 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(':')) channel_id, message_id = map(int, key.split(":"))
# 从 Redis 获取消息数据 # 从 Redis 获取消息数据
raw_data = await self._redis_exec( raw_data = await self._redis_exec(
self.redis.hgetall, f"msg:{channel_id}:{message_id}" self.redis.hgetall, f"msg:{channel_id}:{message_id}"
) )
if not raw_data: if not raw_data:
continue continue
# 解码数据 # 解码数据
message_data = {} message_data = {}
for k, v in raw_data.items(): for k, v in raw_data.items():
if isinstance(k, bytes): if isinstance(k, bytes):
k = k.decode('utf-8') k = k.decode("utf-8")
if isinstance(v, bytes): if isinstance(v, bytes):
v = v.decode('utf-8') v = v.decode("utf-8")
message_data[k] = v message_data[k] = v
# 检查消息是否已存在于数据库 # 检查消息是否已存在于数据库
existing = await session.get(ChatMessage, int(message_id)) existing = await session.get(ChatMessage, int(message_id))
if existing: if existing:
continue continue
# 创建数据库消息 - 使用 Redis 生成的正数ID # 创建数据库消息 - 使用 Redis 生成的正数ID
db_message = ChatMessage( db_message = ChatMessage(
message_id=int(message_id), # 使用 Redis 系统生成的正数ID message_id=int(message_id), # 使用 Redis 系统生成的正数ID
@@ -428,31 +488,34 @@ class RedisMessageSystem:
content=message_data["content"], content=message_data["content"],
timestamp=datetime.fromisoformat(message_data["timestamp"]), timestamp=datetime.fromisoformat(message_data["timestamp"]),
type=MessageType(message_data["type"]), type=MessageType(message_data["type"]),
uuid=message_data.get("uuid") or None uuid=message_data.get("uuid") or None,
) )
session.add(db_message) session.add(db_message)
# 更新 Redis 中的状态 # 更新 Redis 中的状态
await self._redis_exec( await self._redis_exec(
self.redis.hset, self.redis.hset,
f"msg:{channel_id}:{message_id}", f"msg:{channel_id}:{message_id}",
"status", "persisted" "status",
"persisted",
) )
logger.debug(f"Message {message_id} persisted to database") logger.debug(f"Message {message_id} persisted to database")
except Exception as e: except Exception as e:
logger.error(f"Failed to process message {key}: {e}") logger.error(f"Failed to process message {key}: {e}")
# 提交批次 # 提交批次
try: try:
await session.commit() await session.commit()
logger.info(f"Batch of {len(message_keys)} messages committed to database") logger.info(
f"Batch of {len(message_keys)} messages committed to database"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to commit message batch: {e}") logger.error(f"Failed to commit message batch: {e}")
await session.rollback() await session.rollback()
def start(self): def start(self):
"""启动系统""" """启动系统"""
if not self._running: if not self._running:
@@ -461,63 +524,71 @@ class RedisMessageSystem:
# 启动时初始化消息ID计数器 # 启动时初始化消息ID计数器
asyncio.create_task(self._initialize_message_counter()) asyncio.create_task(self._initialize_message_counter())
logger.info("Redis message system started") logger.info("Redis message system started")
async def _initialize_message_counter(self): async def _initialize_message_counter(self):
"""初始化全局消息ID计数器确保从数据库最大ID开始""" """初始化全局消息ID计数器确保从数据库最大ID开始"""
try: try:
# 清理可能存在的问题键 # 清理可能存在的问题键
await self._cleanup_redis_keys() await self._cleanup_redis_keys()
async with with_db() as session: async with with_db() as session:
from sqlmodel import select, func from sqlmodel import func, select
# 获取数据库中最大的消息ID # 获取数据库中最大的消息ID
result = await session.exec( result = await session.exec(select(func.max(ChatMessage.message_id)))
select(func.max(ChatMessage.message_id))
)
max_id = result.one() or 0 max_id = result.one() or 0
# 检查 Redis 中的计数器值 # 检查 Redis 中的计数器值
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") current_counter = await self._redis_exec(
self.redis.get, "global_message_id_counter"
)
current_counter = int(current_counter) if current_counter else 0 current_counter = int(current_counter) if current_counter else 0
# 设置计数器为两者中的最大值 # 设置计数器为两者中的最大值
initial_counter = max(max_id, current_counter) initial_counter = max(max_id, current_counter)
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) await self._redis_exec(
self.redis.set, "global_message_id_counter", initial_counter
logger.info(f"Initialized global message ID counter to {initial_counter}") )
logger.info(
f"Initialized global message ID counter to {initial_counter}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize message counter: {e}") logger.error(f"Failed to initialize message counter: {e}")
# 如果初始化失败,设置一个安全的起始值 # 如果初始化失败,设置一个安全的起始值
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) await self._redis_exec(
self.redis.setnx, "global_message_id_counter", 1000000
)
async def _cleanup_redis_keys(self): async def _cleanup_redis_keys(self):
"""清理可能存在问题的 Redis 键""" """清理可能存在问题的 Redis 键"""
try: try:
# 扫描所有 channel:*:messages 键并检查类型 # 扫描所有 channel:*:messages 键并检查类型
keys_pattern = "channel:*:messages" keys_pattern = "channel:*:messages"
keys = await self._redis_exec(self.redis.keys, keys_pattern) keys = await self._redis_exec(self.redis.keys, keys_pattern)
for key in keys: for key in keys:
if isinstance(key, bytes): if isinstance(key, bytes):
key = key.decode('utf-8') key = key.decode("utf-8")
try: try:
key_type = await self._redis_exec(self.redis.type, key) key_type = await self._redis_exec(self.redis.type, key)
if key_type and key_type != "zset": if key_type and key_type != "zset":
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") logger.warning(
f"Cleaning up Redis key {key} with wrong type: {key_type}"
)
await self._redis_exec(self.redis.delete, key) await self._redis_exec(self.redis.delete, key)
except Exception as cleanup_error: except Exception as cleanup_error:
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
# 强制删除问题键 # 强制删除问题键
await self._redis_exec(self.redis.delete, key) await self._redis_exec(self.redis.delete, key)
logger.info("Redis keys cleanup completed") logger.info("Redis keys cleanup completed")
except Exception as e: except Exception as e:
logger.error(f"Failed to cleanup Redis keys: {e}") logger.error(f"Failed to cleanup Redis keys: {e}")
def stop(self): def stop(self):
"""停止系统""" """停止系统"""
if self._running: if self._running:
@@ -526,10 +597,10 @@ class RedisMessageSystem:
self._batch_timer.cancel() self._batch_timer.cancel()
self._batch_timer = None self._batch_timer = None
logger.info("Redis message system stopped") logger.info("Redis message system stopped")
def __del__(self): def __del__(self):
"""清理资源""" """清理资源"""
if hasattr(self, 'executor'): if hasattr(self, "executor"):
self.executor.shutdown(wait=False) self.executor.shutdown(wait=False)

View File

@@ -1,81 +1,94 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.dependencies.database import get_redis, get_redis_message from app.dependencies.database import get_redis, get_redis_message
from app.log import logger from app.log import logger
from app.router.v2.stats import REDIS_ONLINE_USERS_KEY, REDIS_PLAYING_USERS_KEY, _redis_exec from app.router.v2.stats import (
REDIS_ONLINE_USERS_KEY,
REDIS_PLAYING_USERS_KEY,
_redis_exec,
)
async def cleanup_stale_online_users() -> tuple[int, int]: async def cleanup_stale_online_users() -> tuple[int, int]:
"""清理过期的在线和游玩用户,返回清理的用户数""" """清理过期的在线和游玩用户,返回清理的用户数"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
redis_async = get_redis() redis_async = get_redis()
online_cleaned = 0 online_cleaned = 0
playing_cleaned = 0 playing_cleaned = 0
try: try:
# 获取所有在线用户 # 获取所有在线用户
online_users = await _redis_exec(redis_sync.smembers, REDIS_ONLINE_USERS_KEY) online_users = await _redis_exec(redis_sync.smembers, REDIS_ONLINE_USERS_KEY)
playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY) playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY)
# 检查在线用户的最后活动时间 # 检查在线用户的最后活动时间
current_time = datetime.utcnow() current_time = datetime.utcnow()
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期
# 对于在线用户我们检查metadata在线标记 # 对于在线用户我们检查metadata在线标记
stale_online_users = [] stale_online_users = []
for user_id in online_users: for user_id in online_users:
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) user_id_str = (
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
)
metadata_key = f"metadata:online:{user_id_str}" metadata_key = f"metadata:online:{user_id_str}"
# 如果metadata标记不存在说明用户已经离线 # 如果metadata标记不存在说明用户已经离线
if not await redis_async.exists(metadata_key): if not await redis_async.exists(metadata_key):
stale_online_users.append(user_id_str) stale_online_users.append(user_id_str)
# 清理过期的在线用户 # 清理过期的在线用户
if stale_online_users: if stale_online_users:
await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users) await _redis_exec(
redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users
)
online_cleaned = len(stale_online_users) online_cleaned = len(stale_online_users)
logger.info(f"Cleaned {online_cleaned} stale online users") logger.info(f"Cleaned {online_cleaned} stale online users")
# 对于游玩用户我们也检查对应的spectator状态 # 对于游玩用户我们也检查对应的spectator状态
stale_playing_users = [] stale_playing_users = []
for user_id in playing_users: for user_id in playing_users:
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) user_id_str = (
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
)
# 如果用户不在在线用户列表中,说明已经离线,也应该从游玩列表中移除 # 如果用户不在在线用户列表中,说明已经离线,也应该从游玩列表中移除
if user_id_str in stale_online_users or user_id_str not in [ if user_id_str in stale_online_users or user_id_str not in [
u.decode() if isinstance(u, bytes) else str(u) for u in online_users u.decode() if isinstance(u, bytes) else str(u) for u in online_users
]: ]:
stale_playing_users.append(user_id_str) stale_playing_users.append(user_id_str)
# 清理过期的游玩用户 # 清理过期的游玩用户
if stale_playing_users: if stale_playing_users:
await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users) await _redis_exec(
redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users
)
playing_cleaned = len(stale_playing_users) playing_cleaned = len(stale_playing_users)
logger.info(f"Cleaned {playing_cleaned} stale playing users") logger.info(f"Cleaned {playing_cleaned} stale playing users")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning stale users: {e}") logger.error(f"Error cleaning stale users: {e}")
return online_cleaned, playing_cleaned return online_cleaned, playing_cleaned
async def refresh_redis_key_expiry() -> None: async def refresh_redis_key_expiry() -> None:
"""刷新Redis键的过期时间防止数据丢失""" """刷新Redis键的过期时间防止数据丢失"""
redis_async = get_redis() redis_async = get_redis()
try: try:
# 刷新在线用户key的过期时间 # 刷新在线用户key的过期时间
if await redis_async.exists(REDIS_ONLINE_USERS_KEY): if await redis_async.exists(REDIS_ONLINE_USERS_KEY):
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 6 * 3600) # 6小时 await redis_async.expire(REDIS_ONLINE_USERS_KEY, 6 * 3600) # 6小时
# 刷新游玩用户key的过期时间 # 刷新游玩用户key的过期时间
if await redis_async.exists(REDIS_PLAYING_USERS_KEY): if await redis_async.exists(REDIS_PLAYING_USERS_KEY):
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 6 * 3600) # 6小时 await redis_async.expire(REDIS_PLAYING_USERS_KEY, 6 * 3600) # 6小时
logger.debug("Refreshed Redis key expiry times") logger.debug("Refreshed Redis key expiry times")
except Exception as e: except Exception as e:
logger.error(f"Error refreshing Redis key expiry: {e}") logger.error(f"Error refreshing Redis key expiry: {e}")

View File

@@ -5,46 +5,49 @@ from datetime import datetime, timedelta
from app.log import logger from app.log import logger
from app.router.v2.stats import record_hourly_stats, update_registered_users_count from app.router.v2.stats import record_hourly_stats, update_registered_users_count
from app.service.stats_cleanup import cleanup_stale_online_users, refresh_redis_key_expiry
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
from app.service.stats_cleanup import (
cleanup_stale_online_users,
refresh_redis_key_expiry,
)
class StatsScheduler: class StatsScheduler:
"""统计数据调度器""" """统计数据调度器"""
def __init__(self): def __init__(self):
self._running = False self._running = False
self._stats_task: asyncio.Task | None = None self._stats_task: asyncio.Task | None = None
self._registered_task: asyncio.Task | None = None self._registered_task: asyncio.Task | None = None
self._cleanup_task: asyncio.Task | None = None self._cleanup_task: asyncio.Task | None = None
def start(self) -> None: def start(self) -> None:
"""启动调度器""" """启动调度器"""
if self._running: if self._running:
return return
self._running = True self._running = True
self._stats_task = asyncio.create_task(self._stats_loop()) self._stats_task = asyncio.create_task(self._stats_loop())
self._registered_task = asyncio.create_task(self._registered_users_loop()) self._registered_task = asyncio.create_task(self._registered_users_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop()) self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("Stats scheduler started") logger.info("Stats scheduler started")
def stop(self) -> None: def stop(self) -> None:
"""停止调度器""" """停止调度器"""
if not self._running: if not self._running:
return return
self._running = False self._running = False
if self._stats_task: if self._stats_task:
self._stats_task.cancel() self._stats_task.cancel()
if self._registered_task: if self._registered_task:
self._registered_task.cancel() self._registered_task.cancel()
if self._cleanup_task: if self._cleanup_task:
self._cleanup_task.cancel() self._cleanup_task.cancel()
logger.info("Stats scheduler stopped") logger.info("Stats scheduler stopped")
async def _stats_loop(self) -> None: async def _stats_loop(self) -> None:
"""统计数据记录循环 - 每30分钟记录一次""" """统计数据记录循环 - 每30分钟记录一次"""
# 启动时立即记录一次统计数据 # 启动时立即记录一次统计数据
@@ -53,49 +56,57 @@ class StatsScheduler:
logger.info("Initial enhanced interval statistics initialized on startup") logger.info("Initial enhanced interval statistics initialized on startup")
except Exception as e: except Exception as e:
logger.error(f"Error initializing enhanced interval stats: {e}") logger.error(f"Error initializing enhanced interval stats: {e}")
while self._running: while self._running:
try: try:
# 计算下次记录时间下个30分钟整点 # 计算下次记录时间下个30分钟整点
now = datetime.utcnow() now = datetime.utcnow()
# 计算当前区间边界 # 计算当前区间边界
current_minute = (now.minute // 30) * 30 current_minute = (now.minute // 30) * 30
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30) current_interval_end = now.replace(
minute=current_minute, second=0, microsecond=0
) + timedelta(minutes=30)
# 如果已经过了当前区间结束时间,立即处理 # 如果已经过了当前区间结束时间,立即处理
if now >= current_interval_end: if now >= current_interval_end:
current_interval_end += timedelta(minutes=30) current_interval_end += timedelta(minutes=30)
# 计算需要等待的时间(到下个区间结束) # 计算需要等待的时间(到下个区间结束)
sleep_seconds = (current_interval_end - now).total_seconds() sleep_seconds = (current_interval_end - now).total_seconds()
# 确保至少等待1分钟最多等待31分钟 # 确保至少等待1分钟最多等待31分钟
sleep_seconds = max(min(sleep_seconds, 31 * 60), 60) sleep_seconds = max(min(sleep_seconds, 31 * 60), 60)
logger.debug(f"Next interval finalization in {sleep_seconds/60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}") logger.debug(
f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}"
)
await asyncio.sleep(sleep_seconds) await asyncio.sleep(sleep_seconds)
if not self._running: if not self._running:
break break
# 完成当前区间并记录到历史 # 完成当前区间并记录到历史
finalized_stats = await EnhancedIntervalStatsManager.finalize_interval() finalized_stats = await EnhancedIntervalStatsManager.finalize_interval()
if finalized_stats: if finalized_stats:
logger.info(f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}") logger.info(
f"Finalized enhanced interval statistics at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}"
)
else: else:
# 如果区间完成失败,使用原有方式记录 # 如果区间完成失败,使用原有方式记录
await record_hourly_stats() await record_hourly_stats()
logger.info(f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}") logger.info(
f"Recorded hourly statistics (fallback) at {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}"
)
# 开始新的区间统计 # 开始新的区间统计
await EnhancedIntervalStatsManager.initialize_current_interval() await EnhancedIntervalStatsManager.initialize_current_interval()
except Exception as e: except Exception as e:
logger.error(f"Error in stats loop: {e}") logger.error(f"Error in stats loop: {e}")
# 出错时等待5分钟再重试 # 出错时等待5分钟再重试
await asyncio.sleep(5 * 60) await asyncio.sleep(5 * 60)
async def _registered_users_loop(self) -> None: async def _registered_users_loop(self) -> None:
"""注册用户数更新循环 - 每5分钟更新一次""" """注册用户数更新循环 - 每5分钟更新一次"""
# 启动时立即更新一次注册用户数 # 启动时立即更新一次注册用户数
@@ -104,14 +115,14 @@ class StatsScheduler:
logger.info("Initial registered users count updated on startup") logger.info("Initial registered users count updated on startup")
except Exception as e: except Exception as e:
logger.error(f"Error updating initial registered users count: {e}") logger.error(f"Error updating initial registered users count: {e}")
while self._running: while self._running:
# 等待5分钟 # 等待5分钟
await asyncio.sleep(5 * 60) await asyncio.sleep(5 * 60)
if not self._running: if not self._running:
break break
try: try:
await update_registered_users_count() await update_registered_users_count()
logger.debug("Updated registered users count") logger.debug("Updated registered users count")
@@ -124,31 +135,35 @@ class StatsScheduler:
try: try:
online_cleaned, playing_cleaned = await cleanup_stale_online_users() online_cleaned, playing_cleaned = await cleanup_stale_online_users()
if online_cleaned > 0 or playing_cleaned > 0: if online_cleaned > 0 or playing_cleaned > 0:
logger.info(f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users") logger.info(
f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
)
await refresh_redis_key_expiry() await refresh_redis_key_expiry()
except Exception as e: except Exception as e:
logger.error(f"Error in initial cleanup: {e}") logger.error(f"Error in initial cleanup: {e}")
while self._running: while self._running:
# 等待10分钟 # 等待10分钟
await asyncio.sleep(10 * 60) await asyncio.sleep(10 * 60)
if not self._running: if not self._running:
break break
try: try:
# 清理过期用户 # 清理过期用户
online_cleaned, playing_cleaned = await cleanup_stale_online_users() online_cleaned, playing_cleaned = await cleanup_stale_online_users()
if online_cleaned > 0 or playing_cleaned > 0: if online_cleaned > 0 or playing_cleaned > 0:
logger.info(f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users") logger.info(
f"Cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
)
# 刷新Redis key过期时间 # 刷新Redis key过期时间
await refresh_redis_key_expiry() await refresh_redis_key_expiry()
# 清理过期的区间数据 # 清理过期的区间数据
await EnhancedIntervalStatsManager.cleanup_old_intervals() await EnhancedIntervalStatsManager.cleanup_old_intervals()
except Exception as e: except Exception as e:
logger.error(f"Error in cleanup loop: {e}") logger.error(f"Error in cleanup loop: {e}")
# 出错时等待2分钟再重试 # 出错时等待2分钟再重试

View File

@@ -92,11 +92,12 @@ class MetadataHub(Hub[MetadataClientState]):
@override @override
async def _clean_state(self, state: MetadataClientState) -> None: async def _clean_state(self, state: MetadataClientState) -> None:
user_id = int(state.connection_id) user_id = int(state.connection_id)
# Remove from online user tracking # Remove from online user tracking
from app.router.v2.stats import remove_online_user from app.router.v2.stats import remove_online_user
asyncio.create_task(remove_online_user(user_id)) asyncio.create_task(remove_online_user(user_id))
if state.pushable: if state.pushable:
await asyncio.gather(*self.broadcast_tasks(user_id, None)) await asyncio.gather(*self.broadcast_tasks(user_id, None))
redis = get_redis() redis = get_redis()
@@ -125,6 +126,7 @@ class MetadataHub(Hub[MetadataClientState]):
# Track online user # Track online user
from app.router.v2.stats import add_online_user from app.router.v2.stats import add_online_user
asyncio.create_task(add_online_user(user_id)) asyncio.create_task(add_online_user(user_id))
async with with_db() as session: async with with_db() as session:

View File

@@ -163,11 +163,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
@override @override
async def _clean_state(self, state: MultiplayerClientState): async def _clean_state(self, state: MultiplayerClientState):
user_id = int(state.connection_id) user_id = int(state.connection_id)
# Remove from online user tracking # Remove from online user tracking
from app.router.v2.stats import remove_online_user from app.router.v2.stats import remove_online_user
asyncio.create_task(remove_online_user(user_id)) asyncio.create_task(remove_online_user(user_id))
if state.room_id != 0 and state.room_id in self.rooms: if state.room_id != 0 and state.room_id in self.rooms:
server_room = self.rooms[state.room_id] server_room = self.rooms[state.room_id]
room = server_room.room room = server_room.room
@@ -180,9 +181,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def on_client_connect(self, client: Client) -> None: async def on_client_connect(self, client: Client) -> None:
"""Track online users when connecting to multiplayer hub""" """Track online users when connecting to multiplayer hub"""
logger.info(f"[MultiplayerHub] Client {client.user_id} connected") logger.info(f"[MultiplayerHub] Client {client.user_id} connected")
# Track online user # Track online user
from app.router.v2.stats import add_online_user from app.router.v2.stats import add_online_user
asyncio.create_task(add_online_user(client.user_id)) asyncio.create_task(add_online_user(client.user_id))
def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom: def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom:
@@ -292,11 +294,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room.users.append(user) room.users.append(user)
self.add_to_group(client, self.group_id(room_id)) self.add_to_group(client, self.group_id(room_id))
await server_room.match_type_handler.handle_join(user) await server_room.match_type_handler.handle_join(user)
# Critical fix: Send current room and gameplay state to new user # Critical fix: Send current room and gameplay state to new user
# This ensures spectators joining ongoing games get proper state sync # This ensures spectators joining ongoing games get proper state sync
await self._send_room_state_to_new_user(client, server_room) await self._send_room_state_to_new_user(client, server_room)
await self.event_logger.player_joined(room_id, user.user_id) await self.event_logger.player_joined(room_id, user.user_id)
async with with_db() as session: async with with_db() as session:
@@ -669,16 +671,22 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
# Enhanced spectator validation - allow transitions from more states # Enhanced spectator validation - allow transitions from more states
# This matches official osu-server-spectator behavior # This matches official osu-server-spectator behavior
if old not in ( if old not in (
MultiplayerUserState.IDLE, MultiplayerUserState.IDLE,
MultiplayerUserState.READY, MultiplayerUserState.READY,
MultiplayerUserState.RESULTS, # Allow spectating after results MultiplayerUserState.RESULTS, # Allow spectating after results
): ):
# Allow spectating during gameplay states only if the room is in appropriate state # Allow spectating during gameplay states only if the room is in appropriate state
if not (old.is_playing and room.room.state in ( if not (
MultiplayerRoomState.WAITING_FOR_LOAD, old.is_playing
MultiplayerRoomState.PLAYING and room.room.state
)): in (
raise InvokeException(f"Cannot change state from {old} to {new}") MultiplayerRoomState.WAITING_FOR_LOAD,
MultiplayerRoomState.PLAYING,
)
):
raise InvokeException(
f"Cannot change state from {old} to {new}"
)
case _: case _:
raise InvokeException(f"Invalid state transition from {old} to {new}") raise InvokeException(f"Invalid state transition from {old} to {new}")
@@ -691,7 +699,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if user.state == state: if user.state == state:
return return
# Special handling for state changes during gameplay # Special handling for state changes during gameplay
match state: match state:
case MultiplayerUserState.IDLE: case MultiplayerUserState.IDLE:
@@ -704,15 +712,15 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
logger.info( logger.info(
f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}" f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}"
) )
await self.validate_user_stare( await self.validate_user_stare(
server_room, server_room,
user.state, user.state,
state, state,
) )
await self.change_user_state(server_room, user, state) await self.change_user_state(server_room, user, state)
# Enhanced spectator handling based on official implementation # Enhanced spectator handling based on official implementation
if state == MultiplayerUserState.SPECTATING: if state == MultiplayerUserState.SPECTATING:
await self.handle_spectator_state_change(client, server_room, user) await self.handle_spectator_state_change(client, server_room, user)
@@ -738,24 +746,21 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
async def handle_spectator_state_change( async def handle_spectator_state_change(
self, self, client: Client, room: ServerMultiplayerRoom, user: MultiplayerRoomUser
client: Client,
room: ServerMultiplayerRoom,
user: MultiplayerRoomUser
): ):
""" """
Handle special logic for users entering spectator mode during ongoing gameplay. Handle special logic for users entering spectator mode during ongoing gameplay.
Based on official osu-server-spectator implementation. Based on official osu-server-spectator implementation.
""" """
room_state = room.room.state room_state = room.room.state
# If switching to spectating during gameplay, immediately request load # If switching to spectating during gameplay, immediately request load
if room_state == MultiplayerRoomState.WAITING_FOR_LOAD: if room_state == MultiplayerRoomState.WAITING_FOR_LOAD:
logger.info( logger.info(
f"[MultiplayerHub] Spectator {user.user_id} joining during load phase" f"[MultiplayerHub] Spectator {user.user_id} joining during load phase"
) )
await self.call_noblock(client, "LoadRequested") await self.call_noblock(client, "LoadRequested")
elif room_state == MultiplayerRoomState.PLAYING: elif room_state == MultiplayerRoomState.PLAYING:
logger.info( logger.info(
f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay" f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay"
@@ -763,9 +768,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await self.call_noblock(client, "LoadRequested") await self.call_noblock(client, "LoadRequested")
async def _send_current_gameplay_state_to_spectator( async def _send_current_gameplay_state_to_spectator(
self, self, client: Client, room: ServerMultiplayerRoom
client: Client,
room: ServerMultiplayerRoom
): ):
""" """
Send current gameplay state information to a newly joined spectator. Send current gameplay state information to a newly joined spectator.
@@ -773,12 +776,8 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
""" """
try: try:
# Send current room state # Send current room state
await self.call_noblock( await self.call_noblock(client, "RoomStateChanged", room.room.state)
client,
"RoomStateChanged",
room.room.state
)
# Send current user states for all players # Send current user states for all players
for room_user in room.room.users: for room_user in room.room.users:
if room_user.state.is_playing: if room_user.state.is_playing:
@@ -788,7 +787,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room_user.user_id, room_user.user_id,
room_user.state, room_user.state,
) )
logger.debug( logger.debug(
f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}" f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}"
) )
@@ -798,9 +797,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
async def _send_room_state_to_new_user( async def _send_room_state_to_new_user(
self, self, client: Client, room: ServerMultiplayerRoom
client: Client,
room: ServerMultiplayerRoom
): ):
""" """
Send complete room state to a newly joined user. Send complete room state to a newly joined user.
@@ -809,23 +806,19 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
try: try:
# Send current room state # Send current room state
if room.room.state != MultiplayerRoomState.OPEN: if room.room.state != MultiplayerRoomState.OPEN:
await self.call_noblock( await self.call_noblock(client, "RoomStateChanged", room.room.state)
client,
"RoomStateChanged",
room.room.state
)
# If room is in gameplay state, send LoadRequested immediately # If room is in gameplay state, send LoadRequested immediately
if room.room.state in ( if room.room.state in (
MultiplayerRoomState.WAITING_FOR_LOAD, MultiplayerRoomState.WAITING_FOR_LOAD,
MultiplayerRoomState.PLAYING MultiplayerRoomState.PLAYING,
): ):
logger.info( logger.info(
f"[MultiplayerHub] Sending LoadRequested to user {client.user_id} " f"[MultiplayerHub] Sending LoadRequested to user {client.user_id} "
f"joining ongoing game (room state: {room.room.state})" f"joining ongoing game (room state: {room.room.state})"
) )
await self.call_noblock(client, "LoadRequested") await self.call_noblock(client, "LoadRequested")
# Send all user states to help with synchronization # Send all user states to help with synchronization
for room_user in room.room.users: for room_user in room.room.users:
if room_user.user_id != client.user_id: # Don't send own state if room_user.user_id != client.user_id: # Don't send own state
@@ -835,11 +828,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
room_user.user_id, room_user.user_id,
room_user.state, room_user.state,
) )
# Critical addition: Send current playing users to SpectatorHub for cross-hub sync # Critical addition: Send current playing users to SpectatorHub for cross-hub sync
# This ensures spectators can watch multiplayer players properly # This ensures spectators can watch multiplayer players properly
await self._sync_with_spectator_hub(client, room) await self._sync_with_spectator_hub(client, room)
logger.debug( logger.debug(
f"[MultiplayerHub] Sent complete room state to new user {client.user_id}" f"[MultiplayerHub] Sent complete room state to new user {client.user_id}"
) )
@@ -849,9 +842,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
async def _sync_with_spectator_hub( async def _sync_with_spectator_hub(
self, self, client: Client, room: ServerMultiplayerRoom
client: Client,
room: ServerMultiplayerRoom
): ):
""" """
Sync with SpectatorHub to ensure cross-hub spectating works properly. Sync with SpectatorHub to ensure cross-hub spectating works properly.
@@ -860,7 +851,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
try: try:
# Import here to avoid circular imports # Import here to avoid circular imports
from app.signalr.hub import SpectatorHubs from app.signalr.hub import SpectatorHubs
# For each playing user in the room, check if they have SpectatorHub state # For each playing user in the room, check if they have SpectatorHub state
# and notify the new client about their playing status # and notify the new client about their playing status
for room_user in room.room.users: for room_user in room.room.users:
@@ -878,7 +869,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
f"[MultiplayerHub] Synced spectator state for user {room_user.user_id} " f"[MultiplayerHub] Synced spectator state for user {room_user.user_id} "
f"to new client {client.user_id}" f"to new client {client.user_id}"
) )
except Exception as e: except Exception as e:
logger.debug(f"[MultiplayerHub] Failed to sync with SpectatorHub: {e}") logger.debug(f"[MultiplayerHub] Failed to sync with SpectatorHub: {e}")
# This is not critical, so we don't raise the exception # This is not critical, so we don't raise the exception

View File

@@ -170,21 +170,22 @@ class SpectatorHub(Hub[StoreClientState]):
Properly notifies watched users when spectator disconnects. Properly notifies watched users when spectator disconnects.
""" """
user_id = int(state.connection_id) user_id = int(state.connection_id)
# Remove from online and playing tracking # Remove from online and playing tracking
from app.router.v2.stats import remove_online_user from app.router.v2.stats import remove_online_user
asyncio.create_task(remove_online_user(user_id)) asyncio.create_task(remove_online_user(user_id))
if state.state: if state.state:
await self._end_session(user_id, state.state, state) await self._end_session(user_id, state.state, state)
# Critical fix: Notify all watched users that this spectator has disconnected # Critical fix: Notify all watched users that this spectator has disconnected
# This matches the official CleanUpState implementation # This matches the official CleanUpState implementation
for watched_user_id in state.watched_user: for watched_user_id in state.watched_user:
if (target_client := self.get_client_by_id(str(watched_user_id))) is not None: if (
await self.call_noblock( target_client := self.get_client_by_id(str(watched_user_id))
target_client, "UserEndedWatching", user_id ) is not None:
) await self.call_noblock(target_client, "UserEndedWatching", user_id)
logger.debug( logger.debug(
f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching" f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching"
) )
@@ -195,18 +196,19 @@ class SpectatorHub(Hub[StoreClientState]):
Send all active player states to newly connected clients. Send all active player states to newly connected clients.
""" """
logger.info(f"[SpectatorHub] Client {client.user_id} connected") logger.info(f"[SpectatorHub] Client {client.user_id} connected")
# Track online user # Track online user
from app.router.v2.stats import add_online_user from app.router.v2.stats import add_online_user
asyncio.create_task(add_online_user(client.user_id)) asyncio.create_task(add_online_user(client.user_id))
# Send all current player states to the new client # Send all current player states to the new client
# This matches the official OnConnectedAsync behavior # This matches the official OnConnectedAsync behavior
active_states = [] active_states = []
for user_id, store in self.state.items(): for user_id, store in self.state.items():
if store.state is not None: if store.state is not None:
active_states.append((user_id, store.state)) active_states.append((user_id, store.state))
if active_states: if active_states:
logger.debug( logger.debug(
f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}" f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}"
@@ -216,8 +218,10 @@ class SpectatorHub(Hub[StoreClientState]):
try: try:
await self.call_noblock(client, "UserBeganPlaying", user_id, state) await self.call_noblock(client, "UserBeganPlaying", user_id, state)
except Exception as e: except Exception as e:
logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}") logger.debug(
f"[SpectatorHub] Failed to send state for user {user_id}: {e}"
)
# Also sync with MultiplayerHub for cross-hub spectating # Also sync with MultiplayerHub for cross-hub spectating
await self._sync_with_multiplayer_hub(client) await self._sync_with_multiplayer_hub(client)
@@ -229,14 +233,15 @@ class SpectatorHub(Hub[StoreClientState]):
try: try:
# Import here to avoid circular imports # Import here to avoid circular imports
from app.signalr.hub import MultiplayerHubs from app.signalr.hub import MultiplayerHubs
# Check all active multiplayer rooms for playing users # Check all active multiplayer rooms for playing users
for room_id, server_room in MultiplayerHubs.rooms.items(): for room_id, server_room in MultiplayerHubs.rooms.items():
for room_user in server_room.room.users: for room_user in server_room.room.users:
# If user is playing in multiplayer but we don't have their spectator state # If user is playing in multiplayer but we don't have their spectator state
if (room_user.state.is_playing and if (
room_user.user_id not in self.state): room_user.state.is_playing
and room_user.user_id not in self.state
):
# Create a synthetic SpectatorState for multiplayer players # Create a synthetic SpectatorState for multiplayer players
# This helps with cross-hub spectating # This helps with cross-hub spectating
try: try:
@@ -245,9 +250,9 @@ class SpectatorHub(Hub[StoreClientState]):
ruleset_id=room_user.ruleset_id or 0, # Default to osu! ruleset_id=room_user.ruleset_id or 0, # Default to osu!
mods=room_user.mods, mods=room_user.mods,
state=SpectatedUserState.Playing, state=SpectatedUserState.Playing,
maximum_statistics={} maximum_statistics={},
) )
await self.call_noblock( await self.call_noblock(
client, client,
"UserBeganPlaying", "UserBeganPlaying",
@@ -258,8 +263,10 @@ class SpectatorHub(Hub[StoreClientState]):
f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}" f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}"
) )
except Exception as e: except Exception as e:
logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}") logger.debug(
f"[SpectatorHub] Failed to create synthetic state: {e}"
)
except Exception as e: except Exception as e:
logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}") logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}")
# This is not critical, so we don't raise the exception # This is not critical, so we don't raise the exception
@@ -306,6 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
# Track playing user # Track playing user
from app.router.v2.stats import add_playing_user from app.router.v2.stats import add_playing_user
asyncio.create_task(add_playing_user(user_id)) asyncio.create_task(add_playing_user(user_id))
# # 预缓存beatmap文件以加速后续PP计算 # # 预缓存beatmap文件以加速后续PP计算
@@ -356,11 +364,12 @@ class SpectatorHub(Hub[StoreClientState]):
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()): ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
await self._process_score(store, client) await self._process_score(store, client)
await self._end_session(user_id, state, store) await self._end_session(user_id, state, store)
# Remove from playing user tracking # Remove from playing user tracking
from app.router.v2.stats import remove_playing_user from app.router.v2.stats import remove_playing_user
asyncio.create_task(remove_playing_user(user_id)) asyncio.create_task(remove_playing_user(user_id))
store.state = None store.state = None
store.beatmap_status = None store.beatmap_status = None
store.checksum = None store.checksum = None
@@ -473,8 +482,10 @@ class SpectatorHub(Hub[StoreClientState]):
if state.state == SpectatedUserState.Playing: if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit state.state = SpectatedUserState.Quit
logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}") logger.debug(
f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}"
)
# Calculate exit time safely # Calculate exit time safely
exit_time = 0 exit_time = 0
if store.score and store.score.replay_frames: if store.score and store.score.replay_frames:
@@ -491,7 +502,7 @@ class SpectatorHub(Hub[StoreClientState]):
) )
self.tasks.add(task) self.tasks.add(task)
task.add_done_callback(self.tasks.discard) task.add_done_callback(self.tasks.discard)
# Background task for failtime tracking - only for failed/quit states with valid data # Background task for failtime tracking - only for failed/quit states with valid data
if ( if (
state.beatmap_id is not None state.beatmap_id is not None
@@ -519,14 +530,16 @@ class SpectatorHub(Hub[StoreClientState]):
Properly handles state synchronization and watcher notifications. Properly handles state synchronization and watcher notifications.
""" """
user_id = int(client.connection_id) user_id = int(client.connection_id)
logger.info(f"[SpectatorHub] {user_id} started watching {target_id}") logger.info(f"[SpectatorHub] {user_id} started watching {target_id}")
try: try:
# Get target user's current state if it exists # Get target user's current state if it exists
target_store = self.state.get(target_id) target_store = self.state.get(target_id)
if target_store and target_store.state: if target_store and target_store.state:
logger.debug(f"[SpectatorHub] {target_id} is currently {target_store.state.state}") logger.debug(
f"[SpectatorHub] {target_id} is currently {target_store.state.state}"
)
# Send current state to the watcher immediately # Send current state to the watcher immediately
await self.call_noblock( await self.call_noblock(
client, client,
@@ -552,7 +565,9 @@ class SpectatorHub(Hub[StoreClientState]):
await session.exec(select(User.username).where(User.id == user_id)) await session.exec(select(User.username).where(User.id == user_id))
).first() ).first()
if not username: if not username:
logger.warning(f"[SpectatorHub] Could not find username for user {user_id}") logger.warning(
f"[SpectatorHub] Could not find username for user {user_id}"
)
return return
# Notify target user that someone started watching # Notify target user that someone started watching
@@ -562,7 +577,9 @@ class SpectatorHub(Hub[StoreClientState]):
await self.call_noblock( await self.call_noblock(
target_client, "UserStartedWatching", watcher_info target_client, "UserStartedWatching", watcher_info
) )
logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching") logger.debug(
f"[SpectatorHub] Notified {target_id} that {username} started watching"
)
except Exception as e: except Exception as e:
logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}") logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}")
@@ -572,19 +589,23 @@ class SpectatorHub(Hub[StoreClientState]):
Properly cleans up watcher state and notifies target user. Properly cleans up watcher state and notifies target user.
""" """
user_id = int(client.connection_id) user_id = int(client.connection_id)
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}") logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
# Remove from SignalR group # Remove from SignalR group
self.remove_from_group(client, self.group_id(target_id)) self.remove_from_group(client, self.group_id(target_id))
# Remove from our tracked watched users # Remove from our tracked watched users
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
store.watched_user.discard(target_id) store.watched_user.discard(target_id)
# Notify target user that watcher stopped watching # Notify target user that watcher stopped watching
if (target_client := self.get_client_by_id(str(target_id))) is not None: if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id) await self.call_noblock(target_client, "UserEndedWatching", user_id)
logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching") logger.debug(
f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching"
)
else: else:
logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification") logger.debug(
f"[SpectatorHub] Target user {target_id} not found for end watching notification"
)