143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
from app.config import settings
|
|
from app.database.lazer_user import User
|
|
from app.database.notification import Notification, UserNotification
|
|
from app.dependencies.database import Database
|
|
from app.dependencies.user import get_client_user
|
|
from app.models.chat import ChatEvent
|
|
from app.router.v2 import api_v2_router as router
|
|
|
|
from . import channel, message # noqa: F401
|
|
from .server import (
|
|
chat_router as chat_router,
|
|
server,
|
|
)
|
|
|
|
from fastapi import Body, Query, Security
|
|
from pydantic import BaseModel
|
|
from sqlmodel import col, func, select
|
|
|
|
__all__ = ["chat_router"]
|
|
|
|
|
|
class NotificationResp(BaseModel):
|
|
has_more: bool
|
|
notifications: list[Notification]
|
|
unread_count: int
|
|
notification_endpoint: str
|
|
|
|
|
|
@router.get(
|
|
"/notifications",
|
|
tags=["通知", "聊天"],
|
|
name="获取通知",
|
|
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
|
response_model=NotificationResp,
|
|
)
|
|
async def get_notifications(
|
|
session: Database,
|
|
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
|
current_user: User = Security(get_client_user),
|
|
):
|
|
if settings.server_url is not None:
|
|
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
|
|
"https://", "wss://"
|
|
)
|
|
else:
|
|
notification_endpoint = "/notification-server"
|
|
query = select(UserNotification).where(
|
|
UserNotification.user_id == current_user.id,
|
|
col(UserNotification.is_read).is_(False),
|
|
)
|
|
if max_id is not None:
|
|
query = query.where(UserNotification.notification_id < max_id)
|
|
notifications = (await session.exec(query)).all()
|
|
total_count = (
|
|
await session.exec(
|
|
select(func.count())
|
|
.select_from(UserNotification)
|
|
.where(
|
|
UserNotification.user_id == current_user.id,
|
|
col(UserNotification.is_read).is_(False),
|
|
)
|
|
)
|
|
).one()
|
|
unread_count = len(notifications)
|
|
|
|
return NotificationResp(
|
|
has_more=unread_count < total_count,
|
|
notifications=[notification.notification for notification in notifications],
|
|
unread_count=unread_count,
|
|
notification_endpoint=notification_endpoint,
|
|
)
|
|
|
|
|
|
class _IdentityReq(BaseModel):
|
|
category: str | None = None
|
|
id: int | None = None
|
|
object_id: int | None = None
|
|
object_type: int | None = None
|
|
|
|
|
|
async def _get_notifications(
|
|
session: Database, current_user: User, identities: list[_IdentityReq]
|
|
) -> list[UserNotification]:
|
|
result: dict[int, UserNotification] = {}
|
|
base_query = select(UserNotification).where(
|
|
UserNotification.user_id == current_user.id,
|
|
col(UserNotification.is_read).is_(False),
|
|
)
|
|
for identity in identities:
|
|
query = base_query
|
|
if identity.id is not None:
|
|
query = base_query.where(UserNotification.notification_id == identity.id)
|
|
if identity.object_id is not None:
|
|
query = base_query.where(
|
|
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
|
|
)
|
|
if identity.object_type is not None:
|
|
query = base_query.where(
|
|
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
|
|
)
|
|
if identity.category is not None:
|
|
query = base_query.where(
|
|
col(UserNotification.notification).has(col(Notification.category) == identity.category)
|
|
)
|
|
result.update({n.notification_id: n for n in await session.exec(query)})
|
|
return list(result.values())
|
|
|
|
|
|
@router.post(
|
|
"/notifications/mark-read",
|
|
tags=["通知", "聊天"],
|
|
name="标记通知为已读",
|
|
description="标记当前用户的通知为已读。",
|
|
status_code=204,
|
|
)
|
|
async def mark_notifications_as_read(
|
|
session: Database,
|
|
identities: list[_IdentityReq] = Body(default_factory=list),
|
|
notifications: list[_IdentityReq] = Body(default_factory=list),
|
|
current_user: User = Security(get_client_user),
|
|
):
|
|
identities.extend(notifications)
|
|
user_notifications = await _get_notifications(session, current_user, identities)
|
|
for user_notification in user_notifications:
|
|
user_notification.is_read = True
|
|
|
|
await server.send_event(
|
|
current_user.id,
|
|
ChatEvent(
|
|
event="read",
|
|
data={
|
|
"notifications": [i.model_dump() for i in identities],
|
|
"read_count": len(user_notifications),
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
},
|
|
),
|
|
)
|
|
await session.commit()
|