feat(notification): support notification

This commit is contained in:
MingxuanGame
2025-08-21 07:22:44 +00:00
parent 6ac9a124ea
commit 9fb0d0c198
13 changed files with 626 additions and 70 deletions

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from app.signalr import signalr_router as signalr_router
from .auth import router as auth_router
from .chat import chat_router as chat_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router
from .notification import chat_router as chat_router
from .private import private_router as private_router
from .redirect import (
redirect_api_router as redirect_api_router,

View File

@@ -1,35 +0,0 @@
from __future__ import annotations
from app.config import settings
from app.router.v2 import api_v2_router as router
from . import channel, message # noqa: F401
from .server import chat_router as chat_router
from fastapi import Query
__all__ = ["chat_router"]
@router.get(
"/notifications",
tags=["通知", "聊天"],
name="获取通知",
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
)
async def get_notifications(
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
else:
notification_endpoint = "/notification-server"
return {
"has_more": False,
"notifications": [],
"unread_count": 0,
"notification_endpoint": notification_endpoint,
}

View File

@@ -0,0 +1,149 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.config import settings
from app.database.lazer_user import User
from app.database.notification import Notification, UserNotification
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.models.chat import ChatEvent
from app.router.v2 import api_v2_router as router
from . import channel, message # noqa: F401
from .server import (
chat_router as chat_router,
server,
)
from fastapi import Body, Query, Security
from pydantic import BaseModel
from sqlmodel import col, func, select
__all__ = ["chat_router"]
class NotificationResp(BaseModel):
has_more: bool
notifications: list[Notification]
unread_count: int
notification_endpoint: str
@router.get(
"/notifications",
tags=["通知", "聊天"],
name="获取通知",
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
response_model=NotificationResp,
)
async def get_notifications(
session: Database,
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
current_user: User = Security(get_client_user),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
else:
notification_endpoint = "/notification-server"
query = select(UserNotification).where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
if max_id is not None:
query = query.where(UserNotification.notification_id < max_id)
notifications = (await session.exec(query)).all()
total_count = (
await session.exec(
select(func.count())
.select_from(UserNotification)
.where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
)
).one()
unread_count = len(notifications)
return NotificationResp(
has_more=unread_count < total_count,
notifications=[notification.notification for notification in notifications],
unread_count=unread_count,
notification_endpoint=notification_endpoint,
)
class _IdentityReq(BaseModel):
category: str | None = None
id: int | None = None
object_id: int | None = None
object_type: int | None = None
async def _get_notifications(
session: Database, current_user: User, identities: list[_IdentityReq]
) -> list[UserNotification]:
result: dict[int, UserNotification] = {}
base_query = select(UserNotification).where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
for identity in identities:
query = base_query
if identity.id is not None:
query = base_query.where(UserNotification.notification_id == identity.id)
if identity.object_id is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_id) == identity.object_id
)
)
if identity.object_type is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_type) == identity.object_type
)
)
if identity.category is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.category) == identity.category
)
)
result.update({n.notification_id: n for n in await session.exec(query)})
return list(result.values())
@router.post(
"/notifications/mark-read",
tags=["通知", "聊天"],
name="标记通知为已读",
description="标记当前用户的通知为已读。",
status_code=204,
)
async def mark_notifications_as_read(
session: Database,
identities: list[_IdentityReq] = Body(default_factory=list),
notifications: list[_IdentityReq] = Body(default_factory=list),
current_user: User = Security(get_client_user),
):
identities.extend(notifications)
user_notifications = await _get_notifications(session, current_user, identities)
for user_notification in user_notifications:
user_notification.is_read = True
assert current_user.id
await server.send_event(
current_user.id,
ChatEvent(
event="read",
data={
"notifications": [i.model_dump() for i in identities],
"read_count": len(user_notifications),
"timestamp": datetime.now(UTC).isoformat(),
},
),
)
await session.commit()

View File

@@ -14,6 +14,7 @@ from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from .banchobot import bot
@@ -113,6 +114,15 @@ async def send_message(
)
if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session)
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
await server.new_private_notification(
ChannelMessage(
msg, current_user, [int(u) for u in user_ids], db_channel.type
)
)
elif db_channel.type == ChannelType.TEAM:
await server.new_private_notification(ChannelMessageTeam(msg, current_user))
return resp

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import asyncio
from typing import overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User
from app.database.notification import UserNotification, insert_notification
from app.dependencies.database import (
DBFactory,
get_db_factory,
@@ -13,12 +15,14 @@ from app.dependencies.database import (
from app.dependencies.user import get_current_user
from app.log import logger
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -59,15 +63,24 @@ class ChatServer:
if channel:
await self.leave_channel(user, channel, session)
async def send_event(self, client: WebSocket, event: ChatEvent):
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@overload
async def send_event(self, client: WebSocket, event: ChatEvent): ...
async def send_event(self, client: WebSocket | int, event: ChatEvent):
if isinstance(client, int):
client_ = self.connect_client.get(client)
if client_ is None:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
for user_id in self.channels.get(channel_id, []):
client = self.connect_client.get(user_id)
if client:
await self.send_event(client, event)
await self.send_event(user_id, event)
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
@@ -80,9 +93,7 @@ class ChatServer:
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
client = self.connect_client.get(message.sender_id)
if client:
self._add_task(self.send_event(client, event))
self._add_task(self.send_event(message.sender_id, event))
else:
self._add_task(
self.broadcast(
@@ -123,15 +134,13 @@ class ChatServer:
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user.id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
@@ -154,15 +163,13 @@ class ChatServer:
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
client = self.connect_client.get(user_id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
return channel_resp
@@ -189,15 +196,13 @@ class ChatServer:
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user_id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
),
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
),
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
@@ -223,6 +228,28 @@ class ChatServer:
await self.leave_channel(user, channel, session)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
data["details"] = user_notification.notification.details
await server.send_event(
user_notification.user_id,
ChatEvent(
event="new",
data=data,
),
)
server = ChatServer()