From 9fb0d0c198cff77cc32c416a2198856b3f202d3f Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 21 Aug 2025 07:22:44 +0000 Subject: [PATCH] feat(notification): support notification --- app/database/__init__.py | 3 + app/database/chat.py | 1 + app/database/notification.py | 73 +++++++ app/models/notification.py | 184 ++++++++++++++++++ app/router/__init__.py | 2 +- app/router/chat/__init__.py | 35 ---- app/router/notification/__init__.py | 149 ++++++++++++++ .../{chat => notification}/banchobot.py | 0 app/router/{chat => notification}/channel.py | 0 app/router/{chat => notification}/message.py | 10 + app/router/{chat => notification}/server.py | 95 +++++---- app/utils.py | 6 + ...6c43d8601_notification_add_notification.py | 138 +++++++++++++ 13 files changed, 626 insertions(+), 70 deletions(-) create mode 100644 app/database/notification.py create mode 100644 app/models/notification.py delete mode 100644 app/router/chat/__init__.py create mode 100644 app/router/notification/__init__.py rename app/router/{chat => notification}/banchobot.py (100%) rename app/router/{chat => notification}/channel.py (100%) rename app/router/{chat => notification}/message.py (94%) rename app/router/{chat => notification}/server.py (82%) create mode 100644 migrations/versions/4f46c43d8601_notification_add_notification.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 5c58b47..6ff1d21 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -29,6 +29,7 @@ from .lazer_user import ( UserResp, ) from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp +from .notification import Notification, UserNotification from .playlist_attempts import ( ItemAttemptsCount, ItemAttemptsResp, @@ -86,6 +87,7 @@ __all__ = [ "MultiplayerEvent", "MultiplayerEventResp", "MultiplayerScores", + "Notification", "OAuthClient", "OAuthToken", "PPBestScore", @@ -120,6 +122,7 @@ __all__ = [ "UserAchievement", "UserAchievementResp", "UserLoginLog", + "UserNotification", "UserResp", "UserStatistics", "UserStatisticsResp", diff --git a/app/database/chat.py b/app/database/chat.py index 29334d6..9647b48 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -190,6 +190,7 @@ class ChatMessageBase(UTCBaseModel, SQLModel): class ChatMessage(ChatMessageBase, table=True): __tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType] user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) + channel: ChatChannel = Relationship() class ChatMessageResp(ChatMessageBase): diff --git a/app/database/notification.py b/app/database/notification.py new file mode 100644 index 0000000..a0f568b --- /dev/null +++ b/app/database/notification.py @@ -0,0 +1,73 @@ +from datetime import UTC, datetime +from typing import Any + +from app.models.notification import NotificationDetail, NotificationName + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) +from sqlmodel.ext.asyncio.session import AsyncSession + + +class Notification(SQLModel, table=True): + __tablename__ = "notifications" # pyright: ignore[reportAssignmentType] + + id: int = Field(primary_key=True, index=True, default=None) + name: NotificationName = Field(index=True) + category: str = Field(max_length=255, index=True) + created_at: datetime = Field(sa_column=Column(DateTime)) + object_type: str = Field(index=True) + object_id: int = Field(sa_column=Column(BigInteger, index=True)) + source_user_id: int = Field(index=True) + details: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON)) + + +class UserNotification(SQLModel, table=True): + __tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType] + id: int = Field( + sa_column=Column( + BigInteger, + primary_key=True, + index=True, + ), + default=None, + ) + notification_id: int = Field(index=True, foreign_key="notifications.id") + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + is_read: bool = Field(index=True) + + notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"}) + + +async def insert_notification(session: AsyncSession, detail: NotificationDetail): + notification = Notification( + name=detail.name, + category=detail.name.category, + object_type=detail.object_type, + object_id=detail.object_id, + source_user_id=detail.source_user_id, + details=detail.model_dump(), + created_at=datetime.now(UTC), + ) + session.add(notification) + await session.commit() + await session.refresh(notification) + id_ = notification.id + for receiver in await detail.get_receivers(session): + user_notification = UserNotification( + notification_id=id_, + user_id=receiver, + is_read=False, + ) + session.add(user_notification) + await session.commit() + return id_ diff --git a/app/models/notification.py b/app/models/notification.py new file mode 100644 index 0000000..14a44fd --- /dev/null +++ b/app/models/notification.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from abc import abstractmethod +from enum import Enum +from typing import TYPE_CHECKING + +from app.utils import truncate + +from pydantic import BaseModel, PrivateAttr +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +CONTENT_TRUNCATE = 36 + +if TYPE_CHECKING: + from app.database import ChannelType, ChatMessage, User + + +# https://github.com/ppy/osu-web/blob/master/app/Models/Notification.php +class NotificationName(str, Enum): + BEATMAP_OWNER_CHANGE = "beatmap_owner_change" + BEATMAPSET_DISCUSSION_LOCK = "beatmapset_discussion_lock" + BEATMAPSET_DISCUSSION_POST_NEW = "beatmapset_discussion_post_new" + BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM = "beatmapset_discussion_qualified_problem" + BEATMAPSET_DISCUSSION_REVIEW_NEW = "beatmapset_discussion_review_new" + BEATMAPSET_DISCUSSION_UNLOCK = "beatmapset_discussion_unlock" + BEATMAPSET_DISQUALIFY = "beatmapset_disqualify" + BEATMAPSET_LOVE = "beatmapset_love" + BEATMAPSET_NOMINATE = "beatmapset_nominate" + BEATMAPSET_QUALIFY = "beatmapset_qualify" + BEATMAPSET_RANK = "beatmapset_rank" + BEATMAPSET_REMOVE_FROM_LOVED = "beatmapset_remove_from_loved" + BEATMAPSET_RESET_NOMINATIONS = "beatmapset_reset_nominations" + CHANNEL_ANNOUNCEMENT = "channel_announcement" + CHANNEL_MESSAGE = "channel_message" + CHANNEL_TEAM = "channel_team" + COMMENT_NEW = "comment_new" + FORUM_TOPIC_REPLY = "forum_topic_reply" + TEAM_APPLICATION_ACCEPT = "team_application_accept" + TEAM_APPLICATION_REJECT = "team_application_reject" + TEAM_APPLICATION_STORE = "team_application_store" + USER_ACHIEVEMENT_UNLOCK = "user_achievement_unlock" + USER_BEATMAPSET_NEW = "user_beatmapset_new" + USER_BEATMAPSET_REVIVE = "user_beatmapset_revive" + + # NAME_TO_CATEGORY + @property + def category(self) -> str: + return { + NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change", + NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion", + NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion", + NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501 + NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion", + NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion", + NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state", + NotificationName.BEATMAPSET_LOVE: "beatmapset_state", + NotificationName.BEATMAPSET_NOMINATE: "beatmapset_state", + NotificationName.BEATMAPSET_QUALIFY: "beatmapset_state", + NotificationName.BEATMAPSET_RANK: "beatmapset_state", + NotificationName.BEATMAPSET_REMOVE_FROM_LOVED: "beatmapset_state", + NotificationName.BEATMAPSET_RESET_NOMINATIONS: "beatmapset_state", + NotificationName.CHANNEL_ANNOUNCEMENT: "announcement", + NotificationName.CHANNEL_MESSAGE: "channel", + NotificationName.CHANNEL_TEAM: "channel_team", + NotificationName.COMMENT_NEW: "comment", + NotificationName.FORUM_TOPIC_REPLY: "forum_topic_reply", + NotificationName.TEAM_APPLICATION_ACCEPT: "team_application", + NotificationName.TEAM_APPLICATION_REJECT: "team_application", + NotificationName.TEAM_APPLICATION_STORE: "team_application", + NotificationName.USER_ACHIEVEMENT_UNLOCK: "user_achievement_unlock", + NotificationName.USER_BEATMAPSET_NEW: "user_beatmapset_new", + NotificationName.USER_BEATMAPSET_REVIVE: "user_beatmapset_new", + }[self] + + +class NotificationDetail(BaseModel): + @property + @abstractmethod + def name(self) -> NotificationName: + raise NotImplementedError + + @property + @abstractmethod + def object_type(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def object_id(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def source_user_id(self) -> int: + raise NotImplementedError + + @abstractmethod + async def get_receivers(self, session: AsyncSession) -> list[int]: + raise NotImplementedError + + +class ChannelMessageBase(NotificationDetail): + title: str = "" + type: str = "" + cover_url: str = "" + + _message: "ChatMessage" = PrivateAttr() + _user: "User" = PrivateAttr() + _receiver: list[int] = PrivateAttr() + + def __init__( + self, + message: "ChatMessage", + user: "User", + receiver: list[int], + channel_type: "ChannelType", + ) -> None: + super().__init__( + title=truncate(message.content, CONTENT_TRUNCATE), + type=channel_type.value.lower(), + cover_url=user.avatar_url, + ) + self._message = message + self._user = user + self._receiver = receiver + + async def get_receivers(self, session: AsyncSession) -> list[int]: + return self._receiver + + @property + def source_user_id(self) -> int: + return self._user.id + + @property + def object_type(self) -> str: + return "channel" + + @property + def object_id(self) -> int: + return self._message.channel_id + + +class ChannelMessage(ChannelMessageBase): + def __init__( + self, + message: "ChatMessage", + user: "User", + receiver: list[int], + channel_type: "ChannelType", + ) -> None: + super().__init__(message, user, receiver, channel_type) + + @property + def name(self) -> NotificationName: + return NotificationName.CHANNEL_MESSAGE + + +class ChannelMessageTeam(ChannelMessageBase): + def __init__(self, message: "ChatMessage", user: "User") -> None: + from app.database import ChannelType + + super().__init__(message, user, [], ChannelType.TEAM) + + @property + def name(self) -> NotificationName: + return NotificationName.CHANNEL_TEAM + + async def get_receivers(self, session: AsyncSession) -> list[int]: + from app.database import TeamMember + + user_team_id = ( + await session.exec( + select(TeamMember.team_id).where(TeamMember.user_id == self._user.id) + ) + ).first() + if not user_team_id: + return [] + user_ids = ( + await session.exec( + select(TeamMember.user_id).where(TeamMember.team_id == user_team_id) + ) + ).all() + return list(user_ids) diff --git a/app/router/__init__.py b/app/router/__init__.py index 6be08fc..3905fb0 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -3,9 +3,9 @@ from __future__ import annotations from app.signalr import signalr_router as signalr_router from .auth import router as auth_router -from .chat import chat_router as chat_router from .fetcher import fetcher_router as fetcher_router from .file import file_router as file_router +from .notification import chat_router as chat_router from .private import private_router as private_router from .redirect import ( redirect_api_router as redirect_api_router, diff --git a/app/router/chat/__init__.py b/app/router/chat/__init__.py deleted file mode 100644 index f9fc0b3..0000000 --- a/app/router/chat/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from app.config import settings -from app.router.v2 import api_v2_router as router - -from . import channel, message # noqa: F401 -from .server import chat_router as chat_router - -from fastapi import Query - -__all__ = ["chat_router"] - - -@router.get( - "/notifications", - tags=["通知", "聊天"], - name="获取通知", - description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。", -) -async def get_notifications( - max_id: int | None = Query(None, description="获取 ID 小于此值的通知"), -): - if settings.server_url is not None: - notification_endpoint = f"{settings.server_url}notification-server".replace( - "http://", "ws://" - ).replace("https://", "wss://") - else: - notification_endpoint = "/notification-server" - - return { - "has_more": False, - "notifications": [], - "unread_count": 0, - "notification_endpoint": notification_endpoint, - } diff --git a/app/router/notification/__init__.py b/app/router/notification/__init__.py new file mode 100644 index 0000000..206abd8 --- /dev/null +++ b/app/router/notification/__init__.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from app.config import settings +from app.database.lazer_user import User +from app.database.notification import Notification, UserNotification +from app.dependencies.database import Database +from app.dependencies.user import get_client_user +from app.models.chat import ChatEvent +from app.router.v2 import api_v2_router as router + +from . import channel, message # noqa: F401 +from .server import ( + chat_router as chat_router, + server, +) + +from fastapi import Body, Query, Security +from pydantic import BaseModel +from sqlmodel import col, func, select + +__all__ = ["chat_router"] + + +class NotificationResp(BaseModel): + has_more: bool + notifications: list[Notification] + unread_count: int + notification_endpoint: str + + +@router.get( + "/notifications", + tags=["通知", "聊天"], + name="获取通知", + description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。", + response_model=NotificationResp, +) +async def get_notifications( + session: Database, + max_id: int | None = Query(None, description="获取 ID 小于此值的通知"), + current_user: User = Security(get_client_user), +): + if settings.server_url is not None: + notification_endpoint = f"{settings.server_url}notification-server".replace( + "http://", "ws://" + ).replace("https://", "wss://") + else: + notification_endpoint = "/notification-server" + query = select(UserNotification).where( + UserNotification.user_id == current_user.id, + col(UserNotification.is_read).is_(False), + ) + if max_id is not None: + query = query.where(UserNotification.notification_id < max_id) + notifications = (await session.exec(query)).all() + total_count = ( + await session.exec( + select(func.count()) + .select_from(UserNotification) + .where( + UserNotification.user_id == current_user.id, + col(UserNotification.is_read).is_(False), + ) + ) + ).one() + unread_count = len(notifications) + + return NotificationResp( + has_more=unread_count < total_count, + notifications=[notification.notification for notification in notifications], + unread_count=unread_count, + notification_endpoint=notification_endpoint, + ) + + +class _IdentityReq(BaseModel): + category: str | None = None + id: int | None = None + object_id: int | None = None + object_type: int | None = None + + +async def _get_notifications( + session: Database, current_user: User, identities: list[_IdentityReq] +) -> list[UserNotification]: + result: dict[int, UserNotification] = {} + base_query = select(UserNotification).where( + UserNotification.user_id == current_user.id, + col(UserNotification.is_read).is_(False), + ) + for identity in identities: + query = base_query + if identity.id is not None: + query = base_query.where(UserNotification.notification_id == identity.id) + if identity.object_id is not None: + query = base_query.where( + col(UserNotification.notification).has( + col(Notification.object_id) == identity.object_id + ) + ) + if identity.object_type is not None: + query = base_query.where( + col(UserNotification.notification).has( + col(Notification.object_type) == identity.object_type + ) + ) + if identity.category is not None: + query = base_query.where( + col(UserNotification.notification).has( + col(Notification.category) == identity.category + ) + ) + result.update({n.notification_id: n for n in await session.exec(query)}) + return list(result.values()) + + +@router.post( + "/notifications/mark-read", + tags=["通知", "聊天"], + name="标记通知为已读", + description="标记当前用户的通知为已读。", + status_code=204, +) +async def mark_notifications_as_read( + session: Database, + identities: list[_IdentityReq] = Body(default_factory=list), + notifications: list[_IdentityReq] = Body(default_factory=list), + current_user: User = Security(get_client_user), +): + identities.extend(notifications) + user_notifications = await _get_notifications(session, current_user, identities) + for user_notification in user_notifications: + user_notification.is_read = True + + assert current_user.id + await server.send_event( + current_user.id, + ChatEvent( + event="read", + data={ + "notifications": [i.model_dump() for i in identities], + "read_count": len(user_notifications), + "timestamp": datetime.now(UTC).isoformat(), + }, + ), + ) + await session.commit() diff --git a/app/router/chat/banchobot.py b/app/router/notification/banchobot.py similarity index 100% rename from app/router/chat/banchobot.py rename to app/router/notification/banchobot.py diff --git a/app/router/chat/channel.py b/app/router/notification/channel.py similarity index 100% rename from app/router/chat/channel.py rename to app/router/notification/channel.py diff --git a/app/router/chat/message.py b/app/router/notification/message.py similarity index 94% rename from app/router/chat/message.py rename to app/router/notification/message.py index 45888bc..b9594aa 100644 --- a/app/router/chat/message.py +++ b/app/router/notification/message.py @@ -14,6 +14,7 @@ from app.database.lazer_user import User from app.dependencies.database import Database, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user +from app.models.notification import ChannelMessage, ChannelMessageTeam from app.router.v2 import api_v2_router as router from .banchobot import bot @@ -113,6 +114,15 @@ async def send_message( ) if is_bot_command: await bot.try_handle(current_user, db_channel, req.message, session) + if db_channel.type == ChannelType.PM: + user_ids = db_channel.name.split("_")[1:] + await server.new_private_notification( + ChannelMessage( + msg, current_user, [int(u) for u in user_ids], db_channel.type + ) + ) + elif db_channel.type == ChannelType.TEAM: + await server.new_private_notification(ChannelMessageTeam(msg, current_user)) return resp diff --git a/app/router/chat/server.py b/app/router/notification/server.py similarity index 82% rename from app/router/chat/server.py rename to app/router/notification/server.py index 88990a8..0f891c3 100644 --- a/app/router/chat/server.py +++ b/app/router/notification/server.py @@ -1,9 +1,11 @@ from __future__ import annotations import asyncio +from typing import overload from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User +from app.database.notification import UserNotification, insert_notification from app.dependencies.database import ( DBFactory, get_db_factory, @@ -13,12 +15,14 @@ from app.dependencies.database import ( from app.dependencies.user import get_current_user from app.log import logger from app.models.chat import ChatEvent +from app.models.notification import NotificationDetail from app.service.subscribers.chat import ChatSubscriber from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes from fastapi.websockets import WebSocketState from redis.asyncio import Redis +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -59,15 +63,24 @@ class ChatServer: if channel: await self.leave_channel(user, channel, session) - async def send_event(self, client: WebSocket, event: ChatEvent): + @overload + async def send_event(self, client: int, event: ChatEvent): ... + + @overload + async def send_event(self, client: WebSocket, event: ChatEvent): ... + + async def send_event(self, client: WebSocket | int, event: ChatEvent): + if isinstance(client, int): + client_ = self.connect_client.get(client) + if client_ is None: + return + client = client_ if client.client_state == WebSocketState.CONNECTED: await client.send_text(event.model_dump_json()) async def broadcast(self, channel_id: int, event: ChatEvent): for user_id in self.channels.get(channel_id, []): - client = self.connect_client.get(user_id) - if client: - await self.send_event(client, event) + await self.send_event(user_id, event) async def mark_as_read(self, channel_id: int, user_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id) @@ -80,9 +93,7 @@ class ChatServer: data={"messages": [message], "users": [message.sender]}, ) if is_bot_command: - client = self.connect_client.get(message.sender_id) - if client: - self._add_task(self.send_event(client, event)) + self._add_task(self.send_event(message.sender_id, event)) else: self._add_task( self.broadcast( @@ -123,15 +134,13 @@ class ChatServer: if channel.type != ChannelType.PUBLIC else None, ) - client = self.connect_client.get(user.id) - if client: - await self.send_event( - client, - ChatEvent( - event="chat.channel.join", - data=channel_resp.model_dump(), - ), - ) + await self.send_event( + user.id, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) async def join_channel( self, user: User, channel: ChatChannel, session: AsyncSession @@ -154,15 +163,13 @@ class ChatServer: self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) - client = self.connect_client.get(user_id) - if client: - await self.send_event( - client, - ChatEvent( - event="chat.channel.join", - data=channel_resp.model_dump(), - ), - ) + await self.send_event( + user_id, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) return channel_resp @@ -189,15 +196,13 @@ class ChatServer: if channel.type != ChannelType.PUBLIC else None, ) - client = self.connect_client.get(user_id) - if client: - await self.send_event( - client, - ChatEvent( - event="chat.channel.part", - data=channel_resp.model_dump(), - ), - ) + await self.send_event( + user_id, + ChatEvent( + event="chat.channel.part", + data=channel_resp.model_dump(), + ), + ) async def join_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: @@ -223,6 +228,28 @@ class ChatServer: await self.leave_channel(user, channel, session) + async def new_private_notification(self, detail: NotificationDetail): + async with with_db() as session: + id = await insert_notification(session, detail) + users = ( + await session.exec( + select(UserNotification).where( + UserNotification.notification_id == id + ) + ) + ).all() + for user_notification in users: + data = user_notification.notification.model_dump() + data["is_read"] = user_notification.is_read + data["details"] = user_notification.notification.details + await server.send_event( + user_notification.user_id, + ChatEvent( + event="new", + data=data, + ), + ) + server = ChatServer() diff --git a/app/utils.py b/app/utils.py index 444eb79..e8e932d 100644 --- a/app/utils.py +++ b/app/utils.py @@ -118,3 +118,9 @@ def are_adjacent_weeks(dt1: datetime, dt2: datetime) -> bool: def are_same_weeks(dt1: datetime, dt2: datetime) -> bool: return dt1.isocalendar()[:2] == dt2.isocalendar()[:2] + + +def truncate(text: str, limit: int = 100, ellipsis: str = "...") -> str: + if len(text) > limit: + return text[:limit] + ellipsis + return text diff --git a/migrations/versions/4f46c43d8601_notification_add_notification.py b/migrations/versions/4f46c43d8601_notification_add_notification.py new file mode 100644 index 0000000..15fbc57 --- /dev/null +++ b/migrations/versions/4f46c43d8601_notification_add_notification.py @@ -0,0 +1,138 @@ +"""notification: add notification + +Revision ID: 4f46c43d8601 +Revises: 2fcfc28846c1 +Create Date: 2025-08-21 07:03:45.813547 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "4f46c43d8601" +down_revision: str | Sequence[str] | None = "2fcfc28846c1" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "notifications", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "name", + sa.Enum( + "BEATMAP_OWNER_CHANGE", + "BEATMAPSET_DISCUSSION_LOCK", + "BEATMAPSET_DISCUSSION_POST_NEW", + "BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM", + "BEATMAPSET_DISCUSSION_REVIEW_NEW", + "BEATMAPSET_DISCUSSION_UNLOCK", + "BEATMAPSET_DISQUALIFY", + "BEATMAPSET_LOVE", + "BEATMAPSET_NOMINATE", + "BEATMAPSET_QUALIFY", + "BEATMAPSET_RANK", + "BEATMAPSET_REMOVE_FROM_LOVED", + "BEATMAPSET_RESET_NOMINATIONS", + "CHANNEL_ANNOUNCEMENT", + "CHANNEL_MESSAGE", + "CHANNEL_TEAM", + "COMMENT_NEW", + "FORUM_TOPIC_REPLY", + "TEAM_APPLICATION_ACCEPT", + "TEAM_APPLICATION_REJECT", + "TEAM_APPLICATION_STORE", + "USER_ACHIEVEMENT_UNLOCK", + "USER_BEATMAPSET_NEW", + "USER_BEATMAPSET_REVIVE", + name="notificationname", + ), + nullable=False, + ), + sa.Column( + "category", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False + ), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("object_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("object_id", sa.BigInteger(), nullable=True), + sa.Column("source_user_id", sa.Integer(), nullable=False), + sa.Column("details", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_notifications_category"), "notifications", ["category"], unique=False + ) + op.create_index(op.f("ix_notifications_id"), "notifications", ["id"], unique=False) + op.create_index( + op.f("ix_notifications_name"), "notifications", ["name"], unique=False + ) + op.create_index( + op.f("ix_notifications_object_id"), "notifications", ["object_id"], unique=False + ) + op.create_index( + op.f("ix_notifications_object_type"), + "notifications", + ["object_type"], + unique=False, + ) + op.create_index( + op.f("ix_notifications_source_user_id"), + "notifications", + ["source_user_id"], + unique=False, + ) + op.create_table( + "user_notifications", + sa.Column("id", sa.BigInteger(), nullable=False), + sa.Column("notification_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("is_read", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["notification_id"], + ["notifications.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_notifications_id"), "user_notifications", ["id"], unique=False + ) + op.create_index( + op.f("ix_user_notifications_is_read"), + "user_notifications", + ["is_read"], + unique=False, + ) + op.create_index( + op.f("ix_user_notifications_notification_id"), + "user_notifications", + ["notification_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_notifications_user_id"), + "user_notifications", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_notifications") + op.drop_table("notifications") + # ### end Alembic commands ###