diff --git a/app/database/__init__.py b/app/database/__init__.py index 0283923..d8794c4 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -10,6 +10,13 @@ from .beatmapset import ( BeatmapsetResp, ) from .best_score import BestScore +from .chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, + ChatMessage, + ChatMessageResp, +) from .counts import ( CountResp, MonthlyPlaycounts, @@ -63,6 +70,11 @@ __all__ = [ "Beatmapset", "BeatmapsetResp", "BestScore", + "ChannelType", + "ChatChannel", + "ChatChannelResp", + "ChatMessage", + "ChatMessageResp", "CountResp", "DailyChallengeStats", "DailyChallengeStatsResp", diff --git a/app/database/chat.py b/app/database/chat.py new file mode 100644 index 0000000..565657d --- /dev/null +++ b/app/database/chat.py @@ -0,0 +1,193 @@ +from datetime import UTC, datetime +from enum import Enum +from typing import Self + +from app.database.lazer_user import RANKING_INCLUDES, User, UserResp +from app.models.model import UTCBaseModel + +from pydantic import BaseModel +from redis.asyncio import Redis +from sqlmodel import ( + VARCHAR, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +# ChatChannel + + +class ChatUserAttributes(BaseModel): + can_message: bool + can_message_error: str | None = None + last_read_id: int + + +class ChannelType(str, Enum): + PUBLIC = "PUBLIC" + PRIVATE = "PRIVATE" + MULTIPLAYER = "MULTIPLAYER" + SPECTATOR = "SPECTATOR" + TEMPORARY = "TEMPORARY" + PM = "PM" + GROUP = "GROUP" + SYSTEM = "SYSTEM" + ANNOUNCE = "ANNOUNCE" + TEAM = "TEAM" + + +class ChatChannelBase(SQLModel): + name: str = Field(sa_column=Column(VARCHAR(50), index=True)) + description: str = Field(sa_column=Column(VARCHAR(255), index=True)) + icon: str | None = Field(default=None) + type: ChannelType = Field(index=True) + + +class ChatChannel(ChatChannelBase, table=True): + __tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType] + channel_id: int | None = Field(primary_key=True, index=True, default=None) + + @classmethod + async def get( + cls, channel: str | int, session: AsyncSession + ) -> "ChatChannel | None": + if isinstance(channel, int) or channel.isdigit(): + channel_ = await session.get(ChatChannel, channel) + if channel_ is not None: + return channel_ + return ( + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) + ).first() + + +class ChatChannelResp(ChatChannelBase): + channel_id: int + moderated: bool = False + uuid: str | None = None + current_user_attributes: ChatUserAttributes | None = None + last_read_id: int | None = None + last_message_id: int | None = None + recent_messages: list[str] | None = None + users: list[int] | None = None + message_length_limit: int = 1000 + + @classmethod + async def from_db( + cls, + channel: ChatChannel, + session: AsyncSession, + users: list[int], + user: User, + redis: Redis, + ) -> Self: + c = cls.model_validate(channel) + silence = ( + await session.exec( + select(SilenceUser).where( + SilenceUser.channel_id == channel.channel_id, + SilenceUser.user_id == user.id, + ) + ) + ).first() + + last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg") + if last_msg and last_msg.isdigit(): + last_msg = int(last_msg) + else: + last_msg = None + + last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") + if last_read_id and last_read_id.isdigit(): + last_read_id = int(last_read_id) + else: + last_read_id = last_msg + + if silence is not None: + attribute = ChatUserAttributes( + can_message=False, + can_message_error=silence.reason or "You are muted in this channel.", + last_read_id=last_read_id or 0, + ) + c.moderated = True + else: + attribute = ChatUserAttributes( + can_message=True, + last_read_id=last_read_id or 0, + ) + c.moderated = False + + c.current_user_attributes = attribute + c.users = users + c.last_message_id = last_msg + c.last_read_id = last_read_id + return c + + +# ChatMessage + + +class MessageType(str, Enum): + ACTION = "action" + MARKDOWN = "markdown" + PLAIN = "plain" + + +class ChatMessageBase(UTCBaseModel, SQLModel): + channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") + content: str = Field(sa_column=Column(VARCHAR(1000))) + message_id: int | None = Field(index=True, primary_key=True, default=None) + sender_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + timestamp: datetime = Field( + sa_column=Column(DateTime, index=True), default=datetime.now(UTC) + ) + type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) + uuid: str | None = Field(default=None) + + +class ChatMessage(ChatMessageBase, table=True): + __tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType] + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) + + +class ChatMessageResp(ChatMessageBase): + sender: UserResp | None = None + is_action: bool = False + + @classmethod + async def from_db( + cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None + ) -> "ChatMessageResp": + m = cls.model_validate(db_message.model_dump()) + m.is_action = db_message.type == MessageType.ACTION + if user: + m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES) + else: + m.sender = await UserResp.from_db( + db_message.user, session, RANKING_INCLUDES + ) + return m + + +# SilenceUser + + +class SilenceUser(UTCBaseModel, SQLModel, table=True): + __tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType] + id: int | None = Field(primary_key=True, default=None, index=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True) + until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None) + reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True)) + banned_at: datetime = Field( + sa_column=Column(DateTime, index=True), default=datetime.now(UTC) + ) diff --git a/app/dependencies/database.py b/app/dependencies/database.py index bb4c065..f7e94ce 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import AsyncIterator, Callable from contextvars import ContextVar import json @@ -51,6 +52,17 @@ async def get_db(): yield session +DBFactory = Callable[[], AsyncIterator[AsyncSession]] + + +async def get_db_factory() -> DBFactory: + async def _factory() -> AsyncIterator[AsyncSession]: + async with AsyncSession(engine) as session: + yield session + + return _factory + + # Redis 依赖 def get_redis(): return redis_client diff --git a/app/dependencies/param.py b/app/dependencies/param.py new file mode 100644 index 0000000..174adde --- /dev/null +++ b/app/dependencies/param.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from pydantic import BaseModel, ValidationError + + +def BodyOrForm[T: BaseModel](model: type[T]): + async def dependency( + request: Request, + ) -> T: + content_type = request.headers.get("content-type", "") + + data: dict[str, Any] = {} + if "application/json" in content_type: + try: + data = await request.json() + except Exception: + raise RequestValidationError( + [ + { + "loc": ("body",), + "msg": "Invalid JSON body", + "type": "value_error", + } + ] + ) + else: + form = await request.form() + data = dict(form) + + try: + return model(**data) + except ValidationError as e: + raise RequestValidationError(e.errors()) + + return dependency diff --git a/app/models/chat.py b/app/models/chat.py new file mode 100644 index 0000000..116342f --- /dev/null +++ b/app/models/chat.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + + +class ChatEvent(BaseModel): + event: str + data: dict[str, Any] | None = None diff --git a/app/router/__init__.py b/app/router/__init__.py index 57fc949..6be08fc 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from app.signalr import signalr_router as signalr_router from .auth import router as auth_router +from .chat import chat_router as chat_router from .fetcher import fetcher_router as fetcher_router from .file import file_router as file_router from .private import private_router as private_router @@ -17,6 +18,7 @@ __all__ = [ "api_v1_router", "api_v2_router", "auth_router", + "chat_router", "fetcher_router", "file_router", "private_router", diff --git a/app/router/chat/__init__.py b/app/router/chat/__init__.py new file mode 100644 index 0000000..3ed7897 --- /dev/null +++ b/app/router/chat/__init__.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from app.config import settings +from app.router.v2 import api_v2_router as router + +from . import channel, message # noqa: F401 +from .server import chat_router as chat_router + +from fastapi import Query + +__all__ = ["chat_router"] + + +@router.get("/notifications") +async def get_notifications(max_id: int | None = Query(None)): + if settings.server_url is not None: + notification_endpoint = f"{settings.server_url}notification-server".replace( + "http://", "ws://" + ).replace("https://", "wss://") + else: + notification_endpoint = "/notification-server" + + return { + "has_more": False, + "notifications": [], + "unread_count": 0, + "notification_endpoint": notification_endpoint, + } diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py new file mode 100644 index 0000000..798a5f2 --- /dev/null +++ b/app/router/chat/channel.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Any + +from app.database.chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, +) +from app.database.lazer_user import User, UserResp +from app.dependencies.database import get_db, get_redis +from app.dependencies.user import get_current_user +from app.router.v2 import api_v2_router as router + +from .server import server + +from fastapi import Depends, HTTPException, Query, Security +from pydantic import BaseModel, Field +from redis.asyncio import Redis +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + + +class UpdateResponse(BaseModel): + presence: list[ChatChannelResp] = Field(default_factory=list) + silences: list[Any] = Field(default_factory=list) + + +@router.get("/chat/updates", response_model=UpdateResponse) +async def get_update( + history_since: int | None = Query(None), + since: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + includes: list[str] = Query(["presence"], alias="includes[]"), + redis: Redis = Depends(get_redis), +): + resp = UpdateResponse() + if "presence" in includes: + channel_ids = server.get_user_joined_channel(current_user.id) + for channel_id in channel_ids: + channel = await ChatChannel.get(channel_id, session) + if channel: + resp.presence.append( + await ChatChannelResp.from_db( + channel, + session, + server.channels.get(channel_id, []), + current_user, + redis, + ) + ) + return resp + + +@router.put("/chat/channels/{channel}/users/{user}", response_model=ChatChannelResp) +async def join_channel( + channel: str, + user: str, + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + return await server.join_channel(current_user, db_channel, session) + + +@router.delete( + "/chat/channels/{channel}/users/{user}", + status_code=204, +) +async def leave_channel( + channel: str, + user: str, + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + await server.leave_channel(current_user, db_channel, session) + return + + +@router.get("/chat/channels") +async def get_channel_list( + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + channels = ( + await session.exec( + select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC) + ) + ).all() + results = [] + for channel in channels: + assert channel.channel_id is not None + results.append( + await ChatChannelResp.from_db( + channel, + session, + server.channels.get(channel.channel_id, []), + current_user, + redis, + ) + ) + return results + + +class GetChannelResp(BaseModel): + channel: ChatChannelResp + users: list[UserResp] = Field(default_factory=list) + + +@router.get("/chat/channels/{channel}") +async def get_channel( + channel: str, + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + assert db_channel.channel_id is not None + return GetChannelResp( + channel=await ChatChannelResp.from_db( + db_channel, + session, + server.channels.get(db_channel.channel_id, []), + current_user, + redis, + ) + ) diff --git a/app/router/chat/message.py b/app/router/chat/message.py new file mode 100644 index 0000000..5bab869 --- /dev/null +++ b/app/router/chat/message.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from app.database import ChatMessageResp +from app.database.chat import ChatChannel, ChatMessage, MessageType +from app.database.lazer_user import User +from app.dependencies.database import get_db +from app.dependencies.param import BodyOrForm +from app.dependencies.user import get_current_user +from app.router.v2 import api_v2_router as router + +from .server import server + +from fastapi import Depends, HTTPException, Query, Security +from pydantic import BaseModel +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +@router.post("/chat/ack") +async def keep_alive( + history_since: int | None = Query(None), + since: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + return {"silences": []} + + +class MessageReq(BaseModel): + message: str + is_action: bool = False + uuid: str | None = None + + +@router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp) +async def send_message( + channel: str, + req: MessageReq = Depends(BodyOrForm(MessageReq)), + current_user: User = Security(get_current_user, scopes=["chat.write"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + msg = ChatMessage( + channel_id=db_channel.channel_id, + content=req.message, + sender_id=current_user.id, + type=MessageType.ACTION if req.is_action else MessageType.PLAIN, + uuid=req.uuid, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(current_user) + resp = await ChatMessageResp.from_db(msg, session, current_user) + await server.send_message_to_channel(resp) + return resp + + +@router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp]) +async def get_message( + channel: str, + limit: int = Query(50, ge=1, le=50), + since: int = Query(default=0, ge=0), + until: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + messages = await session.exec( + select(ChatMessage) + .where( + ChatMessage.channel_id == db_channel.channel_id, + col(ChatMessage.message_id) > since, + col(ChatMessage.message_id) < until if until is not None else True, + ) + .order_by(col(ChatMessage.timestamp).desc()) + .limit(limit) + ) + resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] + resp.reverse() + return resp + + +@router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204) +async def mark_as_read( + channel: str, + message: int, + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + await server.mark_as_read(db_channel.channel_id, message) diff --git a/app/router/chat/server.py b/app/router/chat/server.py new file mode 100644 index 0000000..727f426 --- /dev/null +++ b/app/router/chat/server.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio + +from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp +from app.database.lazer_user import User +from app.dependencies.database import DBFactory, get_db_factory, get_redis +from app.dependencies.user import get_current_user +from app.log import logger +from app.models.chat import ChatEvent + +from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect +from fastapi.security import SecurityScopes +from fastapi.websockets import WebSocketState +from redis.asyncio import Redis +from sqlmodel.ext.asyncio.session import AsyncSession + + +class ChatServer: + def __init__(self): + self.connect_client: dict[int, WebSocket] = {} + self.channels: dict[int, list[int]] = {} + self.redis: Redis = get_redis() + + self.tasks: set[asyncio.Task] = set() + + def _add_task(self, task): + task = asyncio.create_task(task) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) + + def connect(self, user_id: int, client: WebSocket): + self.connect_client[user_id] = client + + def get_user_joined_channel(self, user_id: int) -> list[int]: + return [ + channel_id + for channel_id, users in self.channels.items() + if user_id in users + ] + + async def disconnect(self, user: User, session: AsyncSession): + user_id = user.id + if user_id in self.connect_client: + del self.connect_client[user_id] + for channel_id, channel in self.channels.items(): + if user_id in channel: + channel.remove(user_id) + channel = await ChatChannel.get(channel_id, session) + if channel: + await self.leave_channel(user, channel, session) + + async def send_event(self, client: WebSocket, event: ChatEvent): + if client.client_state == WebSocketState.CONNECTED: + await client.send_text(event.model_dump_json()) + + async def broadcast(self, channel_id: int, event: ChatEvent): + for user_id in self.channels.get(channel_id, []): + client = self.connect_client.get(user_id) + if client: + await self.send_event(client, event) + + async def mark_as_read(self, channel_id: int, message_id: int): + await self.redis.set(f"chat:{channel_id}:last_msg", message_id) + + async def send_message_to_channel(self, message: ChatMessageResp): + self._add_task( + self.broadcast( + message.channel_id, + ChatEvent( + event="chat.message.new", + data={"messages": [message], "users": [message.sender]}, + ), + ) + ) + await self.mark_as_read(message.channel_id, message.message_id) + + async def join_channel( + self, user: User, channel: ChatChannel, session: AsyncSession + ) -> ChatChannelResp: + user_id = user.id + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id not in self.channels: + self.channels[channel_id] = [] + if user_id not in self.channels[channel_id]: + self.channels[channel_id].append(user_id) + + channel_resp = await ChatChannelResp.from_db( + channel, session, self.channels[channel_id], user, self.redis + ) + + client = self.connect_client.get(user_id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) + + return channel_resp + + async def leave_channel( + self, user: User, channel: ChatChannel, session: AsyncSession + ) -> None: + user_id = user.id + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id in self.channels and user_id in self.channels[channel_id]: + self.channels[channel_id].remove(user_id) + + if not self.channels.get(channel_id): + del self.channels[channel_id] + + channel_resp = await ChatChannelResp.from_db( + channel, session, self.channels.get(channel_id, []), user, self.redis + ) + client = self.connect_client.get(user_id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.part", + data=channel_resp.model_dump(), + ), + ) + + +server = ChatServer() + +chat_router = APIRouter(include_in_schema=False) + + +async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): + try: + while True: + packets = await ws.receive_json() + if packets.get("event") == "chat.end": + async for session in factory(): + user = await session.get(User, user_id) + if user is None: + break + await server.disconnect(user, session) + await ws.close(code=1000) + break + except WebSocketDisconnect as e: + logger.info( + f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}" + ) + except RuntimeError as e: + if "disconnect message" in str(e): + logger.info(f"[NotificationServer] Client {user_id} closed the connection.") + else: + logger.exception(f"RuntimeError in client {user_id}: {e}") + except Exception: + logger.exception(f"Error in client {user_id}") + + +@chat_router.websocket("/notification-server") +async def chat_websocket( + websocket: WebSocket, + authorization: str = Header(...), + factory: DBFactory = Depends(get_db_factory), +): + async for session in factory(): + token = authorization[7:] + if ( + user := await get_current_user( + SecurityScopes(scopes=["chat.read"]), session, token_pw=token + ) + ) is None: + await websocket.close(code=1008) + return + + await websocket.accept() + login = await websocket.receive_json() + if login.get("event") != "chat.start": + await websocket.close(code=1008) + return + user_id = user.id + assert user_id + server.connect(user_id, websocket) + channel = await ChatChannel.get(1, session) + if channel is not None: + await server.join_channel(user, channel, session) + await _listen_stop(websocket, user_id, factory) diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py index 144dfd0..8fc5654 100644 --- a/app/service/subscribers/base.py +++ b/app/service/subscribers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable +from fnmatch import fnmatch from typing import Any from app.dependencies.database import get_redis_pubsub @@ -29,12 +30,17 @@ class RedisSubscriber: ignore_subscribe_messages=True, timeout=None ) if message is not None and message["type"] == "message": - method = self.handlers.get(message["channel"]) - if method: + matched_handlers = [] + if message["channel"] in self.handlers: + matched_handlers.extend(self.handlers[message["channel"]]) + for pattern, handlers in self.handlers.items(): + if fnmatch(message["channel"], pattern): + matched_handlers.extend(handlers) + if matched_handlers: await asyncio.gather( *[ handler(message["channel"], message["data"]) - for handler in method + for handler in matched_handlers ] ) diff --git a/main.py b/main.py index c2739c4..66471cc 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ from app.router import ( api_v1_router, api_v2_router, auth_router, + chat_router, fetcher_router, file_router, private_router, @@ -71,6 +72,7 @@ app = FastAPI( app.include_router(api_v2_router) app.include_router(api_v1_router) +app.include_router(chat_router) app.include_router(redirect_api_router) app.include_router(signalr_router) app.include_router(fetcher_router) diff --git a/migrations/versions/dd33d89aa2c2_chat_add_chat.py b/migrations/versions/dd33d89aa2c2_chat_add_chat.py new file mode 100644 index 0000000..f548f55 --- /dev/null +++ b/migrations/versions/dd33d89aa2c2_chat_add_chat.py @@ -0,0 +1,192 @@ +"""chat: add chat + +Revision ID: dd33d89aa2c2 +Revises: 9f6b27e8ea51 +Create Date: 2025-08-15 14:22:34.775877 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "dd33d89aa2c2" +down_revision: str | Sequence[str] | None = "9f6b27e8ea51" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + channel_table = op.create_table( + "chat_channels", + sa.Column("name", sa.VARCHAR(length=50), nullable=True), + sa.Column("description", sa.VARCHAR(length=255), nullable=True), + sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column( + "type", + sa.Enum( + "PUBLIC", + "PRIVATE", + "MULTIPLAYER", + "SPECTATOR", + "TEMPORARY", + "PM", + "GROUP", + "SYSTEM", + "ANNOUNCE", + "TEAM", + name="channeltype", + ), + nullable=False, + ), + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("channel_id"), + ) + op.create_index( + op.f("ix_chat_channels_channel_id"), + "chat_channels", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_channels_description"), + "chat_channels", + ["description"], + unique=False, + ) + op.create_index( + op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False + ) + op.create_index( + op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False + ) + op.create_table( + "chat_messages", + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.Column("content", sa.VARCHAR(length=1000), nullable=True), + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("sender_id", sa.BigInteger(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column( + "type", + sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"), + nullable=False, + ), + sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["channel_id"], + ["chat_channels.channel_id"], + ), + sa.ForeignKeyConstraint( + ["sender_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_index( + op.f("ix_chat_messages_channel_id"), + "chat_messages", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_messages_message_id"), + "chat_messages", + ["message_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False + ) + op.create_index( + op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False + ) + op.create_index( + op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False + ) + op.create_table( + "chat_silence_users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.Column("until", sa.DateTime(), nullable=True), + sa.Column("banned_at", sa.DateTime(), nullable=False), + sa.Column("reason", sa.VARCHAR(length=255), nullable=True), + sa.ForeignKeyConstraint( + ["channel_id"], + ["chat_channels.channel_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_chat_silence_users_channel_id"), + "chat_silence_users", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_reason"), + "chat_silence_users", + ["reason"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_until"), + "chat_silence_users", + ["until"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_banned_at"), + "chat_silence_users", + ["banned_at"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_user_id"), + "chat_silence_users", + ["user_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_id"), + "chat_silence_users", + ["id"], + unique=False, + ) + op.bulk_insert( + channel_table, + [ + { + "name": "osu!", + "description": "General discussion for osu!", + "type": "PUBLIC", + }, + { + "name": "announce", + "description": "Official announcements", + "type": "PUBLIC", + }, + ], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("chat_silence_users") + op.drop_table("chat_messages") + op.drop_table("chat_channels") + # ### end Alembic commands ###