345 lines
14 KiB
Python
345 lines
14 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from typing import overload
|
||
|
||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||
from app.database.notification import UserNotification, insert_notification
|
||
from app.database.user import User
|
||
from app.dependencies.database import (
|
||
DBFactory,
|
||
get_db_factory,
|
||
get_redis,
|
||
with_db,
|
||
)
|
||
from app.dependencies.user import get_current_user_and_token
|
||
from app.log import logger
|
||
from app.models.chat import ChatEvent
|
||
from app.models.notification import NotificationDetail
|
||
from app.service.subscribers.chat import ChatSubscriber
|
||
from app.utils import bg_tasks
|
||
|
||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||
from fastapi.security import SecurityScopes
|
||
from fastapi.websockets import WebSocketState
|
||
from redis.asyncio import Redis
|
||
from sqlmodel import select
|
||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
|
||
|
||
class ChatServer:
|
||
def __init__(self):
|
||
self.connect_client: dict[int, WebSocket] = {}
|
||
self.channels: dict[int, list[int]] = {}
|
||
self.redis: Redis = get_redis()
|
||
|
||
self.tasks: set[asyncio.Task] = set()
|
||
self.ChatSubscriber = ChatSubscriber()
|
||
self.ChatSubscriber.chat_server = self
|
||
self._subscribed = False
|
||
|
||
def connect(self, user_id: int, client: WebSocket):
|
||
self.connect_client[user_id] = client
|
||
|
||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
|
||
|
||
async def disconnect(self, user: User, session: AsyncSession):
|
||
user_id = user.id
|
||
if user_id in self.connect_client:
|
||
del self.connect_client[user_id]
|
||
|
||
# 创建频道ID列表的副本以避免在迭代过程中修改字典
|
||
channel_ids_to_process = []
|
||
for channel_id, channel in self.channels.items():
|
||
if user_id in channel:
|
||
channel_ids_to_process.append(channel_id)
|
||
|
||
# 现在安全地处理每个频道
|
||
for channel_id in channel_ids_to_process:
|
||
# 再次检查用户是否仍在频道中(防止并发修改)
|
||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||
self.channels[channel_id].remove(user_id)
|
||
# 使用明确的查询避免延迟加载
|
||
db_channel = (
|
||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||
).first()
|
||
if db_channel:
|
||
await self.leave_channel(user, db_channel, session)
|
||
|
||
@overload
|
||
async def send_event(self, client: int, event: ChatEvent): ...
|
||
|
||
@overload
|
||
async def send_event(self, client: WebSocket, event: ChatEvent): ...
|
||
|
||
async def send_event(self, client: WebSocket | int, event: ChatEvent):
|
||
if isinstance(client, int):
|
||
client_ = self.connect_client.get(client)
|
||
if client_ is None:
|
||
return
|
||
client = client_
|
||
if client.client_state == WebSocketState.CONNECTED:
|
||
await client.send_text(event.model_dump_json())
|
||
|
||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||
users_in_channel = self.channels.get(channel_id, [])
|
||
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
|
||
|
||
# 如果频道中没有用户,检查是否是多人游戏频道
|
||
if not users_in_channel:
|
||
try:
|
||
async with with_db() as session:
|
||
channel = await session.get(ChatChannel, channel_id)
|
||
if channel and channel.type == ChannelType.MULTIPLAYER:
|
||
logger.warning(
|
||
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
|
||
)
|
||
# 对于多人游戏房间,这可能是正常的(用户都离开了房间)
|
||
# 但我们仍然记录这个情况以便调试
|
||
except Exception as e:
|
||
logger.error(f"Failed to check channel type for {channel_id}: {e}")
|
||
|
||
for user_id in users_in_channel:
|
||
await self.send_event(user_id, event)
|
||
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
|
||
|
||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||
|
||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||
logger.info(
|
||
f"Sending message to channel {message.channel_id}, message_id: "
|
||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||
)
|
||
|
||
event = ChatEvent(
|
||
event="chat.message.new",
|
||
data={"messages": [message], "users": [message.sender]},
|
||
)
|
||
if is_bot_command:
|
||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||
else:
|
||
# 总是广播消息,无论是临时ID还是真实ID
|
||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||
bg_tasks.add_task(
|
||
self.broadcast,
|
||
message.channel_id,
|
||
event,
|
||
)
|
||
|
||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||
if message.message_id and message.message_id > 0:
|
||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||
else:
|
||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||
|
||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||
channel_id = channel.channel_id
|
||
|
||
not_joined = []
|
||
|
||
if channel_id not in self.channels:
|
||
self.channels[channel_id] = []
|
||
for user in users:
|
||
if user.id not in self.channels[channel_id]:
|
||
self.channels[channel_id].append(user.id)
|
||
not_joined.append(user)
|
||
|
||
for user in not_joined:
|
||
channel_resp = await ChatChannelResp.from_db(
|
||
channel,
|
||
session,
|
||
user,
|
||
self.redis,
|
||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||
)
|
||
await self.send_event(
|
||
user.id,
|
||
ChatEvent(
|
||
event="chat.channel.join",
|
||
data=channel_resp.model_dump(),
|
||
),
|
||
)
|
||
|
||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||
user_id = user.id
|
||
channel_id = channel.channel_id
|
||
|
||
if channel_id not in self.channels:
|
||
self.channels[channel_id] = []
|
||
if user_id not in self.channels[channel_id]:
|
||
self.channels[channel_id].append(user_id)
|
||
|
||
channel_resp = await ChatChannelResp.from_db(
|
||
channel,
|
||
session,
|
||
user,
|
||
self.redis,
|
||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||
)
|
||
|
||
await self.send_event(
|
||
user_id,
|
||
ChatEvent(
|
||
event="chat.channel.join",
|
||
data=channel_resp.model_dump(),
|
||
),
|
||
)
|
||
|
||
return channel_resp
|
||
|
||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||
user_id = user.id
|
||
channel_id = channel.channel_id
|
||
|
||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||
self.channels[channel_id].remove(user_id)
|
||
|
||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||
del self.channels[channel_id]
|
||
|
||
channel_resp = await ChatChannelResp.from_db(
|
||
channel,
|
||
session,
|
||
user,
|
||
self.redis,
|
||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||
)
|
||
await self.send_event(
|
||
user_id,
|
||
ChatEvent(
|
||
event="chat.channel.part",
|
||
data=channel_resp.model_dump(),
|
||
),
|
||
)
|
||
|
||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||
async with with_db() as session:
|
||
# 使用明确的查询避免延迟加载
|
||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||
if db_channel is None:
|
||
logger.warning(f"Attempted to join non-existent channel {channel_id} by user {user_id}")
|
||
return
|
||
|
||
user = await session.get(User, user_id)
|
||
if user is None:
|
||
logger.warning(f"Attempted to join channel {channel_id} by non-existent user {user_id}")
|
||
return
|
||
|
||
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
||
await self.join_channel(user, db_channel, session)
|
||
|
||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||
async with with_db() as session:
|
||
# 使用明确的查询避免延迟加载
|
||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||
if db_channel is None:
|
||
logger.warning(f"Attempted to leave non-existent channel {channel_id} by user {user_id}")
|
||
return
|
||
|
||
user = await session.get(User, user_id)
|
||
if user is None:
|
||
logger.warning(f"Attempted to leave channel {channel_id} by non-existent user {user_id}")
|
||
return
|
||
|
||
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
||
await self.leave_channel(user, db_channel, session)
|
||
|
||
async def new_private_notification(self, detail: NotificationDetail):
|
||
async with with_db() as session:
|
||
id = await insert_notification(session, detail)
|
||
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
|
||
for user_notification in users:
|
||
data = user_notification.notification.model_dump()
|
||
data["is_read"] = user_notification.is_read
|
||
data["details"] = user_notification.notification.details
|
||
await server.send_event(
|
||
user_notification.user_id,
|
||
ChatEvent(
|
||
event="new",
|
||
data=data,
|
||
),
|
||
)
|
||
|
||
|
||
server = ChatServer()
|
||
|
||
chat_router = APIRouter(include_in_schema=False)
|
||
|
||
|
||
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||
try:
|
||
while True:
|
||
packets = await ws.receive_json()
|
||
if packets.get("event") == "chat.end":
|
||
async for session in factory():
|
||
user = await session.get(User, user_id)
|
||
if user is None:
|
||
break
|
||
await server.disconnect(user, session)
|
||
await ws.close(code=1000)
|
||
break
|
||
except WebSocketDisconnect as e:
|
||
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
|
||
except RuntimeError as e:
|
||
if "disconnect message" in str(e):
|
||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||
else:
|
||
logger.exception(f"RuntimeError in client {user_id}: {e}")
|
||
except Exception:
|
||
logger.exception(f"Error in client {user_id}")
|
||
|
||
|
||
@chat_router.websocket("/notification-server")
|
||
async def chat_websocket(
|
||
websocket: WebSocket,
|
||
token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"),
|
||
access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"),
|
||
authorization: str | None = Header(None, description="Bearer认证头"),
|
||
factory: DBFactory = Depends(get_db_factory),
|
||
):
|
||
if not server._subscribed:
|
||
server._subscribed = True
|
||
await server.ChatSubscriber.start_subscribe()
|
||
|
||
async for session in factory():
|
||
# 优先使用查询参数中的token,支持token或access_token参数名
|
||
auth_token = token or access_token
|
||
if not auth_token and authorization:
|
||
if authorization.startswith("Bearer "):
|
||
auth_token = authorization[7:]
|
||
else:
|
||
auth_token = authorization
|
||
|
||
if not auth_token:
|
||
await websocket.close(code=1008, reason="Missing authentication token")
|
||
return
|
||
|
||
if (
|
||
user_and_token := await get_current_user_and_token(
|
||
session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token
|
||
)
|
||
) is None:
|
||
await websocket.close(code=1008, reason="Invalid or expired token")
|
||
return
|
||
|
||
await websocket.accept()
|
||
login = await websocket.receive_json()
|
||
if login.get("event") != "chat.start":
|
||
await websocket.close(code=1008)
|
||
return
|
||
user = user_and_token[0]
|
||
user_id = user.id
|
||
server.connect(user_id, websocket)
|
||
# 使用明确的查询避免延迟加载
|
||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||
if db_channel is not None:
|
||
await server.join_channel(user, db_channel, session)
|
||
|
||
await _listen_stop(websocket, user_id, factory)
|