Merge pull request #14 from GooGuTeam/feat/chat

feat: 添加聊天系统
This commit is contained in:
咕谷酱
2025-08-17 00:23:29 +08:00
committed by GitHub
29 changed files with 2255 additions and 29 deletions

3
app/const.py Normal file
View File

@@ -0,0 +1,3 @@
from __future__ import annotations
BANCHOBOT_ID = 2

View File

@@ -10,6 +10,13 @@ from .beatmapset import (
BeatmapsetResp,
)
from .best_score import BestScore
from .chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
ChatMessageResp,
)
from .counts import (
CountResp,
MonthlyPlaycounts,
@@ -63,6 +70,11 @@ __all__ = [
"Beatmapset",
"BeatmapsetResp",
"BestScore",
"ChannelType",
"ChatChannel",
"ChatChannelResp",
"ChatMessage",
"ChatMessageResp",
"CountResp",
"DailyChallengeStats",
"DailyChallengeStatsResp",

240
app/database/chat.py Normal file
View File

@@ -0,0 +1,240 @@
from datetime import UTC, datetime
from enum import Enum
from typing import Self
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
from app.models.model import UTCBaseModel
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import (
VARCHAR,
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
# ChatChannel
class ChatUserAttributes(BaseModel):
can_message: bool
can_message_error: str | None = None
last_read_id: int
class ChannelType(str, Enum):
PUBLIC = "PUBLIC"
PRIVATE = "PRIVATE"
MULTIPLAYER = "MULTIPLAYER"
SPECTATOR = "SPECTATOR"
TEMPORARY = "TEMPORARY"
PM = "PM"
GROUP = "GROUP"
SYSTEM = "SYSTEM"
ANNOUNCE = "ANNOUNCE"
TEAM = "TEAM"
class ChatChannelBase(SQLModel):
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
icon: str | None = Field(default=None)
type: ChannelType = Field(index=True)
class ChatChannel(ChatChannelBase, table=True):
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
channel_id: int | None = Field(primary_key=True, index=True, default=None)
@classmethod
async def get(
cls, channel: str | int, session: AsyncSession
) -> "ChatChannel | None":
if isinstance(channel, int) or channel.isdigit():
channel_ = await session.get(ChatChannel, channel)
if channel_ is not None:
return channel_
return (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
@classmethod
async def get_pm_channel(
cls, user1: int, user2: int, session: AsyncSession
) -> "ChatChannel | None":
channel = await cls.get(f"pm_{user1}_{user2}", session)
if channel is None:
channel = await cls.get(f"pm_{user2}_{user1}", session)
return channel
class ChatChannelResp(ChatChannelBase):
channel_id: int
moderated: bool = False
uuid: str | None = None
current_user_attributes: ChatUserAttributes | None = None
last_read_id: int | None = None
last_message_id: int | None = None
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
users: list[int] = Field(default_factory=list)
message_length_limit: int = 1000
@classmethod
async def from_db(
cls,
channel: ChatChannel,
session: AsyncSession,
user: User,
redis: Redis,
users: list[int] | None = None,
include_recent_messages: bool = False,
) -> Self:
c = cls.model_validate(channel)
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
if last_msg and last_msg.isdigit():
last_msg = int(last_msg)
else:
last_msg = None
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
if last_read_id and last_read_id.isdigit():
last_read_id = int(last_read_id)
else:
last_read_id = last_msg
if silence is not None:
attribute = ChatUserAttributes(
can_message=False,
can_message_error=silence.reason or "You are muted in this channel.",
last_read_id=last_read_id or 0,
)
c.moderated = True
else:
attribute = ChatUserAttributes(
can_message=True,
last_read_id=last_read_id or 0,
)
c.moderated = False
c.current_user_attributes = attribute
if c.type != ChannelType.PUBLIC and users is not None:
c.users = users
c.last_message_id = last_msg
c.last_read_id = last_read_id
if include_recent_messages:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.timestamp).desc())
.limit(10)
)
).all()
c.recent_messages = [
await ChatMessageResp.from_db(msg, session, user) for msg in messages
]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(
select(User.username).where(User.id == target_user_id)
)
c.name = target_name.one()
assert user.id
c.users = [target_user_id, user.id]
return c
# ChatMessage
class MessageType(str, Enum):
ACTION = "action"
MARKDOWN = "markdown"
PLAIN = "plain"
class ChatMessageBase(UTCBaseModel, SQLModel):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int | None = Field(index=True, primary_key=True, default=None)
sender_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
timestamp: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
class ChatMessage(ChatMessageBase, table=True):
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
class ChatMessageResp(ChatMessageBase):
sender: UserResp | None = None
is_action: bool = False
@classmethod
async def from_db(
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
) -> "ChatMessageResp":
m = cls.model_validate(db_message.model_dump())
m.is_action = db_message.type == MessageType.ACTION
if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else:
m.sender = await UserResp.from_db(
db_message.user, session, RANKING_INCLUDES
)
return m
# SilenceUser
class SilenceUser(UTCBaseModel, SQLModel, table=True):
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(primary_key=True, default=None, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
banned_at: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
class UserSilenceResp(SQLModel):
id: int
user_id: int
@classmethod
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
return cls(
id=db_silence.id,
user_id=db_silence.user_id,
)

View File

@@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
async def is_user_can_pm(
self, from_user: "User", session: AsyncSession
) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType
from_relationship = (
await session.exec(
select(Relationship).where(
Relationship.user_id == from_user.id,
Relationship.target_id == self.id,
)
)
).first()
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
return False, "You have blocked the target user."
if from_user.pm_friends_only and (
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
):
return (
False,
"You have disabled non-friend communications "
"and target user is not your friend.",
)
relationship = (
await session.exec(
select(Relationship).where(
Relationship.user_id == self.id,
Relationship.target_id == from_user.id,
)
)
).first()
if relationship and relationship.type == RelationshipType.BLOCK:
return False, "Target user has blocked you."
if self.pm_friends_only and (
not relationship or relationship.type != RelationshipType.FOLLOW
):
return False, "Target user has disabled non-friend communications"
return True, ""
class UserResp(UserBase):
id: int | None = None

View File

@@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel):
auto_skip: bool
auto_start_duration: int
status: RoomStatus
# TODO: channel_id
channel_id: int | None = None
class Room(AsyncAttrs, RoomBase, table=True):
@@ -84,6 +84,7 @@ class RoomResp(RoomBase):
current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list)
channel_id: int = 0
@classmethod
async def from_db(
@@ -93,7 +94,9 @@ class RoomResp(RoomBase):
include: list[str] = [],
user: User | None = None,
) -> "RoomResp":
resp = cls.model_validate(room.model_dump())
d = room.model_dump()
d["channel_id"] = d.get("channel_id", 0) or 0
resp = cls.model_validate(d)
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
@@ -158,6 +161,7 @@ class RoomResp(RoomBase):
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
channel_id=server_room.room.channel_id or 0,
)
return resp

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextvars import ContextVar
import json
@@ -51,6 +52,17 @@ async def get_db():
yield session
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
async def get_db_factory() -> DBFactory:
async def _factory() -> AsyncIterator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
return _factory
# Redis 依赖
def get_redis():
return redis_client

51
app/dependencies/param.py Normal file
View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import inspect
from typing import Any
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ValidationError
def BodyOrForm[T: BaseModel](model: type[T]):
async def dependency(
request: Request,
) -> T:
content_type = request.headers.get("content-type", "")
data: dict[str, Any] = {}
if "application/json" in content_type:
try:
data = await request.json()
except Exception:
raise RequestValidationError(
[
{
"loc": ("body",),
"msg": "Invalid JSON body",
"type": "value_error",
}
]
)
else:
form = await request.form()
data = dict(form)
try:
return model(**data)
except ValidationError as e:
raise RequestValidationError(e.errors())
dependency.__signature__ = inspect.signature( # pyright: ignore[reportFunctionMemberAccess]
lambda x: None
).replace(
parameters=[
inspect.Parameter(
name=model.__name__.lower(),
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=model,
)
]
)
return dependency

10
app/models/chat.py Normal file
View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel
class ChatEvent(BaseModel):
event: str
data: dict[str, Any] | None = None

View File

@@ -44,6 +44,7 @@ from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.database.room import Room
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
@@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel):
channel_id: int
@classmethod
def from_db(cls, room) -> "MultiplayerRoom":
def from_db(cls, room: "Room") -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
@@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel):
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if hasattr(room, "playlist"):
if room.playlist:
for item in room.playlist:
playlist.append(
PlaylistItem(
@@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel):
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=getattr(room, "channel_id", 0),
channel_id=room.channel_id,
)
@@ -624,7 +625,9 @@ class MultiplayerQueue:
async with AsyncSession(engine) as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
self.room.playlist.remove(item)
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
if found_item:
self.room.playlist.remove(found_item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()

View File

@@ -88,6 +88,16 @@ class GameMode(str, Enum):
}[self]
return self
@classmethod
def parse(cls, v: str | int) -> "GameMode | None":
if isinstance(v, int) or v.isdigit():
return cls.from_int_extra(int(v))
v = v.lower()
try:
return cls[v]
except ValueError:
return None
class Rank(str, Enum):
X = "X"

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from app.signalr import signalr_router as signalr_router
from .auth import router as auth_router
from .chat import chat_router as chat_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router
from .private import private_router as private_router
@@ -17,6 +18,7 @@ __all__ = [
"api_v1_router",
"api_v2_router",
"auth_router",
"chat_router",
"fetcher_router",
"file_router",
"private_router",

View File

@@ -14,6 +14,7 @@ from app.auth import (
store_token,
)
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
@@ -459,7 +460,7 @@ async def oauth_token(
# 存储令牌
await store_token(
db,
3,
BANCHOBOT_ID,
client_id,
scopes,
access_token,

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from app.config import settings
from app.router.v2 import api_v2_router as router
from . import channel, message # noqa: F401
from .server import chat_router as chat_router
from fastapi import Query
__all__ = ["chat_router"]
@router.get(
"/notifications",
tags=["通知", "聊天"],
name="获取通知",
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
)
async def get_notifications(
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
else:
notification_endpoint = "/notification-server"
return {
"has_more": False,
"notifications": [],
"unread_count": 0,
"notification_endpoint": notification_endpoint,
}

View File

@@ -0,0 +1,573 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from datetime import timedelta
from math import ceil
import random
import shlex
from app.const import BANCHOBOT_ID
from app.database import ChatMessageResp
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
from app.database.lazer_user import User
from app.database.score import Score
from app.database.statistics import UserStatistics, get_rank
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.models.mods import APIMod
from app.models.multiplayer_hub import (
ChangeTeamRequest,
ServerMultiplayerRoom,
StartMatchCountdownRequest,
)
from app.models.room import MatchType, QueueMode
from app.models.score import GameMode
from app.signalr.hub import MultiplayerHubs
from app.signalr.hub.hub import Client
from .server import server
from httpx import HTTPError
from sqlmodel import func, select
from sqlmodel.ext.asyncio.session import AsyncSession
HandlerResult = str | None | Awaitable[str | None]
Handler = Callable[[User, list[str], AsyncSession, ChatChannel], HandlerResult]
class Bot:
def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None:
self._handlers: dict[str, Handler] = {}
self.bot_user_id = bot_user_id
# decorator: @bot.command("ping")
def command(self, name: str) -> Callable[[Handler], Handler]:
def _decorator(func: Handler) -> Handler:
self._handlers[name.lower()] = func
return func
return _decorator
def parse(self, content: str) -> tuple[str, list[str]] | None:
if not content or not content.startswith("!"):
return None
try:
parts = shlex.split(content[1:])
except ValueError:
parts = content[1:].split()
if not parts:
return None
cmd = parts[0].lower()
args = parts[1:]
return cmd, args
async def try_handle(
self,
user: User,
channel: ChatChannel,
content: str,
session: AsyncSession,
) -> None:
parsed = self.parse(content)
if not parsed:
return
cmd, args = parsed
handler = self._handlers.get(cmd)
reply: str | None = None
if handler is None:
return
else:
res = handler(user, args, session, channel)
if asyncio.iscoroutine(res):
res = await res
reply = res # type: ignore[assignment]
if reply:
await self._send_reply(user, channel, reply, session)
async def _send_message(
self, channel: ChatChannel, content: str, session: AsyncSession
) -> None:
bot = await session.get(User, self.bot_user_id)
if bot is None:
return
channel_id = channel.channel_id
if channel_id is None:
return
assert bot.id is not None
msg = ChatMessage(
channel_id=channel_id,
content=content,
sender_id=bot.id,
type=MessageType.PLAIN,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(bot)
resp = await ChatMessageResp.from_db(msg, session, bot)
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(
self, user: User, session: AsyncSession
) -> ChatChannel | None:
user_id = user.id
if user_id is None:
return None
bot = await session.get(User, self.bot_user_id)
if bot is None or bot.id is None:
return None
channel = await ChatChannel.get_pm_channel(user_id, bot.id, session)
if channel is None:
channel = ChatChannel(
name=f"pm_{user_id}_{bot.id}",
description="Private message channel",
type=ChannelType.PM,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(user)
await session.refresh(bot)
await server.batch_join_channel([user, bot], channel, session)
return channel
async def _send_reply(
self,
user: User,
src_channel: ChatChannel,
content: str,
session: AsyncSession,
) -> None:
target_channel = src_channel
if src_channel.type == ChannelType.PUBLIC:
pm = await self._ensure_pm_channel(user, session)
if pm is not None:
target_channel = pm
await self._send_message(target_channel, content, session)
bot = Bot()
@bot.command("help")
async def _help(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
cmds = sorted(bot._handlers.keys())
if args:
target = args[0].lower()
if target in bot._handlers:
return f"Usage: !{target} [args]"
return f"No such command: {target}"
if not cmds:
return "No available commands"
return "Available: " + ", ".join(f"!{c}" for c in cmds)
@bot.command("roll")
def _roll(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
r = random.randint(1, 100)
return f"{user.username} rolls {r} point(s)"
@bot.command("stats")
async def _stats(
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) < 1:
return "Usage: !stats <username> [gamemode]"
target_user = (
await session.exec(select(User).where(User.username == args[0]))
).first()
if not target_user:
return f"User '{args[0]}' not found."
gamemode = None
if len(args) >= 2:
gamemode = GameMode.parse(args[1].upper())
if gamemode is None:
subquery = (
select(func.max(Score.id))
.where(Score.user_id == target_user.id)
.scalar_subquery()
)
last_score = (
await session.exec(select(Score).where(Score.id == subquery))
).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
gamemode = target_user.playmode
statistics = (
await session.exec(
select(UserStatistics).where(
UserStatistics.user_id == target_user.id,
UserStatistics.mode == gamemode,
)
)
).first()
if not statistics:
return f"User '{args[0]}' has no statistics."
return f"""Stats for {target_user.username} ({gamemode.name.lower()}):
Score: {statistics.total_score} (#{await get_rank(session, statistics)})
Plays: {statistics.play_count} (lv{ceil(statistics.level_current)})
Accuracy: {statistics.hit_accuracy}
PP: {statistics.pp}
"""
async def _mp_name(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp name <name>"
name = args[0]
try:
settings = room.room.settings.model_copy()
settings.name = name
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room name has changed to {name}"
except InvokeException as e:
return e.message
async def _mp_set(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp set <teammode> [<queuemode>]"
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
if not teammode:
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
queuemode = (
{
"0": QueueMode.HOST_ONLY,
"1": QueueMode.ALL_PLAYERS,
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
}.get(args[1])
if len(args) >= 2
else None
)
try:
settings = room.room.settings.model_copy()
settings.match_type = teammode
if queuemode:
settings.queue_mode = queuemode
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
except InvokeException as e:
return e.message
async def _mp_host(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp host <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.TransferHost(signalr_client, user_id)
return f"User '{username}' has been hosted in the room."
except InvokeException as e:
return e.message
async def _mp_start(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
timer = None
if len(args) >= 1 and args[0].isdigit():
timer = int(args[0])
try:
if timer is not None:
await MultiplayerHubs.SendMatchRequest(
signalr_client,
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
)
return ""
else:
await MultiplayerHubs.StartMatch(signalr_client)
return "Good luck! Enjoy game!"
except InvokeException as e:
return e.message
async def _mp_abort(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
try:
await MultiplayerHubs.AbortMatch(signalr_client)
return "Match aborted."
except InvokeException as e:
return e.message
async def _mp_team(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
):
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
return "This command is only available in Team Versus mode."
if len(args) < 2:
return "Usage: !mp team <username> <colour>"
username = args[0]
team = {"red": 0, "blue": 1}.get(args[1])
if team is None:
return "Invalid team colour. Use 'red' or 'blue'."
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client:
return f"User '{username}' is not in the room."
try:
await MultiplayerHubs.SendMatchRequest(
user_client, ChangeTeamRequest(team_id=team)
)
return ""
except InvokeException as e:
return e.message
async def _mp_password(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
password = ""
if len(args) >= 1:
password = args[0]
try:
settings = room.room.settings.model_copy()
settings.password = password
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return "Room password has been set."
except InvokeException as e:
return e.message
async def _mp_kick(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp kick <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.KickUser(signalr_client, user_id)
return f"User '{username}' has been kicked from the room."
except InvokeException as e:
return e.message
async def _mp_map(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp map <mapid> [<playmode>]"
map_id = args[0]
if not map_id.isdigit():
return "Invalid map ID."
map_id = int(map_id)
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
if playmode not in (
GameMode.OSU,
GameMode.TAIKO,
GameMode.FRUITS,
GameMode.MANIA,
None,
):
return "Invalid playmode."
try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return (
f"Cannot convert to {playmode.value}. "
f"Original mode is {beatmap.mode.value}."
)
except HTTPError:
return "Beatmap not found"
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.beatmap_checksum = beatmap.checksum
item.required_mods = []
item.allowed_mods = []
item.freestyle = False
item.beatmap_id = map_id
if playmode is not None:
item.ruleset_id = int(playmode)
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
async def _mp_mods(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp mods <mod1> [<mod2> ...]"
required_mods = []
allowed_mods = []
freestyle = False
for arg in args:
if arg == "None":
required_mods.clear()
allowed_mods.clear()
break
elif arg == "Freestyle":
freestyle = True
elif arg.startswith("+"):
mod = arg.removeprefix("+")
if len(mod) != 2:
return f"Invalid mod: {mod}."
allowed_mods.append(APIMod(acronym=mod))
else:
if len(arg) != 2:
return f"Invalid mod: {arg}."
required_mods.append(APIMod(acronym=arg))
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.freestyle = freestyle
if not freestyle:
item.allowed_mods = allowed_mods
else:
item.allowed_mods = []
item.required_mods = required_mods
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
_MP_COMMANDS = {
"name": _mp_name,
"set": _mp_set,
"host": _mp_host,
"start": _mp_start,
"abort": _mp_abort,
"map": _mp_map,
"mods": _mp_mods,
"kick": _mp_kick,
"password": _mp_password,
"team": _mp_team,
}
_MP_HELP = """!mp name <name>
!mp set <teammode> [<queuemode>]
!mp host <host>
!mp start [<timer>]
!mp abort
!mp map <map> [<playmode>]
!mp mods <mod1> [<mod2> ...]
!mp kick <user>
!mp password [<password>]
!mp team <user> <team:red|blue>"""
@bot.command("mp")
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
if not channel.name.startswith("room_"):
return
room_id = int(channel.name[5:])
room = MultiplayerHubs.rooms.get(room_id)
if not room:
return
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
if not signalr_client:
return
if len(args) < 1:
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
command = args[0].lower()
if command not in _MP_COMMANDS:
return f"No such command: {command}"
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)

306
app/router/chat/channel.py Normal file
View File

@@ -0,0 +1,306 @@
from __future__ import annotations
from typing import Any, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
from app.database.lazer_user import User, UserResp
from app.dependencies.database import get_db, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class UpdateResponse(BaseModel):
presence: list[ChatChannelResp] = Field(default_factory=list)
silences: list[Any] = Field(default_factory=list)
@router.get(
"/chat/updates",
response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
)
async def get_update(
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session)
if channel:
resp.presence.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
)
)
if "silences" in includes:
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
)
async def join_channel(
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
return await server.join_channel(current_user, db_channel, session)
@router.delete(
"/chat/channels/{channel}/users/{user}",
status_code=204,
name="离开频道",
description="将用户移出指定的公开/房间频道。",
tags=["聊天"],
)
async def leave_channel(
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
await server.leave_channel(current_user, db_channel, session)
return
@router.get(
"/chat/channels",
response_model=list[ChatChannelResp],
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
)
async def get_channel_list(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
channels = (
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
results = []
for channel in channels:
assert channel.channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
)
)
return results
class GetChannelResp(BaseModel):
channel: ChatChannelResp
users: list[UserResp] = Field(default_factory=list)
@router.get(
"/chat/channels/{channel}",
response_model=GetChannelResp,
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
)
async def get_channel(
channel: str = Path(..., description="频道 ID/名称"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
if int(id_) == current_user.id:
continue
target_user = await session.get(User, int(id_))
if target_user is None:
raise HTTPException(status_code=404, detail="Target user not found")
users.extend([target_user, current_user])
break
return GetChannelResp(
channel=await ChatChannelResp.from_db(
db_channel,
session,
current_user,
redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
else None,
)
)
class CreateChannelReq(BaseModel):
class AnnounceChannel(BaseModel):
name: str
description: str
message: str | None = None
type: Literal["ANNOUNCE", "PM"] = "PM"
target_id: int | None = None
target_ids: list[int] | None = None
channel: AnnounceChannel | None = None
@model_validator(mode="after")
def check(self) -> Self:
if self.type == "PM":
if self.target_id is None:
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
return self
@router.post(
"/chat/channels",
response_model=ChatChannelResp,
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
)
async def create_channel(
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
if req.type == "PM":
target = await session.get(User, req.target_id)
if not target:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
include_recent_messages=True,
)

243
app/router/chat/message.py Normal file
View File

@@ -0,0 +1,243 @@
from __future__ import annotations
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
MessageType,
SilenceUser,
UserSilenceResp,
)
from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from .banchobot import bot
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class KeepAliveResp(BaseModel):
silences: list[UserSilenceResp] = Field(default_factory=list)
@router.post(
"/chat/ack",
name="保持连接",
response_model=KeepAliveResp,
description="保持公共频道的连接。同时返回最近的禁言列表。",
tags=["聊天"],
)
async def keep_alive(
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
resp = KeepAliveResp()
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp
class MessageReq(BaseModel):
message: str
is_action: bool = False
uuid: str | None = None
@router.post(
"/chat/channels/{channel}/messages",
response_model=ChatMessageResp,
name="发送消息",
description="发送消息到指定频道。",
tags=["聊天"],
)
async def send_message(
channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
assert current_user.id
msg = ChatMessage(
channel_id=db_channel.channel_id,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(db_channel)
resp = await ChatMessageResp.from_db(msg, session, current_user)
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
)
if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session)
return resp
@router.get(
"/chat/channels/{channel}/messages",
response_model=list[ChatMessageResp],
name="获取消息",
description="获取指定频道的消息列表。",
tags=["聊天"],
)
async def get_message(
channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
messages = await session.exec(
select(ChatMessage)
.where(
ChatMessage.channel_id == db_channel.channel_id,
col(ChatMessage.message_id) > since,
col(ChatMessage.message_id) < until if until is not None else True,
)
.order_by(col(ChatMessage.timestamp).desc())
.limit(limit)
)
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
resp.reverse()
return resp
@router.put(
"/chat/channels/{channel}/mark-as-read/{message}",
status_code=204,
name="标记消息为已读",
description="标记指定消息为已读。",
tags=["聊天"],
)
async def mark_as_read(
channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
session: AsyncSession = Depends(get_db),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
await server.mark_as_read(db_channel.channel_id, message)
class PMReq(BaseModel):
target_id: int
message: str
is_action: bool = False
uuid: str | None = None
class NewPMResp(BaseModel):
channel: ChatChannelResp
message: ChatMessageResp
new_channel_id: int
@router.post(
"/chat/new",
name="创建私聊频道",
description="创建一个新的私聊频道。",
tags=["聊天"],
)
async def create_new_pm(
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
user_id = current_user.id
target = await session.get(User, req.target_id)
if target is None:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None:
channel = ChatChannel(
name=f"pm_{user_id}_{req.target_id}",
description="Private message channel",
type=ChannelType.PM,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(target)
await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]
)
msg = ChatMessage(
channel_id=channel.channel_id,
content=req.message,
sender_id=user_id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
await server.send_message_to_channel(message_resp)
return NewPMResp(
channel=channel_resp,
message=message_resp,
new_channel_id=channel_resp.channel_id,
)

285
app/router/chat/server.py Normal file
View File

@@ -0,0 +1,285 @@
from __future__ import annotations
import asyncio
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User
from app.dependencies.database import (
DBFactory,
engine,
get_db_factory,
get_redis,
)
from app.dependencies.user import get_current_user
from app.log import logger
from app.models.chat import ChatEvent
from app.service.subscribers.chat import ChatSubscriber
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
class ChatServer:
def __init__(self):
self.connect_client: dict[int, WebSocket] = {}
self.channels: dict[int, list[int]] = {}
self.redis: Redis = get_redis()
self.tasks: set[asyncio.Task] = set()
self.ChatSubscriber = ChatSubscriber()
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
if user_id in self.connect_client:
del self.connect_client[user_id]
for channel_id, channel in self.channels.items():
if user_id in channel:
channel.remove(user_id)
channel = await ChatChannel.get(channel_id, session)
if channel:
await self.leave_channel(user, channel, session)
async def send_event(self, client: WebSocket, event: ChatEvent):
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
for user_id in self.channels.get(channel_id, []):
client = self.connect_client.get(user_id)
if client:
await self.send_event(client, event)
async def mark_as_read(self, channel_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
client = self.connect_client.get(message.sender_id)
if client:
self._add_task(self.send_event(client, event))
else:
self._add_task(
self.broadcast(
message.channel_id,
event,
)
)
assert message.message_id
await self.mark_as_read(message.channel_id, message.message_id)
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
channel_id = channel.channel_id
assert channel_id is not None
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user.id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
client = self.connect_client.get(user_id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
return channel_resp
async def leave_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user_id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
),
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
server = ChatServer()
chat_router = APIRouter(include_in_schema=False)
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
try:
while True:
packets = await ws.receive_json()
if packets.get("event") == "chat.end":
async for session in factory():
user = await session.get(User, user_id)
if user is None:
break
await server.disconnect(user, session)
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
else:
logger.exception(f"RuntimeError in client {user_id}: {e}")
except Exception:
logger.exception(f"Error in client {user_id}")
@chat_router.websocket("/notification-server")
async def chat_websocket(
websocket: WebSocket,
authorization: str = Header(...),
factory: DBFactory = Depends(get_db_factory),
):
if not server._subscribed:
server._subscribed = True
await server.ChatSubscriber.start_subscribe()
async for session in factory():
token = authorization[7:]
if (
user := await get_current_user(
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
)
) is None:
await websocket.close(code=1008)
return
await websocket.accept()
login = await websocket.receive_json()
if login.get("event") != "chat.start":
await websocket.close(code=1008)
return
user_id = user.id
assert user_id
server.connect(user_id, websocket)
channel = await ChatChannel.get(1, session)
if channel is not None:
await server.join_channel(user, channel, session)
await _listen_stop(websocket, user_id, factory)

View File

@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
participated_user = (
await session.exec(
@@ -138,6 +138,8 @@ async def _participate_room(
participated_user.joined_at = datetime.now(UTC)
db_room.participant_count += 1
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
@router.post(
"/rooms",
@@ -150,11 +152,12 @@ async def create_room(
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
assert current_user.id is not None
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db)
await _participate_room(db_room.id, user_id, db_room, db, redis)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = ""
return created_room
@@ -219,11 +222,12 @@ async def add_user_to_room(
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
await _participate_room(room_id, user_id, db_room, db)
await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db)
@@ -243,6 +247,7 @@ async def remove_user_from_room(
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -257,6 +262,7 @@ async def remove_user_from_room(
if participated_user is not None:
participated_user.left_at = datetime.now(UTC)
db_room.participant_count -= 1
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
await db.commit()
return None
else:

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Literal
from app.const import BANCHOBOT_ID
from app.database import (
BeatmapPlaycounts,
BeatmapPlaycountsResp,
@@ -65,6 +66,7 @@ async def get_users(
include=SEARCH_INCLUDED,
)
for searched_user in searched_users
if searched_user.id != BANCHOBOT_ID
]
)
@@ -91,7 +93,7 @@ async def get_user_info_ruleset(
)
)
).first()
if not searched_user:
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
return await UserResp.from_db(
searched_user,
@@ -123,7 +125,7 @@ async def get_user_info(
)
)
).first()
if not searched_user:
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
return await UserResp.from_db(
searched_user,
@@ -148,7 +150,7 @@ async def get_user_beatmapsets(
offset: int = Query(0, ge=0, description="偏移量"),
):
user = await session.get(User, user_id)
if not user:
if not user or user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
if type in {
@@ -218,7 +220,7 @@ async def get_user_scores(
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_user = await session.get(User, user_id)
if not db_user:
if not db_user or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
gamemode = mode or db_user.playmode
@@ -271,7 +273,7 @@ async def get_user_events(
session: AsyncSession = Depends(get_db),
):
db_user = await session.get(User, user)
if db_user is None:
if db_user is None or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, "User Not found")
events = await db_user.awaitable_attrs.events
if limit is not None:

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from app.const import BANCHOBOT_ID
from app.database.lazer_user import User
from app.database.statistics import UserStatistics
from app.dependencies.database import engine
from app.models.score import GameMode
from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_banchobot():
async with AsyncSession(engine) as session:
is_exist = (
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first()
if not is_exist:
banchobot = User(
username="BanchoBot",
email="banchobot@ppy.sh",
is_bot=True,
pw_bcrypt="0",
id=BANCHOBOT_ID,
avatar_url="https://a.ppy.sh/3",
country_code="SH",
website="https://twitter.com/banchoboat",
)
session.add(banchobot)
statistics = UserStatistics(user_id=BANCHOBOT_ID, mode=GameMode.OSU)
session.add(statistics)
await session.commit()

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import json
from app.const import BANCHOBOT_ID
from app.database.playlists import Playlist
from app.database.room import Room
from app.dependencies.database import engine, get_redis
@@ -26,12 +27,12 @@ async def create_daily_challenge_room(
return await create_playlist_room(
session=session,
name=str(today),
host_id=3,
host_id=BANCHOBOT_ID,
playlist=[
Playlist(
id=0,
room_id=0,
owner_id=3,
owner_id=BANCHOBOT_ID,
ruleset_id=ruleset_id,
beatmap_id=beatmap,
required_mods=required_mods,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database.lazer_user import User
from app.database.statistics import UserStatistics
from app.dependencies.database import engine
@@ -15,6 +16,9 @@ async def create_rx_statistics():
async with AsyncSession(engine) as session:
users = (await session.exec(select(User.id))).all()
for i in users:
if i == BANCHOBOT_ID:
continue
if settings.enable_rx:
for mode in (
GameMode.OSURX,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.playlists import Playlist
from app.database.room import APIUploadedRoom, Room
from app.dependencies.fetcher import get_fetcher
@@ -25,6 +26,18 @@ async def create_playlist_room_from_api(
session.add(db_room)
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
await session.refresh(db_room)
return db_room
@@ -57,6 +70,18 @@ async def create_playlist_room(
session.add(db_room)
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, playlist, host_id)
await session.refresh(db_room)
return db_room

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from fnmatch import fnmatch
from typing import Any
from app.dependencies.database import get_redis_pubsub
@@ -23,18 +24,28 @@ class RedisSubscriber:
del self.handlers[channel]
await self.pubsub.unsubscribe(channel)
def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]):
if channel not in self.handlers:
self.handlers[channel] = []
self.handlers[channel].append(handler)
async def listen(self):
while True:
message = await self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=None
)
if message is not None and message["type"] == "message":
method = self.handlers.get(message["channel"])
if method:
matched_handlers = []
if message["channel"] in self.handlers:
matched_handlers.extend(self.handlers[message["channel"]])
for pattern, handlers in self.handlers.items():
if fnmatch(message["channel"], pattern):
matched_handlers.extend(handlers)
if matched_handlers:
await asyncio.gather(
*[
handler(message["channel"], message["data"])
for handler in method
for handler in matched_handlers
]
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import RedisSubscriber
if TYPE_CHECKING:
from app.router.chat.server import ChatServer
JOIN_CHANNEL = "chat:room:joined"
EXIT_CHANNEL = "chat:room:left"
class ChatSubscriber(RedisSubscriber):
def __init__(self):
super().__init__()
self.room_subscriber: dict[int, list[int]] = {}
self.chat_server: "ChatServer | None" = None
async def start_subscribe(self):
await self.subscribe(JOIN_CHANNEL)
self.add_handler(JOIN_CHANNEL, self.on_join_room)
await self.subscribe(EXIT_CHANNEL)
self.add_handler(EXIT_CHANNEL, self.on_leave_room)
self.start()
async def on_join_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
async def on_leave_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))

View File

@@ -6,6 +6,7 @@ from typing import override
from app.database import Room
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.lazer_user import User
from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist
@@ -172,6 +173,20 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
self.get_client_by_id(str(user_id)), server_room, user
)
def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom:
store = self.get_or_create_state(client)
if store.room_id == 0:
raise InvokeException("You are not in a room")
if store.room_id not in self.rooms:
raise InvokeException("Room does not exist")
server_room = self.rooms[store.room_id]
return server_room
def _ensure_host(self, client: Client, server_room: ServerMultiplayerRoom):
room = server_room.room
if room.host is None or room.host.user_id != client.user_id:
raise InvokeException("You are not the host of this room")
async def CreateRoom(self, client: Client, room: MultiplayerRoom):
logger.info(f"[MultiplayerHub] {client.user_id} creating room")
store = self.get_or_create_state(client)
@@ -195,6 +210,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Multiplayer room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue]
db_room.channel_id = channel.channel_id
item = room.playlist[0]
item.owner_id = client.user_id
room.room_id = db_room.id
@@ -280,6 +307,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if db_room is None:
raise InvokeException("Room does not exist in database")
db_room.participant_count += 1
redis = get_redis()
await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}")
return room
async def change_beatmap_availability(
@@ -914,6 +945,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if target_store:
target_store.room_id = 0
redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host
async with AsyncSession(engine) as session:
@@ -1085,17 +1119,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
)
async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings):
store = self.get_or_create_state(client)
if store.room_id == 0:
raise InvokeException("You are not in a room")
if store.room_id not in self.rooms:
raise InvokeException("Room does not exist")
server_room = self.rooms[store.room_id]
server_room = self._ensure_in_room(client)
self._ensure_host(client, server_room)
room = server_room.room
if room.host is None or room.host.user_id != client.user_id:
raise InvokeException("You are not the host of this room")
if room.state != MultiplayerRoomState.OPEN:
raise InvokeException("Cannot change settings while playing")

View File

@@ -13,6 +13,7 @@ from app.router import (
api_v1_router,
api_v2_router,
auth_router,
chat_router,
fetcher_router,
file_router,
private_router,
@@ -21,6 +22,7 @@ from app.router import (
)
from app.router.redirect import redirect_router
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.create_banchobot import create_banchobot
from app.service.daily_challenge import daily_challenge_job
from app.service.osu_rx_statistics import create_rx_statistics
from app.service.pp_recalculate import recalculate_all_players_pp
@@ -42,6 +44,7 @@ async def lifespan(app: FastAPI):
await calculate_user_rank(True)
init_scheduler()
await daily_challenge_job()
await create_banchobot()
# on shutdown
yield
stop_scheduler()
@@ -71,6 +74,7 @@ app = FastAPI(
app.include_router(api_v2_router)
app.include_router(api_v1_router)
app.include_router(chat_router)
app.include_router(redirect_api_router)
app.include_router(signalr_router)
app.include_router(fetcher_router)

View File

@@ -0,0 +1,192 @@
"""chat: add chat
Revision ID: dd33d89aa2c2
Revises: 9f6b27e8ea51
Create Date: 2025-08-15 14:22:34.775877
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "dd33d89aa2c2"
down_revision: str | Sequence[str] | None = "9f6b27e8ea51"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
channel_table = op.create_table(
"chat_channels",
sa.Column("name", sa.VARCHAR(length=50), nullable=True),
sa.Column("description", sa.VARCHAR(length=255), nullable=True),
sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column(
"type",
sa.Enum(
"PUBLIC",
"PRIVATE",
"MULTIPLAYER",
"SPECTATOR",
"TEMPORARY",
"PM",
"GROUP",
"SYSTEM",
"ANNOUNCE",
"TEAM",
name="channeltype",
),
nullable=False,
),
sa.Column("channel_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("channel_id"),
)
op.create_index(
op.f("ix_chat_channels_channel_id"),
"chat_channels",
["channel_id"],
unique=False,
)
op.create_index(
op.f("ix_chat_channels_description"),
"chat_channels",
["description"],
unique=False,
)
op.create_index(
op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False
)
op.create_index(
op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False
)
op.create_table(
"chat_messages",
sa.Column("channel_id", sa.Integer(), nullable=False),
sa.Column("content", sa.VARCHAR(length=1000), nullable=True),
sa.Column("message_id", sa.Integer(), nullable=False),
sa.Column("sender_id", sa.BigInteger(), nullable=True),
sa.Column("timestamp", sa.DateTime(), nullable=True),
sa.Column(
"type",
sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"),
nullable=False,
),
sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.ForeignKeyConstraint(
["channel_id"],
["chat_channels.channel_id"],
),
sa.ForeignKeyConstraint(
["sender_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("message_id"),
)
op.create_index(
op.f("ix_chat_messages_channel_id"),
"chat_messages",
["channel_id"],
unique=False,
)
op.create_index(
op.f("ix_chat_messages_message_id"),
"chat_messages",
["message_id"],
unique=False,
)
op.create_index(
op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False
)
op.create_index(
op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False
)
op.create_index(
op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False
)
op.create_table(
"chat_silence_users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("channel_id", sa.Integer(), nullable=False),
sa.Column("until", sa.DateTime(), nullable=True),
sa.Column("banned_at", sa.DateTime(), nullable=False),
sa.Column("reason", sa.VARCHAR(length=255), nullable=True),
sa.ForeignKeyConstraint(
["channel_id"],
["chat_channels.channel_id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_chat_silence_users_channel_id"),
"chat_silence_users",
["channel_id"],
unique=False,
)
op.create_index(
op.f("ix_chat_silence_users_reason"),
"chat_silence_users",
["reason"],
unique=False,
)
op.create_index(
op.f("ix_chat_silence_users_until"),
"chat_silence_users",
["until"],
unique=False,
)
op.create_index(
op.f("ix_chat_silence_users_banned_at"),
"chat_silence_users",
["banned_at"],
unique=False,
)
op.create_index(
op.f("ix_chat_silence_users_user_id"),
"chat_silence_users",
["user_id"],
unique=False,
)
op.create_index(
op.f("ix_chat_silence_users_id"),
"chat_silence_users",
["id"],
unique=False,
)
op.bulk_insert(
channel_table,
[
{
"name": "osu!",
"description": "General discussion for osu!",
"type": "PUBLIC",
},
{
"name": "announce",
"description": "Official announcements",
"type": "PUBLIC",
},
],
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("chat_silence_users")
op.drop_table("chat_messages")
op.drop_table("chat_channels")
# ### end Alembic commands ###

View File

@@ -0,0 +1,54 @@
"""room: add channel_id
Revision ID: df9f725a077c
Revises: dd33d89aa2c2
Create Date: 2025-08-16 08:05:28.748265
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "df9f725a077c"
down_revision: str | Sequence[str] | None = "dd33d89aa2c2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True
)
op.create_index(
op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False
)
op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("rooms", "channel_id")
op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users")
op.alter_column(
"chat_silence_users",
"banned_at",
existing_type=mysql.DATETIME(),
nullable=False,
)
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
# ### end Alembic commands ###