feat(notification): support notification
This commit is contained in:
149
app/router/notification/__init__.py
Normal file
149
app/router/notification/__init__.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.database.lazer_user import User
|
||||
from app.database.notification import Notification, UserNotification
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.chat import ChatEvent
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from . import channel, message # noqa: F401
|
||||
from .server import (
|
||||
chat_router as chat_router,
|
||||
server,
|
||||
)
|
||||
|
||||
from fastapi import Body, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, func, select
|
||||
|
||||
__all__ = ["chat_router"]
|
||||
|
||||
|
||||
class NotificationResp(BaseModel):
|
||||
has_more: bool
|
||||
notifications: list[Notification]
|
||||
unread_count: int
|
||||
notification_endpoint: str
|
||||
|
||||
|
||||
@router.get(
|
||||
"/notifications",
|
||||
tags=["通知", "聊天"],
|
||||
name="获取通知",
|
||||
description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。",
|
||||
response_model=NotificationResp,
|
||||
)
|
||||
async def get_notifications(
|
||||
session: Database,
|
||||
max_id: int | None = Query(None, description="获取 ID 小于此值的通知"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
if max_id is not None:
|
||||
query = query.where(UserNotification.notification_id < max_id)
|
||||
notifications = (await session.exec(query)).all()
|
||||
total_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(UserNotification)
|
||||
.where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
unread_count = len(notifications)
|
||||
|
||||
return NotificationResp(
|
||||
has_more=unread_count < total_count,
|
||||
notifications=[notification.notification for notification in notifications],
|
||||
unread_count=unread_count,
|
||||
notification_endpoint=notification_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class _IdentityReq(BaseModel):
|
||||
category: str | None = None
|
||||
id: int | None = None
|
||||
object_id: int | None = None
|
||||
object_type: int | None = None
|
||||
|
||||
|
||||
async def _get_notifications(
|
||||
session: Database, current_user: User, identities: list[_IdentityReq]
|
||||
) -> list[UserNotification]:
|
||||
result: dict[int, UserNotification] = {}
|
||||
base_query = select(UserNotification).where(
|
||||
UserNotification.user_id == current_user.id,
|
||||
col(UserNotification.is_read).is_(False),
|
||||
)
|
||||
for identity in identities:
|
||||
query = base_query
|
||||
if identity.id is not None:
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/notifications/mark-read",
|
||||
tags=["通知", "聊天"],
|
||||
name="标记通知为已读",
|
||||
description="标记当前用户的通知为已读。",
|
||||
status_code=204,
|
||||
)
|
||||
async def mark_notifications_as_read(
|
||||
session: Database,
|
||||
identities: list[_IdentityReq] = Body(default_factory=list),
|
||||
notifications: list[_IdentityReq] = Body(default_factory=list),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
identities.extend(notifications)
|
||||
user_notifications = await _get_notifications(session, current_user, identities)
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
event="read",
|
||||
data={
|
||||
"notifications": [i.model_dump() for i in identities],
|
||||
"read_count": len(user_notifications),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
649
app/router/notification/banchobot.py
Normal file
649
app/router/notification/banchobot.py
Normal file
@@ -0,0 +1,649 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from math import ceil
|
||||
import random
|
||||
import shlex
|
||||
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||
from app.database.lazer_user import User
|
||||
from app.database.score import Score
|
||||
from app.database.statistics import UserStatistics, get_rank
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.models.mods import APIMod, get_available_mods, mod_to_save
|
||||
from app.models.multiplayer_hub import (
|
||||
ChangeTeamRequest,
|
||||
ServerMultiplayerRoom,
|
||||
StartMatchCountdownRequest,
|
||||
)
|
||||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||||
from app.models.score import GameMode
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
from app.signalr.hub.hub import Client
|
||||
|
||||
from .server import server
|
||||
|
||||
from httpx import HTTPError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
HandlerResult = str | None | Awaitable[str | None]
|
||||
Handler = Callable[[User, list[str], AsyncSession, ChatChannel], HandlerResult]
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None:
|
||||
self._handlers: dict[str, Handler] = {}
|
||||
self.bot_user_id = bot_user_id
|
||||
|
||||
# decorator: @bot.command("ping")
|
||||
def command(self, name: str) -> Callable[[Handler], Handler]:
|
||||
def _decorator(func: Handler) -> Handler:
|
||||
self._handlers[name.lower()] = func
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
def parse(self, content: str) -> tuple[str, list[str]] | None:
|
||||
if not content or not content.startswith("!"):
|
||||
return None
|
||||
try:
|
||||
parts = shlex.split(content[1:])
|
||||
except ValueError:
|
||||
parts = content[1:].split()
|
||||
if not parts:
|
||||
return None
|
||||
cmd = parts[0].lower()
|
||||
args = parts[1:]
|
||||
return cmd, args
|
||||
|
||||
async def try_handle(
|
||||
self,
|
||||
user: User,
|
||||
channel: ChatChannel,
|
||||
content: str,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
parsed = self.parse(content)
|
||||
if not parsed:
|
||||
return
|
||||
cmd, args = parsed
|
||||
handler = self._handlers.get(cmd)
|
||||
|
||||
reply: str | None = None
|
||||
if handler is None:
|
||||
return
|
||||
else:
|
||||
try:
|
||||
res = handler(user, args, session, channel)
|
||||
if asyncio.iscoroutine(res):
|
||||
res = await res
|
||||
reply = res # type: ignore[assignment]
|
||||
except Exception:
|
||||
reply = "Unknown error occured."
|
||||
if reply:
|
||||
await self._send_reply(user, channel, reply, session)
|
||||
|
||||
async def _send_message(
|
||||
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||
) -> None:
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None:
|
||||
return
|
||||
channel_id = channel.channel_id
|
||||
if channel_id is None:
|
||||
return
|
||||
|
||||
assert bot.id is not None
|
||||
msg = ChatMessage(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
sender_id=bot.id,
|
||||
type=MessageType.PLAIN,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(bot)
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(
|
||||
self, user: User, session: AsyncSession
|
||||
) -> ChatChannel | None:
|
||||
user_id = user.id
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None or bot.id is None:
|
||||
return None
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(user_id, bot.id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=f"pm_{user_id}_{bot.id}",
|
||||
description="Private message channel",
|
||||
type=ChannelType.PM,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(user)
|
||||
await session.refresh(bot)
|
||||
await server.batch_join_channel([user, bot], channel, session)
|
||||
return channel
|
||||
|
||||
async def _send_reply(
|
||||
self,
|
||||
user: User,
|
||||
src_channel: ChatChannel,
|
||||
content: str,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
target_channel = src_channel
|
||||
if src_channel.type == ChannelType.PUBLIC:
|
||||
pm = await self._ensure_pm_channel(user, session)
|
||||
if pm is not None:
|
||||
target_channel = pm
|
||||
await self._send_message(target_channel, content, session)
|
||||
|
||||
|
||||
bot = Bot()
|
||||
|
||||
|
||||
@bot.command("help")
|
||||
async def _help(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
cmds = sorted(bot._handlers.keys())
|
||||
if args:
|
||||
target = args[0].lower()
|
||||
if target in bot._handlers:
|
||||
return f"Usage: !{target} [args]"
|
||||
return f"No such command: {target}"
|
||||
if not cmds:
|
||||
return "No available commands"
|
||||
return "Available: " + ", ".join(f"!{c}" for c in cmds)
|
||||
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
r = random.randint(1, 100)
|
||||
return f"{user.username} rolls {r} point(s)"
|
||||
|
||||
|
||||
@bot.command("stats")
|
||||
async def _stats(
|
||||
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
if len(args) >= 1:
|
||||
target_user = (
|
||||
await session.exec(select(User).where(User.username == args[0]))
|
||||
).first()
|
||||
if not target_user:
|
||||
return f"User '{args[0]}' not found."
|
||||
else:
|
||||
target_user = user
|
||||
|
||||
gamemode = None
|
||||
if len(args) >= 2:
|
||||
gamemode = GameMode.parse(args[1].upper())
|
||||
if gamemode is None:
|
||||
subquery = (
|
||||
select(func.max(Score.id))
|
||||
.where(Score.user_id == target_user.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
last_score = (
|
||||
await session.exec(select(Score).where(Score.id == subquery))
|
||||
).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
gamemode = target_user.playmode
|
||||
|
||||
statistics = (
|
||||
await session.exec(
|
||||
select(UserStatistics).where(
|
||||
UserStatistics.user_id == target_user.id,
|
||||
UserStatistics.mode == gamemode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not statistics:
|
||||
return f"User '{args[0]}' has no statistics."
|
||||
|
||||
return f"""Stats for {target_user.username} ({gamemode.name.lower()}):
|
||||
Score: {statistics.total_score} (#{await get_rank(session, statistics)})
|
||||
Plays: {statistics.play_count} (lv{ceil(statistics.level_current)})
|
||||
Accuracy: {statistics.hit_accuracy}
|
||||
PP: {statistics.pp}
|
||||
"""
|
||||
|
||||
|
||||
async def _mp_name(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp name <name>"
|
||||
|
||||
name = args[0]
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.name = name
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room name has changed to {name}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_set(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp set <teammode> [<queuemode>]"
|
||||
|
||||
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
|
||||
if not teammode:
|
||||
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
|
||||
queuemode = (
|
||||
{
|
||||
"0": QueueMode.HOST_ONLY,
|
||||
"1": QueueMode.ALL_PLAYERS,
|
||||
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
|
||||
}.get(args[1])
|
||||
if len(args) >= 2
|
||||
else None
|
||||
)
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.match_type = teammode
|
||||
if queuemode:
|
||||
settings.queue_mode = queuemode
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_host(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.TransferHost(signalr_client, user_id)
|
||||
return f"User '{username}' has been hosted in the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_start(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
timer = None
|
||||
if len(args) >= 1 and args[0].isdigit():
|
||||
timer = int(args[0])
|
||||
|
||||
try:
|
||||
if timer is not None:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
signalr_client,
|
||||
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
await MultiplayerHubs.StartMatch(signalr_client)
|
||||
return "Good luck! Enjoy game!"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_abort(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
try:
|
||||
await MultiplayerHubs.AbortMatch(signalr_client)
|
||||
return "Match aborted."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_team(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
):
|
||||
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
|
||||
return "This command is only available in Team Versus mode."
|
||||
|
||||
if len(args) < 2:
|
||||
return "Usage: !mp team <username> <colour>"
|
||||
|
||||
username = args[0]
|
||||
team = {"red": 0, "blue": 1}.get(args[1])
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
if (
|
||||
user_client.user_id != signalr_client.user_id
|
||||
and room.room.host.user_id != signalr_client.user_id
|
||||
):
|
||||
return "You are not allowed to change other users' teams."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
user_client, ChangeTeamRequest(team_id=team)
|
||||
)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_password(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
password = ""
|
||||
if len(args) >= 1:
|
||||
password = args[0]
|
||||
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.password = password
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return "Room password has been set."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_kick(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.KickUser(signalr_client, user_id)
|
||||
return f"User '{username}' has been kicked from the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_map(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp map <mapid> [<playmode>]"
|
||||
|
||||
if room.status != RoomStatus.IDLE:
|
||||
return "Cannot change map while the game is running."
|
||||
|
||||
map_id = args[0]
|
||||
if not map_id.isdigit():
|
||||
return "Invalid map ID."
|
||||
map_id = int(map_id)
|
||||
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
|
||||
if playmode not in (
|
||||
GameMode.OSU,
|
||||
GameMode.TAIKO,
|
||||
GameMode.FRUITS,
|
||||
GameMode.MANIA,
|
||||
None,
|
||||
):
|
||||
return "Invalid playmode."
|
||||
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return (
|
||||
f"Cannot convert to {playmode.value}. "
|
||||
f"Original mode is {beatmap.mode.value}."
|
||||
)
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.beatmap_checksum = beatmap.checksum
|
||||
item.required_mods = []
|
||||
item.allowed_mods = []
|
||||
item.freestyle = False
|
||||
item.beatmap_id = map_id
|
||||
if playmode is not None:
|
||||
item.ruleset_id = int(playmode)
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_mods(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp mods <mod1> [<mod2> ...]"
|
||||
|
||||
if room.status != RoomStatus.IDLE:
|
||||
return "Cannot change mods while the game is running."
|
||||
|
||||
required_mods = []
|
||||
allowed_mods = []
|
||||
freestyle = False
|
||||
freemod = False
|
||||
for arg in args:
|
||||
arg = arg.upper()
|
||||
if arg == "NONE":
|
||||
required_mods.clear()
|
||||
allowed_mods.clear()
|
||||
break
|
||||
elif arg == "FREESTYLE":
|
||||
freestyle = True
|
||||
elif arg == "FREEMOD":
|
||||
freemod = True
|
||||
elif arg.startswith("+"):
|
||||
mod = arg.removeprefix("+")
|
||||
if len(mod) != 2:
|
||||
return f"Invalid mod: {mod}."
|
||||
allowed_mods.append(APIMod(acronym=mod))
|
||||
else:
|
||||
if len(arg) != 2:
|
||||
return f"Invalid mod: {arg}."
|
||||
required_mods.append(APIMod(acronym=arg))
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.freestyle = freestyle
|
||||
if freestyle:
|
||||
item.allowed_mods = []
|
||||
elif freemod:
|
||||
item.allowed_mods = get_available_mods(
|
||||
current_item.ruleset_id, required_mods
|
||||
)
|
||||
else:
|
||||
item.allowed_mods = allowed_mods
|
||||
item.required_mods = required_mods
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
_MP_COMMANDS = {
|
||||
"name": _mp_name,
|
||||
"set": _mp_set,
|
||||
"host": _mp_host,
|
||||
"start": _mp_start,
|
||||
"abort": _mp_abort,
|
||||
"map": _mp_map,
|
||||
"mods": _mp_mods,
|
||||
"kick": _mp_kick,
|
||||
"password": _mp_password,
|
||||
"team": _mp_team,
|
||||
}
|
||||
_MP_HELP = """!mp name <name>
|
||||
!mp set <teammode> [<queuemode>]
|
||||
!mp host <host>
|
||||
!mp start [<timer>]
|
||||
!mp abort
|
||||
!mp map <map> [<playmode>]
|
||||
!mp mods <mod1> [<mod2> ...]
|
||||
!mp kick <user>
|
||||
!mp password [<password>]
|
||||
!mp team <user> <team:red|blue>"""
|
||||
|
||||
|
||||
@bot.command("mp")
|
||||
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
|
||||
if not channel.name.startswith("room_"):
|
||||
return
|
||||
|
||||
room_id = int(channel.name[5:])
|
||||
room = MultiplayerHubs.rooms.get(room_id)
|
||||
if not room:
|
||||
return
|
||||
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
|
||||
if not signalr_client:
|
||||
return
|
||||
|
||||
if len(args) < 1:
|
||||
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
|
||||
|
||||
command = args[0].lower()
|
||||
if command not in _MP_COMMANDS:
|
||||
return f"No such command: {command}"
|
||||
|
||||
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)
|
||||
|
||||
|
||||
async def _score(
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
include_fail: bool = False,
|
||||
gamemode: GameMode | None = None,
|
||||
) -> str:
|
||||
q = (
|
||||
select(Score)
|
||||
.where(Score.user_id == user_id)
|
||||
.order_by(col(Score.id).desc())
|
||||
.options(joinedload(Score.beatmap))
|
||||
)
|
||||
if not include_fail:
|
||||
q = q.where(Score.passed.is_(True))
|
||||
if gamemode is not None:
|
||||
q = q.where(Score.gamemode == gamemode)
|
||||
|
||||
score = (await session.exec(q)).first()
|
||||
if score is None:
|
||||
return "You have no scores."
|
||||
|
||||
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
|
||||
Played at {score.started_at}
|
||||
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
|
||||
if score.gamemode == GameMode.MANIA:
|
||||
keys = next(
|
||||
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
|
||||
)
|
||||
if keys is None:
|
||||
keys = f"{int(score.beatmap.cs)}K"
|
||||
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
|
||||
result += (
|
||||
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@bot.command("re")
|
||||
async def _re(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
|
||||
gamemode = None
|
||||
if len(args) >= 1:
|
||||
gamemode = GameMode.parse(args[0])
|
||||
return await _score(user.id, session, include_fail=True, gamemode=gamemode)
|
||||
|
||||
|
||||
@bot.command("pr")
|
||||
async def _pr(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
|
||||
gamemode = None
|
||||
if len(args) >= 1:
|
||||
gamemode = GameMode.parse(args[0])
|
||||
return await _score(user.id, session, include_fail=False, gamemode=gamemode)
|
||||
305
app/router/notification/channel.py
Normal file
305
app/router/notification/channel.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||
silences: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/updates",
|
||||
response_model=UpdateResponse,
|
||||
name="获取更新",
|
||||
description="获取当前用户所在频道的最新的禁言情况。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
response_model=ChatChannelResp,
|
||||
name="加入频道",
|
||||
description="加入指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return await server.join_channel(current_user, db_channel, session)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
status_code=204,
|
||||
name="离开频道",
|
||||
description="将用户移出指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def leave_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.leave_channel(current_user, db_channel, session)
|
||||
return
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels",
|
||||
response_model=list[ChatChannelResp],
|
||||
name="获取频道列表",
|
||||
description="获取所有公开频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
assert channel.channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class GetChannelResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
users: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}",
|
||||
response_model=GetChannelResp,
|
||||
name="获取频道信息",
|
||||
description="获取指定频道的信息。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id is not None
|
||||
|
||||
users = []
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
if len(user_ids) != 2:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
for id_ in user_ids:
|
||||
if int(id_) == current_user.id:
|
||||
continue
|
||||
target_user = await session.get(User, int(id_))
|
||||
if target_user is None:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
users.extend([target_user, current_user])
|
||||
break
|
||||
|
||||
return GetChannelResp(
|
||||
channel=await ChatChannelResp.from_db(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(db_channel.channel_id, [])
|
||||
if db_channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CreateChannelReq(BaseModel):
|
||||
class AnnounceChannel(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
message: str | None = None
|
||||
type: Literal["ANNOUNCE", "PM"] = "PM"
|
||||
target_id: int | None = None
|
||||
target_ids: list[int] | None = None
|
||||
channel: AnnounceChannel | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check(self) -> Self:
|
||||
if self.type == "PM":
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id must be set for PM channels")
|
||||
else:
|
||||
if self.target_ids is None or self.channel is None or self.message is None:
|
||||
raise ValueError(
|
||||
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/channels",
|
||||
response_model=ChatChannelResp,
|
||||
name="创建频道",
|
||||
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_channel(
|
||||
session: Database,
|
||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
if req.type == "PM":
|
||||
target = await session.get(User, req.target_id)
|
||||
if not target:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(
|
||||
current_user.id, # pyright: ignore[reportArgumentType]
|
||||
req.target_id, # pyright: ignore[reportArgumentType]
|
||||
session,
|
||||
)
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
channel = await ChatChannel.get(channel_name, session)
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=channel_name,
|
||||
description=req.channel.description
|
||||
if req.channel
|
||||
else "Private message channel",
|
||||
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(current_user)
|
||||
if req.type == "PM":
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(
|
||||
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||
)
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
assert channel.channel_id
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, []),
|
||||
include_recent_messages=True,
|
||||
)
|
||||
253
app/router/notification/message.py
Normal file
253
app/router/notification/message.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
MessageType,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class KeepAliveResp(BaseModel):
|
||||
silences: list[UserSilenceResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/ack",
|
||||
name="保持连接",
|
||||
response_model=KeepAliveResp,
|
||||
description="保持公共频道的连接。同时返回最近的禁言列表。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class MessageReq(BaseModel):
|
||||
message: str
|
||||
is_action: bool = False
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=ChatMessageResp,
|
||||
name="发送消息",
|
||||
description="发送消息到指定频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def send_message(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
assert db_channel.channel_id
|
||||
assert current_user.id
|
||||
msg = ChatMessage(
|
||||
channel_id=db_channel.channel_id,
|
||||
content=req.message,
|
||||
sender_id=current_user.id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(db_channel)
|
||||
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
|
||||
)
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage(
|
||||
msg, current_user, [int(u) for u in user_ids], db_channel.type
|
||||
)
|
||||
)
|
||||
elif db_channel.type == ChannelType.TEAM:
|
||||
await server.new_private_notification(ChannelMessageTeam(msg, current_user))
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=list[ChatMessageResp],
|
||||
name="获取消息",
|
||||
description="获取指定频道的消息列表。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_message(
|
||||
session: Database,
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
messages = await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.channel_id == db_channel.channel_id,
|
||||
col(ChatMessage.message_id) > since,
|
||||
col(ChatMessage.message_id) < until if until is not None else True,
|
||||
)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
resp.reverse()
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/mark-as-read/{message}",
|
||||
status_code=204,
|
||||
name="标记消息为已读",
|
||||
description="标记指定消息为已读。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def mark_as_read(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id
|
||||
assert current_user.id
|
||||
await server.mark_as_read(db_channel.channel_id, current_user.id, message)
|
||||
|
||||
|
||||
class PMReq(BaseModel):
|
||||
target_id: int
|
||||
message: str
|
||||
is_action: bool = False
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
class NewPMResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
message: ChatMessageResp
|
||||
new_channel_id: int
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/new",
|
||||
name="创建私聊频道",
|
||||
description="创建一个新的私聊频道。",
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
user_id = current_user.id
|
||||
target = await session.get(User, req.target_id)
|
||||
if target is None:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
assert user_id
|
||||
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=f"pm_{user_id}_{req.target_id}",
|
||||
description="Private message channel",
|
||||
type=ChannelType.PM,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
assert channel.channel_id
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
)
|
||||
msg = ChatMessage(
|
||||
channel_id=channel.channel_id,
|
||||
content=req.message,
|
||||
sender_id=user_id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(channel)
|
||||
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
await server.send_message_to_channel(message_resp)
|
||||
return NewPMResp(
|
||||
channel=channel_resp,
|
||||
message=message_resp,
|
||||
new_channel_id=channel_resp.channel_id,
|
||||
)
|
||||
315
app/router/notification/server.py
Normal file
315
app/router/notification/server.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
with_db,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ChatServer:
|
||||
def __init__(self):
|
||||
self.connect_client: dict[int, WebSocket] = {}
|
||||
self.channels: dict[int, list[int]] = {}
|
||||
self.redis: Redis = get_redis()
|
||||
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.ChatSubscriber = ChatSubscriber()
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
if user_id in self.connect_client:
|
||||
del self.connect_client[user_id]
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: WebSocket, event: ChatEvent): ...
|
||||
|
||||
async def send_event(self, client: WebSocket | int, event: ChatEvent):
|
||||
if isinstance(client, int):
|
||||
client_ = self.connect_client.get(client)
|
||||
if client_ is None:
|
||||
return
|
||||
client = client_
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
await self.send_event(user_id, event)
|
||||
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
)
|
||||
if is_bot_command:
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
assert message.message_id
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
):
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
not_joined = []
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
for user in users:
|
||||
assert user.id is not None
|
||||
if user.id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user.id)
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
assert user.id is not None
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id]
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
|
||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
id = await insert_notification(session, detail)
|
||||
users = (
|
||||
await session.exec(
|
||||
select(UserNotification).where(
|
||||
UserNotification.notification_id == id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for user_notification in users:
|
||||
data = user_notification.notification.model_dump()
|
||||
data["is_read"] = user_notification.is_read
|
||||
data["details"] = user_notification.notification.details
|
||||
await server.send_event(
|
||||
user_notification.user_id,
|
||||
ChatEvent(
|
||||
event="new",
|
||||
data=data,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
chat_router = APIRouter(include_in_schema=False)
|
||||
|
||||
|
||||
async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
try:
|
||||
while True:
|
||||
packets = await ws.receive_json()
|
||||
if packets.get("event") == "chat.end":
|
||||
async for session in factory():
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
break
|
||||
await server.disconnect(user, session)
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {user_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {user_id}")
|
||||
|
||||
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
await server.ChatSubscriber.start_subscribe()
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||
)
|
||||
) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
login = await websocket.receive_json()
|
||||
if login.get("event") != "chat.start":
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
channel = await ChatChannel.get(1, session)
|
||||
if channel is not None:
|
||||
await server.join_channel(user, channel, session)
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
Reference in New Issue
Block a user