Files
g0v0-server/app/router/notification/server.py

342 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
from typing import Annotated, overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.notification import UserNotification, insert_notification
from app.database.user import User
from app.dependencies.database import (
DBFactory,
Redis,
get_db_factory,
get_redis,
with_db,
)
from app.dependencies.user import get_current_user_and_token
from app.log import log
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
logger = log("NotificationServer")
class ChatServer:
def __init__(self):
self.connect_client: dict[int, WebSocket] = {}
self.channels: dict[int, list[int]] = {}
self.redis: Redis = get_redis()
self.tasks: set[asyncio.Task] = set()
self.ChatSubscriber = ChatSubscriber()
self.ChatSubscriber.chat_server = self
self._subscribed = False
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
if user_id in self.connect_client:
del self.connect_client[user_id]
# 创建频道ID列表的副本以避免在迭代过程中修改字典
channel_ids_to_process = []
for channel_id, channel in self.channels.items():
if user_id in channel:
channel_ids_to_process.append(channel_id)
# 现在安全地处理每个频道
for channel_id in channel_ids_to_process:
# 再次检查用户是否仍在频道中(防止并发修改)
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@overload
async def send_event(self, client: WebSocket, event: ChatEvent): ...
async def send_event(self, client: WebSocket | int, event: ChatEvent):
if isinstance(client, int):
client_ = self.connect_client.get(client)
if client_ is None:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
users_in_channel = self.channels.get(channel_id, [])
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
# 如果频道中没有用户,检查是否是多人游戏频道
if not users_in_channel:
try:
async with with_db() as session:
channel = await session.get(ChatChannel, channel_id)
if channel and channel.type == ChannelType.MULTIPLAYER:
logger.warning(
f"No users in multiplayer channel {channel_id}, message will not be delivered to anyone"
)
# 对于多人游戏房间,这可能是正常的(用户都离开了房间)
# 但我们仍然记录这个情况以便调试
except Exception as e:
logger.error(f"Failed to check channel type for {channel_id}: {e}")
for user_id in users_in_channel:
await self.send_event(user_id, event)
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
bg_tasks.add_task(self.send_event, message.sender_id, event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
channel_id = channel.channel_id
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
if channel_id not in self.channels:
self.channels[channel_id] = []
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
return channel_resp
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
user_id = user.id
channel_id = channel.channel_id
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
),
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
logger.warning(f"Attempted to join non-existent channel {channel_id} by user {user_id}")
return
user = await session.get(User, user_id)
if user is None:
logger.warning(f"Attempted to join channel {channel_id} by non-existent user {user_id}")
return
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
await self.join_channel(user, db_channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
logger.warning(f"Attempted to leave non-existent channel {channel_id} by user {user_id}")
return
user = await session.get(User, user_id)
if user is None:
logger.warning(f"Attempted to leave channel {channel_id} by non-existent user {user_id}")
return
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
await self.leave_channel(user, db_channel, session)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
data["details"] = user_notification.notification.details
await server.send_event(
user_notification.user_id,
ChatEvent(
event="new",
data=data,
),
)
server = ChatServer()
chat_router = APIRouter(include_in_schema=False)
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
try:
while True:
packets = await ws.receive_json()
if packets.get("event") == "chat.end":
async for session in factory():
user = await session.get(User, user_id)
if user is None:
break
await server.disconnect(user, session)
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(f"Client {user_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"Client {user_id} closed the connection.")
else:
logger.exception(f"RuntimeError in client {user_id}: {e}")
except Exception:
logger.exception(f"Error in client {user_id}")
@chat_router.websocket("/notification-server")
async def chat_websocket(
websocket: WebSocket,
factory: Annotated[DBFactory, Depends(get_db_factory)],
token: Annotated[str | None, Query(description="认证令牌支持通过URL参数传递")] = None,
access_token: Annotated[str | None, Query(description="访问令牌支持通过URL参数传递")] = None,
authorization: Annotated[str | None, Header(description="Bearer认证头")] = None,
):
if not server._subscribed:
server._subscribed = True
await server.ChatSubscriber.start_subscribe()
async for session in factory():
# 优先使用查询参数中的token支持token或access_token参数名
auth_token = token or access_token
if not auth_token and authorization:
auth_token = authorization.removeprefix("Bearer ")
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")
return
if (
user_and_token := await get_current_user_and_token(
session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token
)
) is None:
await websocket.close(code=1008, reason="Invalid or expired token")
return
await websocket.accept()
login = await websocket.receive_json()
if login.get("event") != "chat.start":
await websocket.close(code=1008)
return
user = user_and_token[0]
user_id = user.id
server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)
await _listen_stop(websocket, user_id, factory)