feat(chat): support public channel chat
This commit is contained in:
@@ -10,6 +10,13 @@ from .beatmapset import (
|
|||||||
BeatmapsetResp,
|
BeatmapsetResp,
|
||||||
)
|
)
|
||||||
from .best_score import BestScore
|
from .best_score import BestScore
|
||||||
|
from .chat import (
|
||||||
|
ChannelType,
|
||||||
|
ChatChannel,
|
||||||
|
ChatChannelResp,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageResp,
|
||||||
|
)
|
||||||
from .counts import (
|
from .counts import (
|
||||||
CountResp,
|
CountResp,
|
||||||
MonthlyPlaycounts,
|
MonthlyPlaycounts,
|
||||||
@@ -63,6 +70,11 @@ __all__ = [
|
|||||||
"Beatmapset",
|
"Beatmapset",
|
||||||
"BeatmapsetResp",
|
"BeatmapsetResp",
|
||||||
"BestScore",
|
"BestScore",
|
||||||
|
"ChannelType",
|
||||||
|
"ChatChannel",
|
||||||
|
"ChatChannelResp",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatMessageResp",
|
||||||
"CountResp",
|
"CountResp",
|
||||||
"DailyChallengeStats",
|
"DailyChallengeStats",
|
||||||
"DailyChallengeStatsResp",
|
"DailyChallengeStatsResp",
|
||||||
|
|||||||
193
app/database/chat.py
Normal file
193
app/database/chat.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||||
|
from app.models.model import UTCBaseModel
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from sqlmodel import (
|
||||||
|
VARCHAR,
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Field,
|
||||||
|
ForeignKey,
|
||||||
|
Relationship,
|
||||||
|
SQLModel,
|
||||||
|
select,
|
||||||
|
)
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
# ChatChannel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatUserAttributes(BaseModel):
|
||||||
|
can_message: bool
|
||||||
|
can_message_error: str | None = None
|
||||||
|
last_read_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelType(str, Enum):
|
||||||
|
PUBLIC = "PUBLIC"
|
||||||
|
PRIVATE = "PRIVATE"
|
||||||
|
MULTIPLAYER = "MULTIPLAYER"
|
||||||
|
SPECTATOR = "SPECTATOR"
|
||||||
|
TEMPORARY = "TEMPORARY"
|
||||||
|
PM = "PM"
|
||||||
|
GROUP = "GROUP"
|
||||||
|
SYSTEM = "SYSTEM"
|
||||||
|
ANNOUNCE = "ANNOUNCE"
|
||||||
|
TEAM = "TEAM"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannelBase(SQLModel):
|
||||||
|
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||||
|
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||||
|
icon: str | None = Field(default=None)
|
||||||
|
type: ChannelType = Field(index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannel(ChatChannelBase, table=True):
|
||||||
|
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||||
|
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(
|
||||||
|
cls, channel: str | int, session: AsyncSession
|
||||||
|
) -> "ChatChannel | None":
|
||||||
|
if isinstance(channel, int) or channel.isdigit():
|
||||||
|
channel_ = await session.get(ChatChannel, channel)
|
||||||
|
if channel_ is not None:
|
||||||
|
return channel_
|
||||||
|
return (
|
||||||
|
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||||
|
).first()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatChannelResp(ChatChannelBase):
|
||||||
|
channel_id: int
|
||||||
|
moderated: bool = False
|
||||||
|
uuid: str | None = None
|
||||||
|
current_user_attributes: ChatUserAttributes | None = None
|
||||||
|
last_read_id: int | None = None
|
||||||
|
last_message_id: int | None = None
|
||||||
|
recent_messages: list[str] | None = None
|
||||||
|
users: list[int] | None = None
|
||||||
|
message_length_limit: int = 1000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_db(
|
||||||
|
cls,
|
||||||
|
channel: ChatChannel,
|
||||||
|
session: AsyncSession,
|
||||||
|
users: list[int],
|
||||||
|
user: User,
|
||||||
|
redis: Redis,
|
||||||
|
) -> Self:
|
||||||
|
c = cls.model_validate(channel)
|
||||||
|
silence = (
|
||||||
|
await session.exec(
|
||||||
|
select(SilenceUser).where(
|
||||||
|
SilenceUser.channel_id == channel.channel_id,
|
||||||
|
SilenceUser.user_id == user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||||
|
if last_msg and last_msg.isdigit():
|
||||||
|
last_msg = int(last_msg)
|
||||||
|
else:
|
||||||
|
last_msg = None
|
||||||
|
|
||||||
|
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||||
|
if last_read_id and last_read_id.isdigit():
|
||||||
|
last_read_id = int(last_read_id)
|
||||||
|
else:
|
||||||
|
last_read_id = last_msg
|
||||||
|
|
||||||
|
if silence is not None:
|
||||||
|
attribute = ChatUserAttributes(
|
||||||
|
can_message=False,
|
||||||
|
can_message_error=silence.reason or "You are muted in this channel.",
|
||||||
|
last_read_id=last_read_id or 0,
|
||||||
|
)
|
||||||
|
c.moderated = True
|
||||||
|
else:
|
||||||
|
attribute = ChatUserAttributes(
|
||||||
|
can_message=True,
|
||||||
|
last_read_id=last_read_id or 0,
|
||||||
|
)
|
||||||
|
c.moderated = False
|
||||||
|
|
||||||
|
c.current_user_attributes = attribute
|
||||||
|
c.users = users
|
||||||
|
c.last_message_id = last_msg
|
||||||
|
c.last_read_id = last_read_id
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
ACTION = "action"
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
PLAIN = "plain"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||||
|
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||||
|
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||||
|
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||||
|
sender_id: int = Field(
|
||||||
|
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||||
|
)
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||||
|
)
|
||||||
|
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||||
|
uuid: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(ChatMessageBase, table=True):
|
||||||
|
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||||
|
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageResp(ChatMessageBase):
|
||||||
|
sender: UserResp | None = None
|
||||||
|
is_action: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_db(
|
||||||
|
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||||
|
) -> "ChatMessageResp":
|
||||||
|
m = cls.model_validate(db_message.model_dump())
|
||||||
|
m.is_action = db_message.type == MessageType.ACTION
|
||||||
|
if user:
|
||||||
|
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||||
|
else:
|
||||||
|
m.sender = await UserResp.from_db(
|
||||||
|
db_message.user, session, RANKING_INCLUDES
|
||||||
|
)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
# SilenceUser
|
||||||
|
|
||||||
|
|
||||||
|
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||||
|
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||||
|
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||||
|
user_id: int = Field(
|
||||||
|
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||||
|
)
|
||||||
|
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||||
|
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||||
|
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||||
|
banned_at: datetime = Field(
|
||||||
|
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||||
|
)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Callable
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -51,6 +52,17 @@ async def get_db():
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_factory() -> DBFactory:
|
||||||
|
async def _factory() -> AsyncIterator[AsyncSession]:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
return _factory
|
||||||
|
|
||||||
|
|
||||||
# Redis 依赖
|
# Redis 依赖
|
||||||
def get_redis():
|
def get_redis():
|
||||||
return redis_client
|
return redis_client
|
||||||
|
|||||||
39
app/dependencies/param.py
Normal file
39
app/dependencies/param.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
def BodyOrForm[T: BaseModel](model: type[T]):
|
||||||
|
async def dependency(
|
||||||
|
request: Request,
|
||||||
|
) -> T:
|
||||||
|
content_type = request.headers.get("content-type", "")
|
||||||
|
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except Exception:
|
||||||
|
raise RequestValidationError(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"loc": ("body",),
|
||||||
|
"msg": "Invalid JSON body",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
form = await request.form()
|
||||||
|
data = dict(form)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return model(**data)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise RequestValidationError(e.errors())
|
||||||
|
|
||||||
|
return dependency
|
||||||
10
app/models/chat.py
Normal file
10
app/models/chat.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEvent(BaseModel):
|
||||||
|
event: str
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from app.signalr import signalr_router as signalr_router
|
from app.signalr import signalr_router as signalr_router
|
||||||
|
|
||||||
from .auth import router as auth_router
|
from .auth import router as auth_router
|
||||||
|
from .chat import chat_router as chat_router
|
||||||
from .fetcher import fetcher_router as fetcher_router
|
from .fetcher import fetcher_router as fetcher_router
|
||||||
from .file import file_router as file_router
|
from .file import file_router as file_router
|
||||||
from .private import private_router as private_router
|
from .private import private_router as private_router
|
||||||
@@ -17,6 +18,7 @@ __all__ = [
|
|||||||
"api_v1_router",
|
"api_v1_router",
|
||||||
"api_v2_router",
|
"api_v2_router",
|
||||||
"auth_router",
|
"auth_router",
|
||||||
|
"chat_router",
|
||||||
"fetcher_router",
|
"fetcher_router",
|
||||||
"file_router",
|
"file_router",
|
||||||
"private_router",
|
"private_router",
|
||||||
|
|||||||
28
app/router/chat/__init__.py
Normal file
28
app/router/chat/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.router.v2 import api_v2_router as router
|
||||||
|
|
||||||
|
from . import channel, message # noqa: F401
|
||||||
|
from .server import chat_router as chat_router
|
||||||
|
|
||||||
|
from fastapi import Query
|
||||||
|
|
||||||
|
__all__ = ["chat_router"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/notifications")
|
||||||
|
async def get_notifications(max_id: int | None = Query(None)):
|
||||||
|
if settings.server_url is not None:
|
||||||
|
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||||
|
"http://", "ws://"
|
||||||
|
).replace("https://", "wss://")
|
||||||
|
else:
|
||||||
|
notification_endpoint = "/notification-server"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"has_more": False,
|
||||||
|
"notifications": [],
|
||||||
|
"unread_count": 0,
|
||||||
|
"notification_endpoint": notification_endpoint,
|
||||||
|
}
|
||||||
138
app/router/chat/channel.py
Normal file
138
app/router/chat/channel.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.database.chat import (
|
||||||
|
ChannelType,
|
||||||
|
ChatChannel,
|
||||||
|
ChatChannelResp,
|
||||||
|
)
|
||||||
|
from app.database.lazer_user import User, UserResp
|
||||||
|
from app.dependencies.database import get_db, get_redis
|
||||||
|
from app.dependencies.user import get_current_user
|
||||||
|
from app.router.v2 import api_v2_router as router
|
||||||
|
|
||||||
|
from .server import server
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Query, Security
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateResponse(BaseModel):
|
||||||
|
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||||
|
silences: list[Any] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat/updates", response_model=UpdateResponse)
|
||||||
|
async def get_update(
|
||||||
|
history_since: int | None = Query(None),
|
||||||
|
since: int | None = Query(None),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
includes: list[str] = Query(["presence"], alias="includes[]"),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
):
|
||||||
|
resp = UpdateResponse()
|
||||||
|
if "presence" in includes:
|
||||||
|
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||||
|
for channel_id in channel_ids:
|
||||||
|
channel = await ChatChannel.get(channel_id, session)
|
||||||
|
if channel:
|
||||||
|
resp.presence.append(
|
||||||
|
await ChatChannelResp.from_db(
|
||||||
|
channel,
|
||||||
|
session,
|
||||||
|
server.channels.get(channel_id, []),
|
||||||
|
current_user,
|
||||||
|
redis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/chat/channels/{channel}/users/{user}", response_model=ChatChannelResp)
|
||||||
|
async def join_channel(
|
||||||
|
channel: str,
|
||||||
|
user: str,
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
return await server.join_channel(current_user, db_channel, session)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/chat/channels/{channel}/users/{user}",
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def leave_channel(
|
||||||
|
channel: str,
|
||||||
|
user: str,
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
await server.leave_channel(current_user, db_channel, session)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat/channels")
|
||||||
|
async def get_channel_list(
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
):
|
||||||
|
channels = (
|
||||||
|
await session.exec(
|
||||||
|
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
results = []
|
||||||
|
for channel in channels:
|
||||||
|
assert channel.channel_id is not None
|
||||||
|
results.append(
|
||||||
|
await ChatChannelResp.from_db(
|
||||||
|
channel,
|
||||||
|
session,
|
||||||
|
server.channels.get(channel.channel_id, []),
|
||||||
|
current_user,
|
||||||
|
redis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class GetChannelResp(BaseModel):
|
||||||
|
channel: ChatChannelResp
|
||||||
|
users: list[UserResp] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat/channels/{channel}")
|
||||||
|
async def get_channel(
|
||||||
|
channel: str,
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
assert db_channel.channel_id is not None
|
||||||
|
return GetChannelResp(
|
||||||
|
channel=await ChatChannelResp.from_db(
|
||||||
|
db_channel,
|
||||||
|
session,
|
||||||
|
server.channels.get(db_channel.channel_id, []),
|
||||||
|
current_user,
|
||||||
|
redis,
|
||||||
|
)
|
||||||
|
)
|
||||||
98
app/router/chat/message.py
Normal file
98
app/router/chat/message.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.database import ChatMessageResp
|
||||||
|
from app.database.chat import ChatChannel, ChatMessage, MessageType
|
||||||
|
from app.database.lazer_user import User
|
||||||
|
from app.dependencies.database import get_db
|
||||||
|
from app.dependencies.param import BodyOrForm
|
||||||
|
from app.dependencies.user import get_current_user
|
||||||
|
from app.router.v2 import api_v2_router as router
|
||||||
|
|
||||||
|
from .server import server
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Query, Security
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlmodel import col, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/ack")
|
||||||
|
async def keep_alive(
|
||||||
|
history_since: int | None = Query(None),
|
||||||
|
since: int | None = Query(None),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return {"silences": []}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageReq(BaseModel):
|
||||||
|
message: str
|
||||||
|
is_action: bool = False
|
||||||
|
uuid: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp)
|
||||||
|
async def send_message(
|
||||||
|
channel: str,
|
||||||
|
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
msg = ChatMessage(
|
||||||
|
channel_id=db_channel.channel_id,
|
||||||
|
content=req.message,
|
||||||
|
sender_id=current_user.id,
|
||||||
|
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||||
|
uuid=req.uuid,
|
||||||
|
)
|
||||||
|
session.add(msg)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(msg)
|
||||||
|
await session.refresh(current_user)
|
||||||
|
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||||
|
await server.send_message_to_channel(resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp])
|
||||||
|
async def get_message(
|
||||||
|
channel: str,
|
||||||
|
limit: int = Query(50, ge=1, le=50),
|
||||||
|
since: int = Query(default=0, ge=0),
|
||||||
|
until: int | None = Query(None),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
messages = await session.exec(
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(
|
||||||
|
ChatMessage.channel_id == db_channel.channel_id,
|
||||||
|
col(ChatMessage.message_id) > since,
|
||||||
|
col(ChatMessage.message_id) < until if until is not None else True,
|
||||||
|
)
|
||||||
|
.order_by(col(ChatMessage.timestamp).desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||||
|
resp.reverse()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204)
|
||||||
|
async def mark_as_read(
|
||||||
|
channel: str,
|
||||||
|
message: int,
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
|
if db_channel is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
await server.mark_as_read(db_channel.channel_id, message)
|
||||||
190
app/router/chat/server.py
Normal file
190
app/router/chat/server.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp
|
||||||
|
from app.database.lazer_user import User
|
||||||
|
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
||||||
|
from app.dependencies.user import get_current_user
|
||||||
|
from app.log import logger
|
||||||
|
from app.models.chat import ChatEvent
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
from fastapi.websockets import WebSocketState
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
class ChatServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.connect_client: dict[int, WebSocket] = {}
|
||||||
|
self.channels: dict[int, list[int]] = {}
|
||||||
|
self.redis: Redis = get_redis()
|
||||||
|
|
||||||
|
self.tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
def _add_task(self, task):
|
||||||
|
task = asyncio.create_task(task)
|
||||||
|
self.tasks.add(task)
|
||||||
|
task.add_done_callback(self.tasks.discard)
|
||||||
|
|
||||||
|
def connect(self, user_id: int, client: WebSocket):
|
||||||
|
self.connect_client[user_id] = client
|
||||||
|
|
||||||
|
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||||
|
return [
|
||||||
|
channel_id
|
||||||
|
for channel_id, users in self.channels.items()
|
||||||
|
if user_id in users
|
||||||
|
]
|
||||||
|
|
||||||
|
async def disconnect(self, user: User, session: AsyncSession):
|
||||||
|
user_id = user.id
|
||||||
|
if user_id in self.connect_client:
|
||||||
|
del self.connect_client[user_id]
|
||||||
|
for channel_id, channel in self.channels.items():
|
||||||
|
if user_id in channel:
|
||||||
|
channel.remove(user_id)
|
||||||
|
channel = await ChatChannel.get(channel_id, session)
|
||||||
|
if channel:
|
||||||
|
await self.leave_channel(user, channel, session)
|
||||||
|
|
||||||
|
async def send_event(self, client: WebSocket, event: ChatEvent):
|
||||||
|
if client.client_state == WebSocketState.CONNECTED:
|
||||||
|
await client.send_text(event.model_dump_json())
|
||||||
|
|
||||||
|
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||||
|
for user_id in self.channels.get(channel_id, []):
|
||||||
|
client = self.connect_client.get(user_id)
|
||||||
|
if client:
|
||||||
|
await self.send_event(client, event)
|
||||||
|
|
||||||
|
async def mark_as_read(self, channel_id: int, message_id: int):
|
||||||
|
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
|
||||||
|
|
||||||
|
async def send_message_to_channel(self, message: ChatMessageResp):
|
||||||
|
self._add_task(
|
||||||
|
self.broadcast(
|
||||||
|
message.channel_id,
|
||||||
|
ChatEvent(
|
||||||
|
event="chat.message.new",
|
||||||
|
data={"messages": [message], "users": [message.sender]},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.mark_as_read(message.channel_id, message.message_id)
|
||||||
|
|
||||||
|
async def join_channel(
|
||||||
|
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||||
|
) -> ChatChannelResp:
|
||||||
|
user_id = user.id
|
||||||
|
channel_id = channel.channel_id
|
||||||
|
assert channel_id is not None
|
||||||
|
|
||||||
|
if channel_id not in self.channels:
|
||||||
|
self.channels[channel_id] = []
|
||||||
|
if user_id not in self.channels[channel_id]:
|
||||||
|
self.channels[channel_id].append(user_id)
|
||||||
|
|
||||||
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
|
channel, session, self.channels[channel_id], user, self.redis
|
||||||
|
)
|
||||||
|
|
||||||
|
client = self.connect_client.get(user_id)
|
||||||
|
if client:
|
||||||
|
await self.send_event(
|
||||||
|
client,
|
||||||
|
ChatEvent(
|
||||||
|
event="chat.channel.join",
|
||||||
|
data=channel_resp.model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel_resp
|
||||||
|
|
||||||
|
async def leave_channel(
|
||||||
|
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
user_id = user.id
|
||||||
|
channel_id = channel.channel_id
|
||||||
|
assert channel_id is not None
|
||||||
|
|
||||||
|
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||||
|
self.channels[channel_id].remove(user_id)
|
||||||
|
|
||||||
|
if not self.channels.get(channel_id):
|
||||||
|
del self.channels[channel_id]
|
||||||
|
|
||||||
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
|
channel, session, self.channels.get(channel_id, []), user, self.redis
|
||||||
|
)
|
||||||
|
client = self.connect_client.get(user_id)
|
||||||
|
if client:
|
||||||
|
await self.send_event(
|
||||||
|
client,
|
||||||
|
ChatEvent(
|
||||||
|
event="chat.channel.part",
|
||||||
|
data=channel_resp.model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
server = ChatServer()
|
||||||
|
|
||||||
|
chat_router = APIRouter(include_in_schema=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
packets = await ws.receive_json()
|
||||||
|
if packets.get("event") == "chat.end":
|
||||||
|
async for session in factory():
|
||||||
|
user = await session.get(User, user_id)
|
||||||
|
if user is None:
|
||||||
|
break
|
||||||
|
await server.disconnect(user, session)
|
||||||
|
await ws.close(code=1000)
|
||||||
|
break
|
||||||
|
except WebSocketDisconnect as e:
|
||||||
|
logger.info(
|
||||||
|
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "disconnect message" in str(e):
|
||||||
|
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||||
|
else:
|
||||||
|
logger.exception(f"RuntimeError in client {user_id}: {e}")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Error in client {user_id}")
|
||||||
|
|
||||||
|
|
||||||
|
@chat_router.websocket("/notification-server")
|
||||||
|
async def chat_websocket(
|
||||||
|
websocket: WebSocket,
|
||||||
|
authorization: str = Header(...),
|
||||||
|
factory: DBFactory = Depends(get_db_factory),
|
||||||
|
):
|
||||||
|
async for session in factory():
|
||||||
|
token = authorization[7:]
|
||||||
|
if (
|
||||||
|
user := await get_current_user(
|
||||||
|
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
|
||||||
|
)
|
||||||
|
) is None:
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
login = await websocket.receive_json()
|
||||||
|
if login.get("event") != "chat.start":
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
user_id = user.id
|
||||||
|
assert user_id
|
||||||
|
server.connect(user_id, websocket)
|
||||||
|
channel = await ChatChannel.get(1, session)
|
||||||
|
if channel is not None:
|
||||||
|
await server.join_channel(user, channel, session)
|
||||||
|
await _listen_stop(websocket, user_id, factory)
|
||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from fnmatch import fnmatch
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.dependencies.database import get_redis_pubsub
|
from app.dependencies.database import get_redis_pubsub
|
||||||
@@ -29,12 +30,17 @@ class RedisSubscriber:
|
|||||||
ignore_subscribe_messages=True, timeout=None
|
ignore_subscribe_messages=True, timeout=None
|
||||||
)
|
)
|
||||||
if message is not None and message["type"] == "message":
|
if message is not None and message["type"] == "message":
|
||||||
method = self.handlers.get(message["channel"])
|
matched_handlers = []
|
||||||
if method:
|
if message["channel"] in self.handlers:
|
||||||
|
matched_handlers.extend(self.handlers[message["channel"]])
|
||||||
|
for pattern, handlers in self.handlers.items():
|
||||||
|
if fnmatch(message["channel"], pattern):
|
||||||
|
matched_handlers.extend(handlers)
|
||||||
|
if matched_handlers:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
handler(message["channel"], message["data"])
|
handler(message["channel"], message["data"])
|
||||||
for handler in method
|
for handler in matched_handlers
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
2
main.py
2
main.py
@@ -13,6 +13,7 @@ from app.router import (
|
|||||||
api_v1_router,
|
api_v1_router,
|
||||||
api_v2_router,
|
api_v2_router,
|
||||||
auth_router,
|
auth_router,
|
||||||
|
chat_router,
|
||||||
fetcher_router,
|
fetcher_router,
|
||||||
file_router,
|
file_router,
|
||||||
private_router,
|
private_router,
|
||||||
@@ -71,6 +72,7 @@ app = FastAPI(
|
|||||||
|
|
||||||
app.include_router(api_v2_router)
|
app.include_router(api_v2_router)
|
||||||
app.include_router(api_v1_router)
|
app.include_router(api_v1_router)
|
||||||
|
app.include_router(chat_router)
|
||||||
app.include_router(redirect_api_router)
|
app.include_router(redirect_api_router)
|
||||||
app.include_router(signalr_router)
|
app.include_router(signalr_router)
|
||||||
app.include_router(fetcher_router)
|
app.include_router(fetcher_router)
|
||||||
|
|||||||
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""chat: add chat
|
||||||
|
|
||||||
|
Revision ID: dd33d89aa2c2
|
||||||
|
Revises: 9f6b27e8ea51
|
||||||
|
Create Date: 2025-08-15 14:22:34.775877
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "dd33d89aa2c2"
|
||||||
|
down_revision: str | Sequence[str] | None = "9f6b27e8ea51"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
channel_table = op.create_table(
|
||||||
|
"chat_channels",
|
||||||
|
sa.Column("name", sa.VARCHAR(length=50), nullable=True),
|
||||||
|
sa.Column("description", sa.VARCHAR(length=255), nullable=True),
|
||||||
|
sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"type",
|
||||||
|
sa.Enum(
|
||||||
|
"PUBLIC",
|
||||||
|
"PRIVATE",
|
||||||
|
"MULTIPLAYER",
|
||||||
|
"SPECTATOR",
|
||||||
|
"TEMPORARY",
|
||||||
|
"PM",
|
||||||
|
"GROUP",
|
||||||
|
"SYSTEM",
|
||||||
|
"ANNOUNCE",
|
||||||
|
"TEAM",
|
||||||
|
name="channeltype",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("channel_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_channels_channel_id"),
|
||||||
|
"chat_channels",
|
||||||
|
["channel_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_channels_description"),
|
||||||
|
"chat_channels",
|
||||||
|
["description"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"chat_messages",
|
||||||
|
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("content", sa.VARCHAR(length=1000), nullable=True),
|
||||||
|
sa.Column("message_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("sender_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("timestamp", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"type",
|
||||||
|
sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["channel_id"],
|
||||||
|
["chat_channels.channel_id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["sender_id"],
|
||||||
|
["lazer_users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("message_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_messages_channel_id"),
|
||||||
|
"chat_messages",
|
||||||
|
["channel_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_messages_message_id"),
|
||||||
|
"chat_messages",
|
||||||
|
["message_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"chat_silence_users",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("until", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("banned_at", sa.DateTime(), nullable=False),
|
||||||
|
sa.Column("reason", sa.VARCHAR(length=255), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["channel_id"],
|
||||||
|
["chat_channels.channel_id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["lazer_users.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_channel_id"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["channel_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_reason"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["reason"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_until"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["until"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_banned_at"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["banned_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_user_id"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_chat_silence_users_id"),
|
||||||
|
"chat_silence_users",
|
||||||
|
["id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.bulk_insert(
|
||||||
|
channel_table,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "osu!",
|
||||||
|
"description": "General discussion for osu!",
|
||||||
|
"type": "PUBLIC",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "announce",
|
||||||
|
"description": "Official announcements",
|
||||||
|
"type": "PUBLIC",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("chat_silence_users")
|
||||||
|
op.drop_table("chat_messages")
|
||||||
|
op.drop_table("chat_channels")
|
||||||
|
# ### end Alembic commands ###
|
||||||
Reference in New Issue
Block a user