feat(notification): support notification

This commit is contained in:
MingxuanGame
2025-08-21 07:22:44 +00:00
parent 6ac9a124ea
commit 9fb0d0c198
13 changed files with 626 additions and 70 deletions

View File

@@ -0,0 +1,149 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.config import settings
from app.database.lazer_user import User
from app.database.notification import Notification, UserNotification
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.models.chat import ChatEvent
from app.router.v2 import api_v2_router as router
from . import channel, message # noqa: F401
from .server import (
chat_router as chat_router,
server,
)
from fastapi import Body, Query, Security
from pydantic import BaseModel
from sqlmodel import col, func, select
__all__ = ["chat_router"]
class NotificationResp(BaseModel):
has_more: bool
notifications: list[Notification]
unread_count: int
notification_endpoint: str
@router.get(
"/notifications",
tags=["通知", "聊天"],
name="获取通知",
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
response_model=NotificationResp,
)
async def get_notifications(
session: Database,
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
current_user: User = Security(get_client_user),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
else:
notification_endpoint = "/notification-server"
query = select(UserNotification).where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
if max_id is not None:
query = query.where(UserNotification.notification_id < max_id)
notifications = (await session.exec(query)).all()
total_count = (
await session.exec(
select(func.count())
.select_from(UserNotification)
.where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
)
).one()
unread_count = len(notifications)
return NotificationResp(
has_more=unread_count < total_count,
notifications=[notification.notification for notification in notifications],
unread_count=unread_count,
notification_endpoint=notification_endpoint,
)
class _IdentityReq(BaseModel):
category: str | None = None
id: int | None = None
object_id: int | None = None
object_type: int | None = None
async def _get_notifications(
session: Database, current_user: User, identities: list[_IdentityReq]
) -> list[UserNotification]:
result: dict[int, UserNotification] = {}
base_query = select(UserNotification).where(
UserNotification.user_id == current_user.id,
col(UserNotification.is_read).is_(False),
)
for identity in identities:
query = base_query
if identity.id is not None:
query = base_query.where(UserNotification.notification_id == identity.id)
if identity.object_id is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_id) == identity.object_id
)
)
if identity.object_type is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_type) == identity.object_type
)
)
if identity.category is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.category) == identity.category
)
)
result.update({n.notification_id: n for n in await session.exec(query)})
return list(result.values())
@router.post(
"/notifications/mark-read",
tags=["通知", "聊天"],
name="标记通知为已读",
description="标记当前用户的通知为已读。",
status_code=204,
)
async def mark_notifications_as_read(
session: Database,
identities: list[_IdentityReq] = Body(default_factory=list),
notifications: list[_IdentityReq] = Body(default_factory=list),
current_user: User = Security(get_client_user),
):
identities.extend(notifications)
user_notifications = await _get_notifications(session, current_user, identities)
for user_notification in user_notifications:
user_notification.is_read = True
assert current_user.id
await server.send_event(
current_user.id,
ChatEvent(
event="read",
data={
"notifications": [i.model_dump() for i in identities],
"read_count": len(user_notifications),
"timestamp": datetime.now(UTC).isoformat(),
},
),
)
await session.commit()

View File

@@ -0,0 +1,649 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from datetime import timedelta
from math import ceil
import random
import shlex
from app.const import BANCHOBOT_ID
from app.database import ChatMessageResp
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
from app.database.lazer_user import User
from app.database.score import Score
from app.database.statistics import UserStatistics, get_rank
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.models.mods import APIMod, get_available_mods, mod_to_save
from app.models.multiplayer_hub import (
ChangeTeamRequest,
ServerMultiplayerRoom,
StartMatchCountdownRequest,
)
from app.models.room import MatchType, QueueMode, RoomStatus
from app.models.score import GameMode
from app.signalr.hub import MultiplayerHubs
from app.signalr.hub.hub import Client
from .server import server
from httpx import HTTPError
from sqlalchemy.orm import joinedload
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
HandlerResult = str | None | Awaitable[str | None]
Handler = Callable[[User, list[str], AsyncSession, ChatChannel], HandlerResult]
class Bot:
def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None:
self._handlers: dict[str, Handler] = {}
self.bot_user_id = bot_user_id
# decorator: @bot.command("ping")
def command(self, name: str) -> Callable[[Handler], Handler]:
def _decorator(func: Handler) -> Handler:
self._handlers[name.lower()] = func
return func
return _decorator
def parse(self, content: str) -> tuple[str, list[str]] | None:
if not content or not content.startswith("!"):
return None
try:
parts = shlex.split(content[1:])
except ValueError:
parts = content[1:].split()
if not parts:
return None
cmd = parts[0].lower()
args = parts[1:]
return cmd, args
async def try_handle(
self,
user: User,
channel: ChatChannel,
content: str,
session: AsyncSession,
) -> None:
parsed = self.parse(content)
if not parsed:
return
cmd, args = parsed
handler = self._handlers.get(cmd)
reply: str | None = None
if handler is None:
return
else:
try:
res = handler(user, args, session, channel)
if asyncio.iscoroutine(res):
res = await res
reply = res # type: ignore[assignment]
except Exception:
reply = "Unknown error occured."
if reply:
await self._send_reply(user, channel, reply, session)
async def _send_message(
self, channel: ChatChannel, content: str, session: AsyncSession
) -> None:
bot = await session.get(User, self.bot_user_id)
if bot is None:
return
channel_id = channel.channel_id
if channel_id is None:
return
assert bot.id is not None
msg = ChatMessage(
channel_id=channel_id,
content=content,
sender_id=bot.id,
type=MessageType.PLAIN,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(bot)
resp = await ChatMessageResp.from_db(msg, session, bot)
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(
self, user: User, session: AsyncSession
) -> ChatChannel | None:
user_id = user.id
if user_id is None:
return None
bot = await session.get(User, self.bot_user_id)
if bot is None or bot.id is None:
return None
channel = await ChatChannel.get_pm_channel(user_id, bot.id, session)
if channel is None:
channel = ChatChannel(
name=f"pm_{user_id}_{bot.id}",
description="Private message channel",
type=ChannelType.PM,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(user)
await session.refresh(bot)
await server.batch_join_channel([user, bot], channel, session)
return channel
async def _send_reply(
self,
user: User,
src_channel: ChatChannel,
content: str,
session: AsyncSession,
) -> None:
target_channel = src_channel
if src_channel.type == ChannelType.PUBLIC:
pm = await self._ensure_pm_channel(user, session)
if pm is not None:
target_channel = pm
await self._send_message(target_channel, content, session)
bot = Bot()
@bot.command("help")
async def _help(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
cmds = sorted(bot._handlers.keys())
if args:
target = args[0].lower()
if target in bot._handlers:
return f"Usage: !{target} [args]"
return f"No such command: {target}"
if not cmds:
return "No available commands"
return "Available: " + ", ".join(f"!{c}" for c in cmds)
@bot.command("roll")
def _roll(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
r = random.randint(1, 100)
return f"{user.username} rolls {r} point(s)"
@bot.command("stats")
async def _stats(
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) >= 1:
target_user = (
await session.exec(select(User).where(User.username == args[0]))
).first()
if not target_user:
return f"User '{args[0]}' not found."
else:
target_user = user
gamemode = None
if len(args) >= 2:
gamemode = GameMode.parse(args[1].upper())
if gamemode is None:
subquery = (
select(func.max(Score.id))
.where(Score.user_id == target_user.id)
.scalar_subquery()
)
last_score = (
await session.exec(select(Score).where(Score.id == subquery))
).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
gamemode = target_user.playmode
statistics = (
await session.exec(
select(UserStatistics).where(
UserStatistics.user_id == target_user.id,
UserStatistics.mode == gamemode,
)
)
).first()
if not statistics:
return f"User '{args[0]}' has no statistics."
return f"""Stats for {target_user.username} ({gamemode.name.lower()}):
Score: {statistics.total_score} (#{await get_rank(session, statistics)})
Plays: {statistics.play_count} (lv{ceil(statistics.level_current)})
Accuracy: {statistics.hit_accuracy}
PP: {statistics.pp}
"""
async def _mp_name(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp name <name>"
name = args[0]
try:
settings = room.room.settings.model_copy()
settings.name = name
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room name has changed to {name}"
except InvokeException as e:
return e.message
async def _mp_set(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp set <teammode> [<queuemode>]"
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
if not teammode:
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
queuemode = (
{
"0": QueueMode.HOST_ONLY,
"1": QueueMode.ALL_PLAYERS,
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
}.get(args[1])
if len(args) >= 2
else None
)
try:
settings = room.room.settings.model_copy()
settings.match_type = teammode
if queuemode:
settings.queue_mode = queuemode
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
except InvokeException as e:
return e.message
async def _mp_host(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp host <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.TransferHost(signalr_client, user_id)
return f"User '{username}' has been hosted in the room."
except InvokeException as e:
return e.message
async def _mp_start(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
timer = None
if len(args) >= 1 and args[0].isdigit():
timer = int(args[0])
try:
if timer is not None:
await MultiplayerHubs.SendMatchRequest(
signalr_client,
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
)
return ""
else:
await MultiplayerHubs.StartMatch(signalr_client)
return "Good luck! Enjoy game!"
except InvokeException as e:
return e.message
async def _mp_abort(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
try:
await MultiplayerHubs.AbortMatch(signalr_client)
return "Match aborted."
except InvokeException as e:
return e.message
async def _mp_team(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
):
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
return "This command is only available in Team Versus mode."
if len(args) < 2:
return "Usage: !mp team <username> <colour>"
username = args[0]
team = {"red": 0, "blue": 1}.get(args[1])
if team is None:
return "Invalid team colour. Use 'red' or 'blue'."
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client:
return f"User '{username}' is not in the room."
if (
user_client.user_id != signalr_client.user_id
and room.room.host.user_id != signalr_client.user_id
):
return "You are not allowed to change other users' teams."
try:
await MultiplayerHubs.SendMatchRequest(
user_client, ChangeTeamRequest(team_id=team)
)
return ""
except InvokeException as e:
return e.message
async def _mp_password(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
password = ""
if len(args) >= 1:
password = args[0]
try:
settings = room.room.settings.model_copy()
settings.password = password
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return "Room password has been set."
except InvokeException as e:
return e.message
async def _mp_kick(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp kick <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.KickUser(signalr_client, user_id)
return f"User '{username}' has been kicked from the room."
except InvokeException as e:
return e.message
async def _mp_map(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp map <mapid> [<playmode>]"
if room.status != RoomStatus.IDLE:
return "Cannot change map while the game is running."
map_id = args[0]
if not map_id.isdigit():
return "Invalid map ID."
map_id = int(map_id)
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
if playmode not in (
GameMode.OSU,
GameMode.TAIKO,
GameMode.FRUITS,
GameMode.MANIA,
None,
):
return "Invalid playmode."
try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return (
f"Cannot convert to {playmode.value}. "
f"Original mode is {beatmap.mode.value}."
)
except HTTPError:
return "Beatmap not found"
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.beatmap_checksum = beatmap.checksum
item.required_mods = []
item.allowed_mods = []
item.freestyle = False
item.beatmap_id = map_id
if playmode is not None:
item.ruleset_id = int(playmode)
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
async def _mp_mods(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp mods <mod1> [<mod2> ...]"
if room.status != RoomStatus.IDLE:
return "Cannot change mods while the game is running."
required_mods = []
allowed_mods = []
freestyle = False
freemod = False
for arg in args:
arg = arg.upper()
if arg == "NONE":
required_mods.clear()
allowed_mods.clear()
break
elif arg == "FREESTYLE":
freestyle = True
elif arg == "FREEMOD":
freemod = True
elif arg.startswith("+"):
mod = arg.removeprefix("+")
if len(mod) != 2:
return f"Invalid mod: {mod}."
allowed_mods.append(APIMod(acronym=mod))
else:
if len(arg) != 2:
return f"Invalid mod: {arg}."
required_mods.append(APIMod(acronym=arg))
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.freestyle = freestyle
if freestyle:
item.allowed_mods = []
elif freemod:
item.allowed_mods = get_available_mods(
current_item.ruleset_id, required_mods
)
else:
item.allowed_mods = allowed_mods
item.required_mods = required_mods
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
_MP_COMMANDS = {
"name": _mp_name,
"set": _mp_set,
"host": _mp_host,
"start": _mp_start,
"abort": _mp_abort,
"map": _mp_map,
"mods": _mp_mods,
"kick": _mp_kick,
"password": _mp_password,
"team": _mp_team,
}
_MP_HELP = """!mp name <name>
!mp set <teammode> [<queuemode>]
!mp host <host>
!mp start [<timer>]
!mp abort
!mp map <map> [<playmode>]
!mp mods <mod1> [<mod2> ...]
!mp kick <user>
!mp password [<password>]
!mp team <user> <team:red|blue>"""
@bot.command("mp")
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
if not channel.name.startswith("room_"):
return
room_id = int(channel.name[5:])
room = MultiplayerHubs.rooms.get(room_id)
if not room:
return
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
if not signalr_client:
return
if len(args) < 1:
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
command = args[0].lower()
if command not in _MP_COMMANDS:
return f"No such command: {command}"
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)
async def _score(
user_id: int,
session: AsyncSession,
include_fail: bool = False,
gamemode: GameMode | None = None,
) -> str:
q = (
select(Score)
.where(Score.user_id == user_id)
.order_by(col(Score.id).desc())
.options(joinedload(Score.beatmap))
)
if not include_fail:
q = q.where(Score.passed.is_(True))
if gamemode is not None:
q = q.where(Score.gamemode == gamemode)
score = (await session.exec(q)).first()
if score is None:
return "You have no scores."
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
Played at {score.started_at}
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
if score.gamemode == GameMode.MANIA:
keys = next(
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
)
if keys is None:
keys = f"{int(score.beatmap.cs)}K"
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
result += (
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
)
return result
@bot.command("re")
async def _re(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
gamemode = None
if len(args) >= 1:
gamemode = GameMode.parse(args[0])
return await _score(user.id, session, include_fail=True, gamemode=gamemode)
@bot.command("pr")
async def _pr(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
gamemode = None
if len(args) >= 1:
gamemode = GameMode.parse(args[0])
return await _score(user.id, session, include_fail=False, gamemode=gamemode)

View File

@@ -0,0 +1,305 @@
from __future__ import annotations
from typing import Any, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
from app.database.lazer_user import User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis
from sqlmodel import col, select
class UpdateResponse(BaseModel):
presence: list[ChatChannelResp] = Field(default_factory=list)
silences: list[Any] = Field(default_factory=list)
@router.get(
"/chat/updates",
response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
)
async def get_update(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session)
if channel:
resp.presence.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
)
)
if "silences" in includes:
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
)
async def join_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
return await server.join_channel(current_user, db_channel, session)
@router.delete(
"/chat/channels/{channel}/users/{user}",
status_code=204,
name="离开频道",
description="将用户移出指定的公开/房间频道。",
tags=["聊天"],
)
async def leave_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
await server.leave_channel(current_user, db_channel, session)
return
@router.get(
"/chat/channels",
response_model=list[ChatChannelResp],
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
)
async def get_channel_list(
session: Database,
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
channels = (
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
results = []
for channel in channels:
assert channel.channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
)
)
return results
class GetChannelResp(BaseModel):
channel: ChatChannelResp
users: list[UserResp] = Field(default_factory=list)
@router.get(
"/chat/channels/{channel}",
response_model=GetChannelResp,
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
)
async def get_channel(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
if int(id_) == current_user.id:
continue
target_user = await session.get(User, int(id_))
if target_user is None:
raise HTTPException(status_code=404, detail="Target user not found")
users.extend([target_user, current_user])
break
return GetChannelResp(
channel=await ChatChannelResp.from_db(
db_channel,
session,
current_user,
redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
else None,
)
)
class CreateChannelReq(BaseModel):
class AnnounceChannel(BaseModel):
name: str
description: str
message: str | None = None
type: Literal["ANNOUNCE", "PM"] = "PM"
target_id: int | None = None
target_ids: list[int] | None = None
channel: AnnounceChannel | None = None
@model_validator(mode="after")
def check(self) -> Self:
if self.type == "PM":
if self.target_id is None:
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
return self
@router.post(
"/chat/channels",
response_model=ChatChannelResp,
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
)
async def create_channel(
session: Database,
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
redis: Redis = Depends(get_redis),
):
if req.type == "PM":
target = await session.get(User, req.target_id)
if not target:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
include_recent_messages=True,
)

View File

@@ -0,0 +1,253 @@
from __future__ import annotations
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
MessageType,
SilenceUser,
UserSilenceResp,
)
from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from .banchobot import bot
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
class KeepAliveResp(BaseModel):
silences: list[UserSilenceResp] = Field(default_factory=list)
@router.post(
"/chat/ack",
name="保持连接",
response_model=KeepAliveResp,
description="保持公共频道的连接。同时返回最近的禁言列表。",
tags=["聊天"],
)
async def keep_alive(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
resp = KeepAliveResp()
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp
class MessageReq(BaseModel):
message: str
is_action: bool = False
uuid: str | None = None
@router.post(
"/chat/channels/{channel}/messages",
response_model=ChatMessageResp,
name="发送消息",
description="发送消息到指定频道。",
tags=["聊天"],
)
async def send_message(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
assert current_user.id
msg = ChatMessage(
channel_id=db_channel.channel_id,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(db_channel)
resp = await ChatMessageResp.from_db(msg, session, current_user)
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
)
if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session)
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
await server.new_private_notification(
ChannelMessage(
msg, current_user, [int(u) for u in user_ids], db_channel.type
)
)
elif db_channel.type == ChannelType.TEAM:
await server.new_private_notification(ChannelMessageTeam(msg, current_user))
return resp
@router.get(
"/chat/channels/{channel}/messages",
response_model=list[ChatMessageResp],
name="获取消息",
description="获取指定频道的消息列表。",
tags=["聊天"],
)
async def get_message(
session: Database,
channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
messages = await session.exec(
select(ChatMessage)
.where(
ChatMessage.channel_id == db_channel.channel_id,
col(ChatMessage.message_id) > since,
col(ChatMessage.message_id) < until if until is not None else True,
)
.order_by(col(ChatMessage.timestamp).desc())
.limit(limit)
)
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
resp.reverse()
return resp
@router.put(
"/chat/channels/{channel}/mark-as-read/{message}",
status_code=204,
name="标记消息为已读",
description="标记指定消息为已读。",
tags=["聊天"],
)
async def mark_as_read(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
assert current_user.id
await server.mark_as_read(db_channel.channel_id, current_user.id, message)
class PMReq(BaseModel):
target_id: int
message: str
is_action: bool = False
uuid: str | None = None
class NewPMResp(BaseModel):
channel: ChatChannelResp
message: ChatMessageResp
new_channel_id: int
@router.post(
"/chat/new",
name="创建私聊频道",
description="创建一个新的私聊频道。",
tags=["聊天"],
)
async def create_new_pm(
session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
redis: Redis = Depends(get_redis),
):
user_id = current_user.id
target = await session.get(User, req.target_id)
if target is None:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None:
channel = ChatChannel(
name=f"pm_{user_id}_{req.target_id}",
description="Private message channel",
type=ChannelType.PM,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(target)
await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]
)
msg = ChatMessage(
channel_id=channel.channel_id,
content=req.message,
sender_id=user_id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
await server.send_message_to_channel(message_resp)
return NewPMResp(
channel=channel_resp,
message=message_resp,
new_channel_id=channel_resp.channel_id,
)

View File

@@ -0,0 +1,315 @@
from __future__ import annotations
import asyncio
from typing import overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User
from app.database.notification import UserNotification, insert_notification
from app.dependencies.database import (
DBFactory,
get_db_factory,
get_redis,
with_db,
)
from app.dependencies.user import get_current_user
from app.log import logger
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class ChatServer:
def __init__(self):
self.connect_client: dict[int, WebSocket] = {}
self.channels: dict[int, list[int]] = {}
self.redis: Redis = get_redis()
self.tasks: set[asyncio.Task] = set()
self.ChatSubscriber = ChatSubscriber()
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
if user_id in self.connect_client:
del self.connect_client[user_id]
for channel_id, channel in self.channels.items():
if user_id in channel:
channel.remove(user_id)
channel = await ChatChannel.get(channel_id, session)
if channel:
await self.leave_channel(user, channel, session)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@overload
async def send_event(self, client: WebSocket, event: ChatEvent): ...
async def send_event(self, client: WebSocket | int, event: ChatEvent):
if isinstance(client, int):
client_ = self.connect_client.get(client)
if client_ is None:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
for user_id in self.channels.get(channel_id, []):
await self.send_event(user_id, event)
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
self._add_task(self.send_event(message.sender_id, event))
else:
self._add_task(
self.broadcast(
message.channel_id,
event,
)
)
assert message.message_id
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
channel_id = channel.channel_id
assert channel_id is not None
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
return channel_resp
async def leave_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
),
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
data["details"] = user_notification.notification.details
await server.send_event(
user_notification.user_id,
ChatEvent(
event="new",
data=data,
),
)
server = ChatServer()
chat_router = APIRouter(include_in_schema=False)
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
try:
while True:
packets = await ws.receive_json()
if packets.get("event") == "chat.end":
async for session in factory():
user = await session.get(User, user_id)
if user is None:
break
await server.disconnect(user, session)
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
else:
logger.exception(f"RuntimeError in client {user_id}: {e}")
except Exception:
logger.exception(f"Error in client {user_id}")
@chat_router.websocket("/notification-server")
async def chat_websocket(
websocket: WebSocket,
authorization: str = Header(...),
factory: DBFactory = Depends(get_db_factory),
):
if not server._subscribed:
server._subscribed = True
await server.ChatSubscriber.start_subscribe()
async for session in factory():
token = authorization[7:]
if (
user := await get_current_user(
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
await websocket.close(code=1008)
return
await websocket.accept()
login = await websocket.receive_json()
if login.get("event") != "chat.start":
await websocket.close(code=1008)
return
user_id = user.id
assert user_id
server.connect(user_id, websocket)
channel = await ChatChannel.get(1, session)
if channel is not None:
await server.join_channel(user, channel, session)
await _listen_stop(websocket, user_id, factory)