feat(chat): support public channel chat
This commit is contained in:
@@ -10,6 +10,13 @@ from .beatmapset import (
|
||||
BeatmapsetResp,
|
||||
)
|
||||
from .best_score import BestScore
|
||||
from .chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
ChatMessageResp,
|
||||
)
|
||||
from .counts import (
|
||||
CountResp,
|
||||
MonthlyPlaycounts,
|
||||
@@ -63,6 +70,11 @@ __all__ = [
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BestScore",
|
||||
"ChannelType",
|
||||
"ChatChannel",
|
||||
"ChatChannelResp",
|
||||
"ChatMessage",
|
||||
"ChatMessageResp",
|
||||
"CountResp",
|
||||
"DailyChallengeStats",
|
||||
"DailyChallengeStatsResp",
|
||||
|
||||
193
app/database/chat.py
Normal file
193
app/database/chat.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
VARCHAR,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# ChatChannel
|
||||
|
||||
|
||||
class ChatUserAttributes(BaseModel):
|
||||
can_message: bool
|
||||
can_message_error: str | None = None
|
||||
last_read_id: int
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
PUBLIC = "PUBLIC"
|
||||
PRIVATE = "PRIVATE"
|
||||
MULTIPLAYER = "MULTIPLAYER"
|
||||
SPECTATOR = "SPECTATOR"
|
||||
TEMPORARY = "TEMPORARY"
|
||||
PM = "PM"
|
||||
GROUP = "GROUP"
|
||||
SYSTEM = "SYSTEM"
|
||||
ANNOUNCE = "ANNOUNCE"
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class ChatChannelBase(SQLModel):
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||
icon: str | None = Field(default=None)
|
||||
type: ChannelType = Field(index=True)
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
channel_ = await session.get(ChatChannel, channel)
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
return (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
class ChatChannelResp(ChatChannelBase):
|
||||
channel_id: int
|
||||
moderated: bool = False
|
||||
uuid: str | None = None
|
||||
current_user_attributes: ChatUserAttributes | None = None
|
||||
last_read_id: int | None = None
|
||||
last_message_id: int | None = None
|
||||
recent_messages: list[str] | None = None
|
||||
users: list[int] | None = None
|
||||
message_length_limit: int = 1000
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
channel: ChatChannel,
|
||||
session: AsyncSession,
|
||||
users: list[int],
|
||||
user: User,
|
||||
redis: Redis,
|
||||
) -> Self:
|
||||
c = cls.model_validate(channel)
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
if last_msg and last_msg.isdigit():
|
||||
last_msg = int(last_msg)
|
||||
else:
|
||||
last_msg = None
|
||||
|
||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
if last_read_id and last_read_id.isdigit():
|
||||
last_read_id = int(last_read_id)
|
||||
else:
|
||||
last_read_id = last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=False,
|
||||
can_message_error=silence.reason or "You are muted in this channel.",
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = True
|
||||
else:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=True,
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = False
|
||||
|
||||
c.current_user_attributes = attribute
|
||||
c.users = users
|
||||
c.last_message_id = last_msg
|
||||
c.last_read_id = last_read_id
|
||||
return c
|
||||
|
||||
|
||||
# ChatMessage
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
|
||||
class ChatMessageResp(ChatMessageBase):
|
||||
sender: UserResp | None = None
|
||||
is_action: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||
) -> "ChatMessageResp":
|
||||
m = cls.model_validate(db_message.model_dump())
|
||||
m.is_action = db_message.type == MessageType.ACTION
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
# SilenceUser
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextvars import ContextVar
|
||||
import json
|
||||
|
||||
@@ -51,6 +52,17 @@ async def get_db():
|
||||
yield session
|
||||
|
||||
|
||||
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
||||
|
||||
|
||||
async def get_db_factory() -> DBFactory:
|
||||
async def _factory() -> AsyncIterator[AsyncSession]:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
39
app/dependencies/param.py
Normal file
39
app/dependencies/param.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
def BodyOrForm[T: BaseModel](model: type[T]):
|
||||
async def dependency(
|
||||
request: Request,
|
||||
) -> T:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
raise RequestValidationError(
|
||||
[
|
||||
{
|
||||
"loc": ("body",),
|
||||
"msg": "Invalid JSON body",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
form = await request.form()
|
||||
data = dict(form)
|
||||
|
||||
try:
|
||||
return model(**data)
|
||||
except ValidationError as e:
|
||||
raise RequestValidationError(e.errors())
|
||||
|
||||
return dependency
|
||||
10
app/models/chat.py
Normal file
10
app/models/chat.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatEvent(BaseModel):
|
||||
event: str
|
||||
data: dict[str, Any] | None = None
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from .auth import router as auth_router
|
||||
from .chat import chat_router as chat_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
from .private import private_router as private_router
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"api_v1_router",
|
||||
"api_v2_router",
|
||||
"auth_router",
|
||||
"chat_router",
|
||||
"fetcher_router",
|
||||
"file_router",
|
||||
"private_router",
|
||||
|
||||
28
app/router/chat/__init__.py
Normal file
28
app/router/chat/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import chat_router as chat_router
|
||||
|
||||
from fastapi import Query
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
@router.get("/notifications")
|
||||
async def get_notifications(max_id: int | None = Query(None)):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
|
||||
return {
|
||||
"has_more": False,
|
||||
"notifications": [],
|
||||
"unread_count": 0,
|
||||
"notification_endpoint": notification_endpoint,
|
||||
}
|
||||
138
app/router/chat/channel.py
Normal file
138
app/router/chat/channel.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
)
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||
silences: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get("/chat/updates", response_model=UpdateResponse)
|
||||
async def get_update(
|
||||
history_since: int | None = Query(None),
|
||||
since: int | None = Query(None),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
includes: list[str] = Query(["presence"], alias="includes[]"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
server.channels.get(channel_id, []),
|
||||
current_user,
|
||||
redis,
|
||||
)
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.put("/chat/channels/{channel}/users/{user}", response_model=ChatChannelResp)
|
||||
async def join_channel(
|
||||
channel: str,
|
||||
user: str,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return await server.join_channel(current_user, db_channel, session)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
status_code=204,
|
||||
)
|
||||
async def leave_channel(
|
||||
channel: str,
|
||||
user: str,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.leave_channel(current_user, db_channel, session)
|
||||
return
|
||||
|
||||
|
||||
@router.get("/chat/channels")
|
||||
async def get_channel_list(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
assert channel.channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
server.channels.get(channel.channel_id, []),
|
||||
current_user,
|
||||
redis,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class GetChannelResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
users: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get("/chat/channels/{channel}")
|
||||
async def get_channel(
|
||||
channel: str,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id is not None
|
||||
return GetChannelResp(
|
||||
channel=await ChatChannelResp.from_db(
|
||||
db_channel,
|
||||
session,
|
||||
server.channels.get(db_channel.channel_id, []),
|
||||
current_user,
|
||||
redis,
|
||||
)
|
||||
)
|
||||
98
app/router/chat/message.py
Normal file
98
app/router/chat/message.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import ChatChannel, ChatMessage, MessageType
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.post("/chat/ack")
|
||||
async def keep_alive(
|
||||
history_since: int | None = Query(None),
|
||||
since: int | None = Query(None),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return {"silences": []}
|
||||
|
||||
|
||||
class MessageReq(BaseModel):
|
||||
message: str
|
||||
is_action: bool = False
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
@router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp)
|
||||
async def send_message(
|
||||
channel: str,
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
msg = ChatMessage(
|
||||
channel_id=db_channel.channel_id,
|
||||
content=req.message,
|
||||
sender_id=current_user.id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
await server.send_message_to_channel(resp)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp])
|
||||
async def get_message(
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50),
|
||||
since: int = Query(default=0, ge=0),
|
||||
until: int | None = Query(None),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
messages = await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.channel_id == db_channel.channel_id,
|
||||
col(ChatMessage.message_id) > since,
|
||||
col(ChatMessage.message_id) < until if until is not None else True,
|
||||
)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
resp.reverse()
|
||||
return resp
|
||||
|
||||
|
||||
@router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204)
|
||||
async def mark_as_read(
|
||||
channel: str,
|
||||
message: int,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.mark_as_read(db_channel.channel_id, message)
|
||||
190
app/router/chat/server.py
Normal file
190
app/router/chat/server.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ChatServer:
|
||||
def __init__(self):
|
||||
self.connect_client: dict[int, WebSocket] = {}
|
||||
self.channels: dict[int, list[int]] = {}
|
||||
self.redis: Redis = get_redis()
|
||||
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
if user_id in self.connect_client:
|
||||
del self.connect_client[user_id]
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
async def send_event(self, client: WebSocket, event: ChatEvent):
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(client, event)
|
||||
|
||||
async def mark_as_read(self, channel_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
|
||||
|
||||
async def send_message_to_channel(self, message: ChatMessageResp):
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.mark_as_read(message.channel_id, message.message_id)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, self.channels[channel_id], user, self.redis
|
||||
)
|
||||
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
|
||||
if not self.channels.get(channel_id):
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, self.channels.get(channel_id, []), user, self.redis
|
||||
)
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
chat_router = APIRouter(include_in_schema=False)
|
||||
|
||||
|
||||
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
try:
|
||||
while True:
|
||||
packets = await ws.receive_json()
|
||||
if packets.get("event") == "chat.end":
|
||||
async for session in factory():
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
break
|
||||
await server.disconnect(user, session)
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {user_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {user_id}")
|
||||
|
||||
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
|
||||
)
|
||||
) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
login = await websocket.receive_json()
|
||||
if login.get("event") != "chat.start":
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
channel = await ChatChannel.get(1, session)
|
||||
if channel is not None:
|
||||
await server.join_channel(user, channel, session)
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any
|
||||
|
||||
from app.dependencies.database import get_redis_pubsub
|
||||
@@ -29,12 +30,17 @@ class RedisSubscriber:
|
||||
ignore_subscribe_messages=True, timeout=None
|
||||
)
|
||||
if message is not None and message["type"] == "message":
|
||||
method = self.handlers.get(message["channel"])
|
||||
if method:
|
||||
matched_handlers = []
|
||||
if message["channel"] in self.handlers:
|
||||
matched_handlers.extend(self.handlers[message["channel"]])
|
||||
for pattern, handlers in self.handlers.items():
|
||||
if fnmatch(message["channel"], pattern):
|
||||
matched_handlers.extend(handlers)
|
||||
if matched_handlers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
handler(message["channel"], message["data"])
|
||||
for handler in method
|
||||
for handler in matched_handlers
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
2
main.py
2
main.py
@@ -13,6 +13,7 @@ from app.router import (
|
||||
api_v1_router,
|
||||
api_v2_router,
|
||||
auth_router,
|
||||
chat_router,
|
||||
fetcher_router,
|
||||
file_router,
|
||||
private_router,
|
||||
@@ -71,6 +72,7 @@ app = FastAPI(
|
||||
|
||||
app.include_router(api_v2_router)
|
||||
app.include_router(api_v1_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(redirect_api_router)
|
||||
app.include_router(signalr_router)
|
||||
app.include_router(fetcher_router)
|
||||
|
||||
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""chat: add chat
|
||||
|
||||
Revision ID: dd33d89aa2c2
|
||||
Revises: 9f6b27e8ea51
|
||||
Create Date: 2025-08-15 14:22:34.775877
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dd33d89aa2c2"
|
||||
down_revision: str | Sequence[str] | None = "9f6b27e8ea51"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
channel_table = op.create_table(
|
||||
"chat_channels",
|
||||
sa.Column("name", sa.VARCHAR(length=50), nullable=True),
|
||||
sa.Column("description", sa.VARCHAR(length=255), nullable=True),
|
||||
sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"PUBLIC",
|
||||
"PRIVATE",
|
||||
"MULTIPLAYER",
|
||||
"SPECTATOR",
|
||||
"TEMPORARY",
|
||||
"PM",
|
||||
"GROUP",
|
||||
"SYSTEM",
|
||||
"ANNOUNCE",
|
||||
"TEAM",
|
||||
name="channeltype",
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("channel_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_channel_id"),
|
||||
"chat_channels",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_description"),
|
||||
"chat_channels",
|
||||
["description"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"chat_messages",
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.Column("content", sa.VARCHAR(length=1000), nullable=True),
|
||||
sa.Column("message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("sender_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("timestamp", sa.DateTime(), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["channel_id"],
|
||||
["chat_channels.channel_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["sender_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_channel_id"),
|
||||
"chat_messages",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_message_id"),
|
||||
"chat_messages",
|
||||
["message_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"chat_silence_users",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.Column("until", sa.DateTime(), nullable=True),
|
||||
sa.Column("banned_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("reason", sa.VARCHAR(length=255), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["channel_id"],
|
||||
["chat_channels.channel_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_channel_id"),
|
||||
"chat_silence_users",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_reason"),
|
||||
"chat_silence_users",
|
||||
["reason"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_until"),
|
||||
"chat_silence_users",
|
||||
["until"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_banned_at"),
|
||||
"chat_silence_users",
|
||||
["banned_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_user_id"),
|
||||
"chat_silence_users",
|
||||
["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_id"),
|
||||
"chat_silence_users",
|
||||
["id"],
|
||||
unique=False,
|
||||
)
|
||||
op.bulk_insert(
|
||||
channel_table,
|
||||
[
|
||||
{
|
||||
"name": "osu!",
|
||||
"description": "General discussion for osu!",
|
||||
"type": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "announce",
|
||||
"description": "Official announcements",
|
||||
"type": "PUBLIC",
|
||||
},
|
||||
],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("chat_silence_users")
|
||||
op.drop_table("chat_messages")
|
||||
op.drop_table("chat_channels")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user