feat(notification): support notification
This commit is contained in:
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from .auth import router as auth_router
|
||||
from .chat import chat_router as chat_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
from .notification import chat_router as chat_router
|
||||
from .private import private_router as private_router
|
||||
from .redirect import (
|
||||
redirect_api_router as redirect_api_router,
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import chat_router as chat_router
|
||||
|
||||
from fastapi import Query
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/notifications",
|
||||
tags=["通知", "聊天"],
|
||||
name="获取通知",
|
||||
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
||||
)
|
||||
async def get_notifications(
|
||||
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
|
||||
return {
|
||||
"has_more": False,
|
||||
"notifications": [],
|
||||
"unread_count": 0,
|
||||
"notification_endpoint": notification_endpoint,
|
||||
}
|
||||
149
app/router/notification/__init__.py
Normal file
149
app/router/notification/__init__.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.database.lazer_user import User
|
||||
from app.database.notification import Notification, UserNotification
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.chat import ChatEvent
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import (
|
||||
chat_router as chat_router,
|
||||
server,
|
||||
)
|
||||
|
||||
from fastapi import Body, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, func, select
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
class NotificationResp(BaseModel):
|
||||
has_more: bool
|
||||
notifications: list[Notification]
|
||||
unread_count: int
|
||||
notification_endpoint: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/notifications",
|
||||
tags=["通知", "聊天"],
|
||||
name="获取通知",
|
||||
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
||||
response_model=NotificationResp,
|
||||
)
|
||||
async def get_notifications(
|
||||
session: Database,
|
||||
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
if max_id is not None:
|
||||
query = query.where(UserNotification.notification_id < max_id)
|
||||
notifications = (await session.exec(query)).all()
|
||||
total_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(UserNotification)
|
||||
.where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
unread_count = len(notifications)
|
||||
|
||||
return NotificationResp(
|
||||
has_more=unread_count < total_count,
|
||||
notifications=[notification.notification for notification in notifications],
|
||||
unread_count=unread_count,
|
||||
notification_endpoint=notification_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class _IdentityReq(BaseModel):
|
||||
category: str | None = None
|
||||
id: int | None = None
|
||||
object_id: int | None = None
|
||||
object_type: int | None = None
|
||||
|
||||
|
||||
async def _get_notifications(
|
||||
session: Database, current_user: User, identities: list[_IdentityReq]
|
||||
) -> list[UserNotification]:
|
||||
result: dict[int, UserNotification] = {}
|
||||
base_query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
for identity in identities:
|
||||
query = base_query
|
||||
if identity.id is not None:
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/notifications/mark-read",
|
||||
tags=["通知", "聊天"],
|
||||
name="标记通知为已读",
|
||||
description="标记当前用户的通知为已读。",
|
||||
status_code=204,
|
||||
)
|
||||
async def mark_notifications_as_read(
|
||||
session: Database,
|
||||
identities: list[_IdentityReq] = Body(default_factory=list),
|
||||
notifications: list[_IdentityReq] = Body(default_factory=list),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
identities.extend(notifications)
|
||||
user_notifications = await _get_notifications(session, current_user, identities)
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
event="read",
|
||||
data={
|
||||
"notifications": [i.model_dump() for i in identities],
|
||||
"read_count": len(user_notifications),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
@@ -14,6 +14,7 @@ from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .banchobot import bot
|
||||
@@ -113,6 +114,15 @@ async def send_message(
|
||||
)
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage(
|
||||
msg, current_user, [int(u) for u in user_ids], db_channel.type
|
||||
)
|
||||
)
|
||||
elif db_channel.type == ChannelType.TEAM:
|
||||
await server.new_private_notification(ChannelMessageTeam(msg, current_user))
|
||||
return resp
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
get_db_factory,
|
||||
@@ -13,12 +15,14 @@ from app.dependencies.database import (
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -59,15 +63,24 @@ class ChatServer:
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
async def send_event(self, client: WebSocket, event: ChatEvent):
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: WebSocket, event: ChatEvent): ...
|
||||
|
||||
async def send_event(self, client: WebSocket | int, event: ChatEvent):
|
||||
if isinstance(client, int):
|
||||
client_ = self.connect_client.get(client)
|
||||
if client_ is None:
|
||||
return
|
||||
client = client_
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(client, event)
|
||||
await self.send_event(user_id, event)
|
||||
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
@@ -80,9 +93,7 @@ class ChatServer:
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
)
|
||||
if is_bot_command:
|
||||
client = self.connect_client.get(message.sender_id)
|
||||
if client:
|
||||
self._add_task(self.send_event(client, event))
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
@@ -123,15 +134,13 @@ class ChatServer:
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user.id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
@@ -154,15 +163,13 @@ class ChatServer:
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
@@ -189,15 +196,13 @@ class ChatServer:
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
@@ -223,6 +228,28 @@ class ChatServer:
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
id = await insert_notification(session, detail)
|
||||
users = (
|
||||
await session.exec(
|
||||
select(UserNotification).where(
|
||||
UserNotification.notification_id == id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for user_notification in users:
|
||||
data = user_notification.notification.model_dump()
|
||||
data["is_read"] = user_notification.is_read
|
||||
data["details"] = user_notification.notification.details
|
||||
await server.send_event(
|
||||
user_notification.user_id,
|
||||
ChatEvent(
|
||||
event="new",
|
||||
data=data,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
Reference in New Issue
Block a user