3
app/const.py
Normal file
3
app/const.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
BANCHOBOT_ID = 2
|
||||
@@ -10,6 +10,13 @@ from .beatmapset import (
|
||||
BeatmapsetResp,
|
||||
)
|
||||
from .best_score import BestScore
|
||||
from .chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
ChatMessageResp,
|
||||
)
|
||||
from .counts import (
|
||||
CountResp,
|
||||
MonthlyPlaycounts,
|
||||
@@ -63,6 +70,11 @@ __all__ = [
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BestScore",
|
||||
"ChannelType",
|
||||
"ChatChannel",
|
||||
"ChatChannelResp",
|
||||
"ChatMessage",
|
||||
"ChatMessageResp",
|
||||
"CountResp",
|
||||
"DailyChallengeStats",
|
||||
"DailyChallengeStatsResp",
|
||||
|
||||
240
app/database/chat.py
Normal file
240
app/database/chat.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
VARCHAR,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# ChatChannel
|
||||
|
||||
|
||||
class ChatUserAttributes(BaseModel):
|
||||
can_message: bool
|
||||
can_message_error: str | None = None
|
||||
last_read_id: int
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
PUBLIC = "PUBLIC"
|
||||
PRIVATE = "PRIVATE"
|
||||
MULTIPLAYER = "MULTIPLAYER"
|
||||
SPECTATOR = "SPECTATOR"
|
||||
TEMPORARY = "TEMPORARY"
|
||||
PM = "PM"
|
||||
GROUP = "GROUP"
|
||||
SYSTEM = "SYSTEM"
|
||||
ANNOUNCE = "ANNOUNCE"
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class ChatChannelBase(SQLModel):
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||
icon: str | None = Field(default=None)
|
||||
type: ChannelType = Field(index=True)
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
channel_ = await session.get(ChatChannel, channel)
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
return (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
@classmethod
|
||||
async def get_pm_channel(
|
||||
cls, user1: int, user2: int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
channel = await cls.get(f"pm_{user1}_{user2}", session)
|
||||
if channel is None:
|
||||
channel = await cls.get(f"pm_{user2}_{user1}", session)
|
||||
return channel
|
||||
|
||||
|
||||
class ChatChannelResp(ChatChannelBase):
|
||||
channel_id: int
|
||||
moderated: bool = False
|
||||
uuid: str | None = None
|
||||
current_user_attributes: ChatUserAttributes | None = None
|
||||
last_read_id: int | None = None
|
||||
last_message_id: int | None = None
|
||||
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
|
||||
users: list[int] = Field(default_factory=list)
|
||||
message_length_limit: int = 1000
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
channel: ChatChannel,
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
redis: Redis,
|
||||
users: list[int] | None = None,
|
||||
include_recent_messages: bool = False,
|
||||
) -> Self:
|
||||
c = cls.model_validate(channel)
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
if last_msg and last_msg.isdigit():
|
||||
last_msg = int(last_msg)
|
||||
else:
|
||||
last_msg = None
|
||||
|
||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
if last_read_id and last_read_id.isdigit():
|
||||
last_read_id = int(last_read_id)
|
||||
else:
|
||||
last_read_id = last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=False,
|
||||
can_message_error=silence.reason or "You are muted in this channel.",
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = True
|
||||
else:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=True,
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = False
|
||||
|
||||
c.current_user_attributes = attribute
|
||||
if c.type != ChannelType.PUBLIC and users is not None:
|
||||
c.users = users
|
||||
c.last_message_id = last_msg
|
||||
c.last_read_id = last_read_id
|
||||
|
||||
if include_recent_messages:
|
||||
messages = (
|
||||
await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.channel_id == channel.channel_id)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [
|
||||
await ChatMessageResp.from_db(msg, session, user) for msg in messages
|
||||
]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(
|
||||
select(User.username).where(User.id == target_user_id)
|
||||
)
|
||||
c.name = target_name.one()
|
||||
assert user.id
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
|
||||
# ChatMessage
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
|
||||
class ChatMessageResp(ChatMessageBase):
|
||||
sender: UserResp | None = None
|
||||
is_action: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||
) -> "ChatMessageResp":
|
||||
m = cls.model_validate(db_message.model_dump())
|
||||
m.is_action = db_message.type == MessageType.ACTION
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
# SilenceUser
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
class UserSilenceResp(SQLModel):
|
||||
id: int
|
||||
user_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
|
||||
return cls(
|
||||
id=db_silence.id,
|
||||
user_id=db_silence.user_id,
|
||||
)
|
||||
@@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
|
||||
async def is_user_can_pm(
|
||||
self, from_user: "User", session: AsyncSession
|
||||
) -> tuple[bool, str]:
|
||||
from .relationship import Relationship, RelationshipType
|
||||
|
||||
from_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == from_user.id,
|
||||
Relationship.target_id == self.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
|
||||
return False, "You have blocked the target user."
|
||||
if from_user.pm_friends_only and (
|
||||
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"You have disabled non-friend communications "
|
||||
"and target user is not your friend.",
|
||||
)
|
||||
|
||||
relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == self.id,
|
||||
Relationship.target_id == from_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if relationship and relationship.type == RelationshipType.BLOCK:
|
||||
return False, "Target user has blocked you."
|
||||
if self.pm_friends_only and (
|
||||
not relationship or relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
return False, "Target user has disabled non-friend communications"
|
||||
return True, ""
|
||||
|
||||
|
||||
class UserResp(UserBase):
|
||||
id: int | None = None
|
||||
|
||||
@@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel):
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
# TODO: channel_id
|
||||
channel_id: int | None = None
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
@@ -84,6 +84,7 @@ class RoomResp(RoomBase):
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
current_user_score: PlaylistAggregateScore | None = None
|
||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
||||
channel_id: int = 0
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
@@ -93,7 +94,9 @@ class RoomResp(RoomBase):
|
||||
include: list[str] = [],
|
||||
user: User | None = None,
|
||||
) -> "RoomResp":
|
||||
resp = cls.model_validate(room.model_dump())
|
||||
d = room.model_dump()
|
||||
d["channel_id"] = d.get("channel_id", 0) or 0
|
||||
resp = cls.model_validate(d)
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
@@ -158,6 +161,7 @@ class RoomResp(RoomBase):
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
channel_id=server_room.room.channel_id or 0,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextvars import ContextVar
|
||||
import json
|
||||
|
||||
@@ -51,6 +52,17 @@ async def get_db():
|
||||
yield session
|
||||
|
||||
|
||||
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
||||
|
||||
|
||||
async def get_db_factory() -> DBFactory:
|
||||
async def _factory() -> AsyncIterator[AsyncSession]:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
51
app/dependencies/param.py
Normal file
51
app/dependencies/param.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
def BodyOrForm[T: BaseModel](model: type[T]):
|
||||
async def dependency(
|
||||
request: Request,
|
||||
) -> T:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
raise RequestValidationError(
|
||||
[
|
||||
{
|
||||
"loc": ("body",),
|
||||
"msg": "Invalid JSON body",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
form = await request.form()
|
||||
data = dict(form)
|
||||
|
||||
try:
|
||||
return model(**data)
|
||||
except ValidationError as e:
|
||||
raise RequestValidationError(e.errors())
|
||||
|
||||
dependency.__signature__ = inspect.signature( # pyright: ignore[reportFunctionMemberAccess]
|
||||
lambda x: None
|
||||
).replace(
|
||||
parameters=[
|
||||
inspect.Parameter(
|
||||
name=model.__name__.lower(),
|
||||
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=model,
|
||||
)
|
||||
]
|
||||
)
|
||||
return dependency
|
||||
10
app/models/chat.py
Normal file
10
app/models/chat.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatEvent(BaseModel):
|
||||
event: str
|
||||
data: dict[str, Any] | None = None
|
||||
@@ -44,6 +44,7 @@ from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.room import Room
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
@@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel):
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room) -> "MultiplayerRoom":
|
||||
def from_db(cls, room: "Room") -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
@@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel):
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if hasattr(room, "playlist"):
|
||||
if room.playlist:
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
@@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel):
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=getattr(room, "channel_id", 0),
|
||||
channel_id=room.channel_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -624,7 +625,9 @@ class MultiplayerQueue:
|
||||
async with AsyncSession(engine) as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
self.room.playlist.remove(item)
|
||||
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if found_item:
|
||||
self.room.playlist.remove(found_item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
|
||||
@@ -88,6 +88,16 @@ class GameMode(str, Enum):
|
||||
}[self]
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str | int) -> "GameMode | None":
|
||||
if isinstance(v, int) or v.isdigit():
|
||||
return cls.from_int_extra(int(v))
|
||||
v = v.lower()
|
||||
try:
|
||||
return cls[v]
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class Rank(str, Enum):
|
||||
X = "X"
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from .auth import router as auth_router
|
||||
from .chat import chat_router as chat_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
from .private import private_router as private_router
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"api_v1_router",
|
||||
"api_v2_router",
|
||||
"auth_router",
|
||||
"chat_router",
|
||||
"fetcher_router",
|
||||
"file_router",
|
||||
"private_router",
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.auth import (
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
@@ -459,7 +460,7 @@ async def oauth_token(
|
||||
# 存储令牌
|
||||
await store_token(
|
||||
db,
|
||||
3,
|
||||
BANCHOBOT_ID,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
|
||||
35
app/router/chat/__init__.py
Normal file
35
app/router/chat/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import chat_router as chat_router
|
||||
|
||||
from fastapi import Query
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/notifications",
|
||||
tags=["通知", "聊天"],
|
||||
name="获取通知",
|
||||
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
||||
)
|
||||
async def get_notifications(
|
||||
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
|
||||
return {
|
||||
"has_more": False,
|
||||
"notifications": [],
|
||||
"unread_count": 0,
|
||||
"notification_endpoint": notification_endpoint,
|
||||
}
|
||||
573
app/router/chat/banchobot.py
Normal file
573
app/router/chat/banchobot.py
Normal file
@@ -0,0 +1,573 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from math import ceil
|
||||
import random
|
||||
import shlex
|
||||
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||
from app.database.lazer_user import User
|
||||
from app.database.score import Score
|
||||
from app.database.statistics import UserStatistics, get_rank
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import (
|
||||
ChangeTeamRequest,
|
||||
ServerMultiplayerRoom,
|
||||
StartMatchCountdownRequest,
|
||||
)
|
||||
from app.models.room import MatchType, QueueMode
|
||||
from app.models.score import GameMode
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
from app.signalr.hub.hub import Client
|
||||
|
||||
from .server import server
|
||||
|
||||
from httpx import HTTPError
|
||||
from sqlmodel import func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
HandlerResult = str | None | Awaitable[str | None]
|
||||
Handler = Callable[[User, list[str], AsyncSession, ChatChannel], HandlerResult]
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None:
|
||||
self._handlers: dict[str, Handler] = {}
|
||||
self.bot_user_id = bot_user_id
|
||||
|
||||
# decorator: @bot.command("ping")
|
||||
def command(self, name: str) -> Callable[[Handler], Handler]:
|
||||
def _decorator(func: Handler) -> Handler:
|
||||
self._handlers[name.lower()] = func
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
def parse(self, content: str) -> tuple[str, list[str]] | None:
|
||||
if not content or not content.startswith("!"):
|
||||
return None
|
||||
try:
|
||||
parts = shlex.split(content[1:])
|
||||
except ValueError:
|
||||
parts = content[1:].split()
|
||||
if not parts:
|
||||
return None
|
||||
cmd = parts[0].lower()
|
||||
args = parts[1:]
|
||||
return cmd, args
|
||||
|
||||
async def try_handle(
|
||||
self,
|
||||
user: User,
|
||||
channel: ChatChannel,
|
||||
content: str,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
parsed = self.parse(content)
|
||||
if not parsed:
|
||||
return
|
||||
cmd, args = parsed
|
||||
handler = self._handlers.get(cmd)
|
||||
|
||||
reply: str | None = None
|
||||
if handler is None:
|
||||
return
|
||||
else:
|
||||
res = handler(user, args, session, channel)
|
||||
if asyncio.iscoroutine(res):
|
||||
res = await res
|
||||
reply = res # type: ignore[assignment]
|
||||
|
||||
if reply:
|
||||
await self._send_reply(user, channel, reply, session)
|
||||
|
||||
async def _send_message(
|
||||
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||
) -> None:
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None:
|
||||
return
|
||||
channel_id = channel.channel_id
|
||||
if channel_id is None:
|
||||
return
|
||||
|
||||
assert bot.id is not None
|
||||
msg = ChatMessage(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
sender_id=bot.id,
|
||||
type=MessageType.PLAIN,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(bot)
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(
|
||||
self, user: User, session: AsyncSession
|
||||
) -> ChatChannel | None:
|
||||
user_id = user.id
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None or bot.id is None:
|
||||
return None
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(user_id, bot.id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=f"pm_{user_id}_{bot.id}",
|
||||
description="Private message channel",
|
||||
type=ChannelType.PM,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(user)
|
||||
await session.refresh(bot)
|
||||
await server.batch_join_channel([user, bot], channel, session)
|
||||
return channel
|
||||
|
||||
async def _send_reply(
|
||||
self,
|
||||
user: User,
|
||||
src_channel: ChatChannel,
|
||||
content: str,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
target_channel = src_channel
|
||||
if src_channel.type == ChannelType.PUBLIC:
|
||||
pm = await self._ensure_pm_channel(user, session)
|
||||
if pm is not None:
|
||||
target_channel = pm
|
||||
await self._send_message(target_channel, content, session)
|
||||
|
||||
|
||||
bot = Bot()
|
||||
|
||||
|
||||
@bot.command("help")
|
||||
async def _help(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
cmds = sorted(bot._handlers.keys())
|
||||
if args:
|
||||
target = args[0].lower()
|
||||
if target in bot._handlers:
|
||||
return f"Usage: !{target} [args]"
|
||||
return f"No such command: {target}"
|
||||
if not cmds:
|
||||
return "No available commands"
|
||||
return "Available: " + ", ".join(f"!{c}" for c in cmds)
|
||||
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
r = random.randint(1, 100)
|
||||
return f"{user.username} rolls {r} point(s)"
|
||||
|
||||
|
||||
@bot.command("stats")
|
||||
async def _stats(
|
||||
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !stats <username> [gamemode]"
|
||||
|
||||
target_user = (
|
||||
await session.exec(select(User).where(User.username == args[0]))
|
||||
).first()
|
||||
if not target_user:
|
||||
return f"User '{args[0]}' not found."
|
||||
|
||||
gamemode = None
|
||||
if len(args) >= 2:
|
||||
gamemode = GameMode.parse(args[1].upper())
|
||||
if gamemode is None:
|
||||
subquery = (
|
||||
select(func.max(Score.id))
|
||||
.where(Score.user_id == target_user.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
last_score = (
|
||||
await session.exec(select(Score).where(Score.id == subquery))
|
||||
).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
gamemode = target_user.playmode
|
||||
|
||||
statistics = (
|
||||
await session.exec(
|
||||
select(UserStatistics).where(
|
||||
UserStatistics.user_id == target_user.id,
|
||||
UserStatistics.mode == gamemode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not statistics:
|
||||
return f"User '{args[0]}' has no statistics."
|
||||
|
||||
return f"""Stats for {target_user.username} ({gamemode.name.lower()}):
|
||||
Score: {statistics.total_score} (#{await get_rank(session, statistics)})
|
||||
Plays: {statistics.play_count} (lv{ceil(statistics.level_current)})
|
||||
Accuracy: {statistics.hit_accuracy}
|
||||
PP: {statistics.pp}
|
||||
"""
|
||||
|
||||
|
||||
async def _mp_name(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp name <name>"
|
||||
|
||||
name = args[0]
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.name = name
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room name has changed to {name}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_set(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp set <teammode> [<queuemode>]"
|
||||
|
||||
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
|
||||
if not teammode:
|
||||
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
|
||||
queuemode = (
|
||||
{
|
||||
"0": QueueMode.HOST_ONLY,
|
||||
"1": QueueMode.ALL_PLAYERS,
|
||||
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
|
||||
}.get(args[1])
|
||||
if len(args) >= 2
|
||||
else None
|
||||
)
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.match_type = teammode
|
||||
if queuemode:
|
||||
settings.queue_mode = queuemode
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_host(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.TransferHost(signalr_client, user_id)
|
||||
return f"User '{username}' has been hosted in the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_start(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
timer = None
|
||||
if len(args) >= 1 and args[0].isdigit():
|
||||
timer = int(args[0])
|
||||
|
||||
try:
|
||||
if timer is not None:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
signalr_client,
|
||||
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
await MultiplayerHubs.StartMatch(signalr_client)
|
||||
return "Good luck! Enjoy game!"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_abort(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
try:
|
||||
await MultiplayerHubs.AbortMatch(signalr_client)
|
||||
return "Match aborted."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_team(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
):
|
||||
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
|
||||
return "This command is only available in Team Versus mode."
|
||||
|
||||
if len(args) < 2:
|
||||
return "Usage: !mp team <username> <colour>"
|
||||
|
||||
username = args[0]
|
||||
team = {"red": 0, "blue": 1}.get(args[1])
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
user_client, ChangeTeamRequest(team_id=team)
|
||||
)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_password(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
password = ""
|
||||
if len(args) >= 1:
|
||||
password = args[0]
|
||||
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.password = password
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return "Room password has been set."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_kick(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.KickUser(signalr_client, user_id)
|
||||
return f"User '{username}' has been kicked from the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_map(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp map <mapid> [<playmode>]"
|
||||
|
||||
map_id = args[0]
|
||||
if not map_id.isdigit():
|
||||
return "Invalid map ID."
|
||||
map_id = int(map_id)
|
||||
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
|
||||
if playmode not in (
|
||||
GameMode.OSU,
|
||||
GameMode.TAIKO,
|
||||
GameMode.FRUITS,
|
||||
GameMode.MANIA,
|
||||
None,
|
||||
):
|
||||
return "Invalid playmode."
|
||||
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return (
|
||||
f"Cannot convert to {playmode.value}. "
|
||||
f"Original mode is {beatmap.mode.value}."
|
||||
)
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.beatmap_checksum = beatmap.checksum
|
||||
item.required_mods = []
|
||||
item.allowed_mods = []
|
||||
item.freestyle = False
|
||||
item.beatmap_id = map_id
|
||||
if playmode is not None:
|
||||
item.ruleset_id = int(playmode)
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_mods(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp mods <mod1> [<mod2> ...]"
|
||||
|
||||
required_mods = []
|
||||
allowed_mods = []
|
||||
freestyle = False
|
||||
for arg in args:
|
||||
if arg == "None":
|
||||
required_mods.clear()
|
||||
allowed_mods.clear()
|
||||
break
|
||||
elif arg == "Freestyle":
|
||||
freestyle = True
|
||||
elif arg.startswith("+"):
|
||||
mod = arg.removeprefix("+")
|
||||
if len(mod) != 2:
|
||||
return f"Invalid mod: {mod}."
|
||||
allowed_mods.append(APIMod(acronym=mod))
|
||||
else:
|
||||
if len(arg) != 2:
|
||||
return f"Invalid mod: {arg}."
|
||||
required_mods.append(APIMod(acronym=arg))
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.freestyle = freestyle
|
||||
if not freestyle:
|
||||
item.allowed_mods = allowed_mods
|
||||
else:
|
||||
item.allowed_mods = []
|
||||
item.required_mods = required_mods
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
_MP_COMMANDS = {
|
||||
"name": _mp_name,
|
||||
"set": _mp_set,
|
||||
"host": _mp_host,
|
||||
"start": _mp_start,
|
||||
"abort": _mp_abort,
|
||||
"map": _mp_map,
|
||||
"mods": _mp_mods,
|
||||
"kick": _mp_kick,
|
||||
"password": _mp_password,
|
||||
"team": _mp_team,
|
||||
}
|
||||
_MP_HELP = """!mp name <name>
|
||||
!mp set <teammode> [<queuemode>]
|
||||
!mp host <host>
|
||||
!mp start [<timer>]
|
||||
!mp abort
|
||||
!mp map <map> [<playmode>]
|
||||
!mp mods <mod1> [<mod2> ...]
|
||||
!mp kick <user>
|
||||
!mp password [<password>]
|
||||
!mp team <user> <team:red|blue>"""
|
||||
|
||||
|
||||
@bot.command("mp")
|
||||
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
|
||||
if not channel.name.startswith("room_"):
|
||||
return
|
||||
|
||||
room_id = int(channel.name[5:])
|
||||
room = MultiplayerHubs.rooms.get(room_id)
|
||||
if not room:
|
||||
return
|
||||
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
|
||||
if not signalr_client:
|
||||
return
|
||||
|
||||
if len(args) < 1:
|
||||
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
|
||||
|
||||
command = args[0].lower()
|
||||
if command not in _MP_COMMANDS:
|
||||
return f"No such command: {command}"
|
||||
|
||||
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)
|
||||
306
app/router/chat/channel.py
Normal file
306
app/router/chat/channel.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||
silences: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/updates",
|
||||
response_model=UpdateResponse,
|
||||
name="获取更新",
|
||||
description="获取当前用户所在频道的最新的禁言情况。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_update(
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
response_model=ChatChannelResp,
|
||||
name="加入频道",
|
||||
description="加入指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def join_channel(
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return await server.join_channel(current_user, db_channel, session)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
status_code=204,
|
||||
name="离开频道",
|
||||
description="将用户移出指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def leave_channel(
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.leave_channel(current_user, db_channel, session)
|
||||
return
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels",
|
||||
response_model=list[ChatChannelResp],
|
||||
name="获取频道列表",
|
||||
description="获取所有公开频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel_list(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
assert channel.channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class GetChannelResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
users: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}",
|
||||
response_model=GetChannelResp,
|
||||
name="获取频道信息",
|
||||
description="获取指定频道的信息。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel(
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id is not None
|
||||
|
||||
users = []
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
if len(user_ids) != 2:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
for id_ in user_ids:
|
||||
if int(id_) == current_user.id:
|
||||
continue
|
||||
target_user = await session.get(User, int(id_))
|
||||
if target_user is None:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
users.extend([target_user, current_user])
|
||||
break
|
||||
|
||||
return GetChannelResp(
|
||||
channel=await ChatChannelResp.from_db(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(db_channel.channel_id, [])
|
||||
if db_channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CreateChannelReq(BaseModel):
|
||||
class AnnounceChannel(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
message: str | None = None
|
||||
type: Literal["ANNOUNCE", "PM"] = "PM"
|
||||
target_id: int | None = None
|
||||
target_ids: list[int] | None = None
|
||||
channel: AnnounceChannel | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check(self) -> Self:
|
||||
if self.type == "PM":
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id must be set for PM channels")
|
||||
else:
|
||||
if self.target_ids is None or self.channel is None or self.message is None:
|
||||
raise ValueError(
|
||||
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/channels",
|
||||
response_model=ChatChannelResp,
|
||||
name="创建频道",
|
||||
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_channel(
|
||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
if req.type == "PM":
|
||||
target = await session.get(User, req.target_id)
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(
|
||||
current_user.id, # pyright: ignore[reportArgumentType]
|
||||
req.target_id, # pyright: ignore[reportArgumentType]
|
||||
session,
|
||||
)
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
channel = await ChatChannel.get(channel_name, session)
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=channel_name,
|
||||
description=req.channel.description
|
||||
if req.channel
|
||||
else "Private message channel",
|
||||
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(current_user)
|
||||
if req.type == "PM":
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(
|
||||
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||
)
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
assert channel.channel_id
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, []),
|
||||
include_recent_messages=True,
|
||||
)
|
||||
243
app/router/chat/message.py
Normal file
243
app/router/chat/message.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
MessageType,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class KeepAliveResp(BaseModel):
|
||||
silences: list[UserSilenceResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/ack",
|
||||
name="保持连接",
|
||||
response_model=KeepAliveResp,
|
||||
description="保持公共频道的连接。同时返回最近的禁言列表。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def keep_alive(
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class MessageReq(BaseModel):
|
||||
message: str
|
||||
is_action: bool = False
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=ChatMessageResp,
|
||||
name="发送消息",
|
||||
description="发送消息到指定频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def send_message(
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
assert db_channel.channel_id
|
||||
assert current_user.id
|
||||
msg = ChatMessage(
|
||||
channel_id=db_channel.channel_id,
|
||||
content=req.message,
|
||||
sender_id=current_user.id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(db_channel)
|
||||
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
|
||||
)
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=list[ChatMessageResp],
|
||||
name="获取消息",
|
||||
description="获取指定频道的消息列表。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_message(
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
messages = await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.channel_id == db_channel.channel_id,
|
||||
col(ChatMessage.message_id) > since,
|
||||
col(ChatMessage.message_id) < until if until is not None else True,
|
||||
)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
resp.reverse()
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/mark-as-read/{message}",
|
||||
status_code=204,
|
||||
name="标记消息为已读",
|
||||
description="标记指定消息为已读。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def mark_as_read(
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id
|
||||
await server.mark_as_read(db_channel.channel_id, message)
|
||||
|
||||
|
||||
class PMReq(BaseModel):
|
||||
target_id: int
|
||||
message: str
|
||||
is_action: bool = False
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
class NewPMResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
message: ChatMessageResp
|
||||
new_channel_id: int
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/new",
|
||||
name="创建私聊频道",
|
||||
description="创建一个新的私聊频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_new_pm(
|
||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
user_id = current_user.id
|
||||
target = await session.get(User, req.target_id)
|
||||
if target is None:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
assert user_id
|
||||
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=f"pm_{user_id}_{req.target_id}",
|
||||
description="Private message channel",
|
||||
type=ChannelType.PM,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
assert channel.channel_id
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
)
|
||||
msg = ChatMessage(
|
||||
channel_id=channel.channel_id,
|
||||
content=req.message,
|
||||
sender_id=user_id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(channel)
|
||||
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
await server.send_message_to_channel(message_resp)
|
||||
return NewPMResp(
|
||||
channel=channel_resp,
|
||||
message=message_resp,
|
||||
new_channel_id=channel_resp.channel_id,
|
||||
)
|
||||
285
app/router/chat/server.py
Normal file
285
app/router/chat/server.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
engine,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ChatServer:
|
||||
def __init__(self):
|
||||
self.connect_client: dict[int, WebSocket] = {}
|
||||
self.channels: dict[int, list[int]] = {}
|
||||
self.redis: Redis = get_redis()
|
||||
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.ChatSubscriber = ChatSubscriber()
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
if user_id in self.connect_client:
|
||||
del self.connect_client[user_id]
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
async def send_event(self, client: WebSocket, event: ChatEvent):
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(client, event)
|
||||
|
||||
async def mark_as_read(self, channel_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
|
||||
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
)
|
||||
if is_bot_command:
|
||||
client = self.connect_client.get(message.sender_id)
|
||||
if client:
|
||||
self._add_task(self.send_event(client, event))
|
||||
else:
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
assert message.message_id
|
||||
await self.mark_as_read(message.channel_id, message.message_id)
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
):
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
not_joined = []
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
for user in users:
|
||||
assert user.id is not None
|
||||
if user.id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user.id)
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
assert user.id is not None
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id]
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user.id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
|
||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
await self.send_event(
|
||||
client,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
chat_router = APIRouter(include_in_schema=False)
|
||||
|
||||
|
||||
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
try:
|
||||
while True:
|
||||
packets = await ws.receive_json()
|
||||
if packets.get("event") == "chat.end":
|
||||
async for session in factory():
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
break
|
||||
await server.disconnect(user, session)
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {user_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {user_id}")
|
||||
|
||||
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
await server.ChatSubscriber.start_subscribe()
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
|
||||
)
|
||||
) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
login = await websocket.receive_json()
|
||||
if login.get("event") != "chat.start":
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
channel = await ChatChannel.get(1, session)
|
||||
if channel is not None:
|
||||
await server.join_channel(user, channel, session)
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
@@ -138,6 +138,8 @@ async def _participate_room(
|
||||
participated_user.joined_at = datetime.now(UTC)
|
||||
db_room.participant_count += 1
|
||||
|
||||
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms",
|
||||
@@ -150,11 +152,12 @@ async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
@@ -219,11 +222,12 @@ async def add_user_to_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
await _participate_room(room_id, user_id, db_room, db)
|
||||
await _participate_room(room_id, user_id, db_room, db, redis)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
@@ -243,6 +247,7 @@ async def remove_user_from_room(
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -257,6 +262,7 @@ async def remove_user_from_room(
|
||||
if participated_user is not None:
|
||||
participated_user.left_at = datetime.now(UTC)
|
||||
db_room.participant_count -= 1
|
||||
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
|
||||
await db.commit()
|
||||
return None
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import (
|
||||
BeatmapPlaycounts,
|
||||
BeatmapPlaycountsResp,
|
||||
@@ -65,6 +66,7 @@ async def get_users(
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
for searched_user in searched_users
|
||||
if searched_user.id != BANCHOBOT_ID
|
||||
]
|
||||
)
|
||||
|
||||
@@ -91,7 +93,7 @@ async def get_user_info_ruleset(
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await UserResp.from_db(
|
||||
searched_user,
|
||||
@@ -123,7 +125,7 @@ async def get_user_info(
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await UserResp.from_db(
|
||||
searched_user,
|
||||
@@ -148,7 +150,7 @@ async def get_user_beatmapsets(
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
if not user or user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
if type in {
|
||||
@@ -218,7 +220,7 @@ async def get_user_scores(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
db_user = await session.get(User, user_id)
|
||||
if not db_user:
|
||||
if not db_user or db_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
@@ -271,7 +273,7 @@ async def get_user_events(
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_user = await session.get(User, user)
|
||||
if db_user is None:
|
||||
if db_user is None or db_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, "User Not found")
|
||||
events = await db_user.awaitable_attrs.events
|
||||
if limit is not None:
|
||||
|
||||
32
app/service/create_banchobot.py
Normal file
32
app/service/create_banchobot.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database.lazer_user import User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import engine
|
||||
from app.models.score import GameMode
|
||||
|
||||
from sqlmodel import exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_banchobot():
|
||||
async with AsyncSession(engine) as session:
|
||||
is_exist = (
|
||||
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
|
||||
).first()
|
||||
if not is_exist:
|
||||
banchobot = User(
|
||||
username="BanchoBot",
|
||||
email="banchobot@ppy.sh",
|
||||
is_bot=True,
|
||||
pw_bcrypt="0",
|
||||
id=BANCHOBOT_ID,
|
||||
avatar_url="https://a.ppy.sh/3",
|
||||
country_code="SH",
|
||||
website="https://twitter.com/banchoboat",
|
||||
)
|
||||
session.add(banchobot)
|
||||
statistics = UserStatistics(user_id=BANCHOBOT_ID, mode=GameMode.OSU)
|
||||
session.add(statistics)
|
||||
await session.commit()
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.dependencies.database import engine, get_redis
|
||||
@@ -26,12 +27,12 @@ async def create_daily_challenge_room(
|
||||
return await create_playlist_room(
|
||||
session=session,
|
||||
name=str(today),
|
||||
host_id=3,
|
||||
host_id=BANCHOBOT_ID,
|
||||
playlist=[
|
||||
Playlist(
|
||||
id=0,
|
||||
room_id=0,
|
||||
owner_id=3,
|
||||
owner_id=BANCHOBOT_ID,
|
||||
ruleset_id=ruleset_id,
|
||||
beatmap_id=beatmap,
|
||||
required_mods=required_mods,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database.lazer_user import User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import engine
|
||||
@@ -15,6 +16,9 @@ async def create_rx_statistics():
|
||||
async with AsyncSession(engine) as session:
|
||||
users = (await session.exec(select(User.id))).all()
|
||||
for i in users:
|
||||
if i == BANCHOBOT_ID:
|
||||
continue
|
||||
|
||||
if settings.enable_rx:
|
||||
for mode in (
|
||||
GameMode.OSURX,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import APIUploadedRoom, Room
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
@@ -25,6 +26,18 @@ async def create_playlist_room_from_api(
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Playlist room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
@@ -57,6 +70,18 @@ async def create_playlist_room(
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Playlist room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
await add_playlists_to_room(session, db_room.id, playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from fnmatch import fnmatch
|
||||
from typing import Any
|
||||
|
||||
from app.dependencies.database import get_redis_pubsub
|
||||
@@ -23,18 +24,28 @@ class RedisSubscriber:
|
||||
del self.handlers[channel]
|
||||
await self.pubsub.unsubscribe(channel)
|
||||
|
||||
def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]):
|
||||
if channel not in self.handlers:
|
||||
self.handlers[channel] = []
|
||||
self.handlers[channel].append(handler)
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
message = await self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=None
|
||||
)
|
||||
if message is not None and message["type"] == "message":
|
||||
method = self.handlers.get(message["channel"])
|
||||
if method:
|
||||
matched_handlers = []
|
||||
if message["channel"] in self.handlers:
|
||||
matched_handlers.extend(self.handlers[message["channel"]])
|
||||
for pattern, handlers in self.handlers.items():
|
||||
if fnmatch(message["channel"], pattern):
|
||||
matched_handlers.extend(handlers)
|
||||
if matched_handlers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
handler(message["channel"], message["data"])
|
||||
for handler in method
|
||||
for handler in matched_handlers
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
38
app/service/subscribers/chat.py
Normal file
38
app/service/subscribers/chat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import RedisSubscriber
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.router.chat.server import ChatServer
|
||||
|
||||
|
||||
JOIN_CHANNEL = "chat:room:joined"
|
||||
EXIT_CHANNEL = "chat:room:left"
|
||||
|
||||
|
||||
class ChatSubscriber(RedisSubscriber):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.room_subscriber: dict[int, list[int]] = {}
|
||||
self.chat_server: "ChatServer | None" = None
|
||||
|
||||
async def start_subscribe(self):
|
||||
await self.subscribe(JOIN_CHANNEL)
|
||||
self.add_handler(JOIN_CHANNEL, self.on_join_room)
|
||||
await self.subscribe(EXIT_CHANNEL)
|
||||
self.add_handler(EXIT_CHANNEL, self.on_leave_room)
|
||||
self.start()
|
||||
|
||||
async def on_join_room(self, c: str, s: str):
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
|
||||
|
||||
async def on_leave_room(self, c: str, s: str):
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
|
||||
@@ -6,6 +6,7 @@ from typing import override
|
||||
|
||||
from app.database import Room
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
from app.database.lazer_user import User
|
||||
from app.database.multiplayer_event import MultiplayerEvent
|
||||
from app.database.playlists import Playlist
|
||||
@@ -172,6 +173,20 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
self.get_client_by_id(str(user_id)), server_room, user
|
||||
)
|
||||
|
||||
def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom:
|
||||
store = self.get_or_create_state(client)
|
||||
if store.room_id == 0:
|
||||
raise InvokeException("You are not in a room")
|
||||
if store.room_id not in self.rooms:
|
||||
raise InvokeException("Room does not exist")
|
||||
server_room = self.rooms[store.room_id]
|
||||
return server_room
|
||||
|
||||
def _ensure_host(self, client: Client, server_room: ServerMultiplayerRoom):
|
||||
room = server_room.room
|
||||
if room.host is None or room.host.user_id != client.user_id:
|
||||
raise InvokeException("You are not the host of this room")
|
||||
|
||||
async def CreateRoom(self, client: Client, room: MultiplayerRoom):
|
||||
logger.info(f"[MultiplayerHub] {client.user_id} creating room")
|
||||
store = self.get_or_create_state(client)
|
||||
@@ -195,6 +210,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Multiplayer room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue]
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
item = room.playlist[0]
|
||||
item.owner_id = client.user_id
|
||||
room.room_id = db_room.id
|
||||
@@ -280,6 +307,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if db_room is None:
|
||||
raise InvokeException("Room does not exist in database")
|
||||
db_room.participant_count += 1
|
||||
|
||||
redis = get_redis()
|
||||
await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}")
|
||||
|
||||
return room
|
||||
|
||||
async def change_beatmap_availability(
|
||||
@@ -914,6 +945,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if target_store:
|
||||
target_store.room_id = 0
|
||||
|
||||
redis = get_redis()
|
||||
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
|
||||
|
||||
async def end_room(self, room: ServerMultiplayerRoom):
|
||||
assert room.room.host
|
||||
async with AsyncSession(engine) as session:
|
||||
@@ -1085,17 +1119,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
)
|
||||
|
||||
async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings):
|
||||
store = self.get_or_create_state(client)
|
||||
if store.room_id == 0:
|
||||
raise InvokeException("You are not in a room")
|
||||
if store.room_id not in self.rooms:
|
||||
raise InvokeException("Room does not exist")
|
||||
server_room = self.rooms[store.room_id]
|
||||
server_room = self._ensure_in_room(client)
|
||||
self._ensure_host(client, server_room)
|
||||
room = server_room.room
|
||||
|
||||
if room.host is None or room.host.user_id != client.user_id:
|
||||
raise InvokeException("You are not the host of this room")
|
||||
|
||||
if room.state != MultiplayerRoomState.OPEN:
|
||||
raise InvokeException("Cannot change settings while playing")
|
||||
|
||||
|
||||
4
main.py
4
main.py
@@ -13,6 +13,7 @@ from app.router import (
|
||||
api_v1_router,
|
||||
api_v2_router,
|
||||
auth_router,
|
||||
chat_router,
|
||||
fetcher_router,
|
||||
file_router,
|
||||
private_router,
|
||||
@@ -21,6 +22,7 @@ from app.router import (
|
||||
)
|
||||
from app.router.redirect import redirect_router
|
||||
from app.service.calculate_all_user_rank import calculate_user_rank
|
||||
from app.service.create_banchobot import create_banchobot
|
||||
from app.service.daily_challenge import daily_challenge_job
|
||||
from app.service.osu_rx_statistics import create_rx_statistics
|
||||
from app.service.pp_recalculate import recalculate_all_players_pp
|
||||
@@ -42,6 +44,7 @@ async def lifespan(app: FastAPI):
|
||||
await calculate_user_rank(True)
|
||||
init_scheduler()
|
||||
await daily_challenge_job()
|
||||
await create_banchobot()
|
||||
# on shutdown
|
||||
yield
|
||||
stop_scheduler()
|
||||
@@ -71,6 +74,7 @@ app = FastAPI(
|
||||
|
||||
app.include_router(api_v2_router)
|
||||
app.include_router(api_v1_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(redirect_api_router)
|
||||
app.include_router(signalr_router)
|
||||
app.include_router(fetcher_router)
|
||||
|
||||
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
192
migrations/versions/dd33d89aa2c2_chat_add_chat.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""chat: add chat
|
||||
|
||||
Revision ID: dd33d89aa2c2
|
||||
Revises: 9f6b27e8ea51
|
||||
Create Date: 2025-08-15 14:22:34.775877
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dd33d89aa2c2"
|
||||
down_revision: str | Sequence[str] | None = "9f6b27e8ea51"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
channel_table = op.create_table(
|
||||
"chat_channels",
|
||||
sa.Column("name", sa.VARCHAR(length=50), nullable=True),
|
||||
sa.Column("description", sa.VARCHAR(length=255), nullable=True),
|
||||
sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"PUBLIC",
|
||||
"PRIVATE",
|
||||
"MULTIPLAYER",
|
||||
"SPECTATOR",
|
||||
"TEMPORARY",
|
||||
"PM",
|
||||
"GROUP",
|
||||
"SYSTEM",
|
||||
"ANNOUNCE",
|
||||
"TEAM",
|
||||
name="channeltype",
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("channel_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_channel_id"),
|
||||
"chat_channels",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_description"),
|
||||
"chat_channels",
|
||||
["description"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"chat_messages",
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.Column("content", sa.VARCHAR(length=1000), nullable=True),
|
||||
sa.Column("message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("sender_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("timestamp", sa.DateTime(), nullable=True),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["channel_id"],
|
||||
["chat_channels.channel_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["sender_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("message_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_channel_id"),
|
||||
"chat_messages",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_message_id"),
|
||||
"chat_messages",
|
||||
["message_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"chat_silence_users",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_id", sa.Integer(), nullable=False),
|
||||
sa.Column("until", sa.DateTime(), nullable=True),
|
||||
sa.Column("banned_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("reason", sa.VARCHAR(length=255), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["channel_id"],
|
||||
["chat_channels.channel_id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_channel_id"),
|
||||
"chat_silence_users",
|
||||
["channel_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_reason"),
|
||||
"chat_silence_users",
|
||||
["reason"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_until"),
|
||||
"chat_silence_users",
|
||||
["until"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_banned_at"),
|
||||
"chat_silence_users",
|
||||
["banned_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_user_id"),
|
||||
"chat_silence_users",
|
||||
["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_id"),
|
||||
"chat_silence_users",
|
||||
["id"],
|
||||
unique=False,
|
||||
)
|
||||
op.bulk_insert(
|
||||
channel_table,
|
||||
[
|
||||
{
|
||||
"name": "osu!",
|
||||
"description": "General discussion for osu!",
|
||||
"type": "PUBLIC",
|
||||
},
|
||||
{
|
||||
"name": "announce",
|
||||
"description": "Official announcements",
|
||||
"type": "PUBLIC",
|
||||
},
|
||||
],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("chat_silence_users")
|
||||
op.drop_table("chat_messages")
|
||||
op.drop_table("chat_channels")
|
||||
# ### end Alembic commands ###
|
||||
54
migrations/versions/df9f725a077c_room_add_channel_id.py
Normal file
54
migrations/versions/df9f725a077c_room_add_channel_id.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""room: add channel_id
|
||||
|
||||
Revision ID: df9f725a077c
|
||||
Revises: dd33d89aa2c2
|
||||
Create Date: 2025-08-16 08:05:28.748265
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "df9f725a077c"
|
||||
down_revision: str | Sequence[str] | None = "dd33d89aa2c2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False
|
||||
)
|
||||
op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("rooms", "channel_id")
|
||||
op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users")
|
||||
op.alter_column(
|
||||
"chat_silence_users",
|
||||
"banned_at",
|
||||
existing_type=mysql.DATETIME(),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user