feat(notification): support notification
This commit is contained in:
149
app/router/notification/__init__.py
Normal file
149
app/router/notification/__init__.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.database.lazer_user import User
|
||||
from app.database.notification import Notification, UserNotification
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.chat import ChatEvent
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import (
|
||||
chat_router as chat_router,
|
||||
server,
|
||||
)
|
||||
|
||||
from fastapi import Body, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, func, select
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
class NotificationResp(BaseModel):
|
||||
has_more: bool
|
||||
notifications: list[Notification]
|
||||
unread_count: int
|
||||
notification_endpoint: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/notifications",
|
||||
tags=["通知", "聊天"],
|
||||
name="获取通知",
|
||||
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
||||
response_model=NotificationResp,
|
||||
)
|
||||
async def get_notifications(
|
||||
session: Database,
|
||||
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
if max_id is not None:
|
||||
query = query.where(UserNotification.notification_id < max_id)
|
||||
notifications = (await session.exec(query)).all()
|
||||
total_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(UserNotification)
|
||||
.where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
unread_count = len(notifications)
|
||||
|
||||
return NotificationResp(
|
||||
has_more=unread_count < total_count,
|
||||
notifications=[notification.notification for notification in notifications],
|
||||
unread_count=unread_count,
|
||||
notification_endpoint=notification_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class _IdentityReq(BaseModel):
|
||||
category: str | None = None
|
||||
id: int | None = None
|
||||
object_id: int | None = None
|
||||
object_type: int | None = None
|
||||
|
||||
|
||||
async def _get_notifications(
|
||||
session: Database, current_user: User, identities: list[_IdentityReq]
|
||||
) -> list[UserNotification]:
|
||||
result: dict[int, UserNotification] = {}
|
||||
base_query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
for identity in identities:
|
||||
query = base_query
|
||||
if identity.id is not None:
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/notifications/mark-read",
|
||||
tags=["通知", "聊天"],
|
||||
name="标记通知为已读",
|
||||
description="标记当前用户的通知为已读。",
|
||||
status_code=204,
|
||||
)
|
||||
async def mark_notifications_as_read(
|
||||
session: Database,
|
||||
identities: list[_IdentityReq] = Body(default_factory=list),
|
||||
notifications: list[_IdentityReq] = Body(default_factory=list),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
identities.extend(notifications)
|
||||
user_notifications = await _get_notifications(session, current_user, identities)
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
event="read",
|
||||
data={
|
||||
"notifications": [i.model_dump() for i in identities],
|
||||
"read_count": len(user_notifications),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
Reference in New Issue
Block a user