From e1d42743d3131728eb63bfa9de823ebc546bdcf2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 10:31:46 +0000 Subject: [PATCH] feat(chat): support BanchoBot --- app/const.py | 3 + app/models/score.py | 10 ++ app/router/auth.py | 2 +- app/router/chat/banchobot.py | 211 +++++++++++++++++++++++++++++++ app/router/chat/message.py | 7 +- app/router/chat/server.py | 39 ++++-- app/router/v2/user.py | 12 +- app/service/create_banchobot.py | 30 +++++ app/service/daily_challenge.py | 5 +- app/service/osu_rx_statistics.py | 4 + main.py | 2 + 11 files changed, 302 insertions(+), 23 deletions(-) create mode 100644 app/const.py create mode 100644 app/router/chat/banchobot.py create mode 100644 app/service/create_banchobot.py diff --git a/app/const.py b/app/const.py new file mode 100644 index 0000000..78ad45c --- /dev/null +++ b/app/const.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +BANCHOBOT_ID = 2 diff --git a/app/models/score.py b/app/models/score.py index 703b55c..53b61c5 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -88,6 +88,16 @@ class GameMode(str, Enum): }[self] return self + @classmethod + def parse(cls, v: str | int) -> "GameMode | None": + if isinstance(v, int) or v.isdigit(): + return cls.from_int_extra(int(v)) + v = v.lower() + try: + return cls[v] + except ValueError: + return None + class Rank(str, Enum): X = "X" diff --git a/app/router/auth.py b/app/router/auth.py index a1149fc..88528e2 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -459,7 +459,7 @@ async def oauth_token( # 存储令牌 await store_token( db, - 3, + 2, client_id, scopes, access_token, diff --git a/app/router/chat/banchobot.py b/app/router/chat/banchobot.py new file mode 100644 index 0000000..76c8beb --- /dev/null +++ b/app/router/chat/banchobot.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from math import ceil +import random +import shlex + +from app.const import BANCHOBOT_ID +from app.database import ChatMessageResp +from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType +from app.database.lazer_user import User +from app.database.score import Score +from app.database.statistics import UserStatistics, get_rank +from app.models.score import GameMode + +from .server import server + +from sqlmodel import func, select +from sqlmodel.ext.asyncio.session import AsyncSession + +HandlerResult = str | None | Awaitable[str | None] +Handler = Callable[[User, list[str], AsyncSession], HandlerResult] + + +class Bot: + def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None: + self._handlers: dict[str, Handler] = {} + self.bot_user_id = bot_user_id + + # decorator: @bot.command("ping") + def command(self, name: str) -> Callable[[Handler], Handler]: + def _decorator(func: Handler) -> Handler: + self._handlers[name.lower()] = func + return func + + return _decorator + + def parse(self, content: str) -> tuple[str, list[str]] | None: + if not content or not content.startswith("!"): + return None + try: + parts = shlex.split(content[1:]) + except ValueError: + parts = content[1:].split() + if not parts: + return None + cmd = parts[0].lower() + args = parts[1:] + return cmd, args + + async def try_handle( + self, + user: User, + channel: ChatChannel, + content: str, + session: AsyncSession, + ) -> None: + parsed = self.parse(content) + if not parsed: + return + cmd, args = parsed + handler = self._handlers.get(cmd) + + reply: str | None = None + if handler is None: + return + else: + res = handler(user, args, session) + if asyncio.iscoroutine(res): + res = await res + reply = res # type: ignore[assignment] + + if reply: + await self._send_reply(user, channel, reply, session) + + async def _send_message( + self, channel: ChatChannel, content: str, session: AsyncSession + ) -> None: + bot = await session.get(User, self.bot_user_id) + if bot is None: + return + channel_id = channel.channel_id + if channel_id is None: + return + + assert bot.id is not None + msg = ChatMessage( + channel_id=channel_id, + content=content, + sender_id=bot.id, + type=MessageType.PLAIN, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(bot) + resp = await ChatMessageResp.from_db(msg, session, bot) + await server.send_message_to_channel(resp) + + async def _ensure_pm_channel( + self, user: User, session: AsyncSession + ) -> ChatChannel | None: + user_id = user.id + if user_id is None: + return None + + bot = await session.get(User, self.bot_user_id) + if bot is None or bot.id is None: + return None + + channel = await ChatChannel.get_pm_channel(user_id, bot.id, session) + if channel is None: + channel = ChatChannel( + name=f"pm_{user_id}_{bot.id}", + description="Private message channel", + type=ChannelType.PM, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(user) + await session.refresh(bot) + await server.batch_join_channel([user, bot], channel, session) + return channel + + async def _send_reply( + self, + user: User, + src_channel: ChatChannel, + content: str, + session: AsyncSession, + ) -> None: + target_channel = src_channel + if src_channel.type == ChannelType.PUBLIC: + pm = await self._ensure_pm_channel(user, session) + if pm is not None: + target_channel = pm + await self._send_message(target_channel, content, session) + + +bot = Bot() + + +@bot.command("help") +async def _help(user: User, args: list[str], _session: AsyncSession) -> str: + cmds = sorted(bot._handlers.keys()) + if args: + target = args[0].lower() + if target in bot._handlers: + return f"Use: !{target} [args]" + return f"No such command: {target}" + if not cmds: + return "No available commands" + return "Available: " + ", ".join(f"!{c}" for c in cmds) + + +@bot.command("roll") +def _roll(user: User, args: list[str], _session: AsyncSession) -> str: + if len(args) > 0 and args[0].isdigit(): + r = random.randint(1, int(args[0])) + else: + r = random.randint(1, 100) + return f"{user.username} rolls {r} point(s)" + + +@bot.command("stats") +async def _stats(user: User, args: list[str], session: AsyncSession) -> str: + if len(args) < 1: + return "Usage: !stats " + + target_user = ( + await session.exec(select(User).where(User.username == args[0])) + ).first() + if not target_user: + return f"User '{args[0]}' not found." + + gamemode = None + if len(args) >= 2: + gamemode = GameMode.parse(args[1].upper()) + if gamemode is None: + subquery = ( + select(func.max(Score.id)) + .where(Score.user_id == target_user.id) + .scalar_subquery() + ) + last_score = ( + await session.exec(select(Score).where(Score.id == subquery)) + ).first() + if last_score is not None: + gamemode = last_score.gamemode + else: + gamemode = target_user.playmode + + statistics = ( + await session.exec( + select(UserStatistics).where( + UserStatistics.user_id == target_user.id, + UserStatistics.mode == gamemode, + ) + ) + ).first() + if not statistics: + return f"User '{args[0]}' has no statistics." + + return f"""Stats for {target_user.username} ({gamemode.name.lower()}): +Score: {statistics.total_score} (#{await get_rank(session, statistics)}) +Plays: {statistics.play_count} (lv{ceil(statistics.level_current)}) +Accuracy: {statistics.hit_accuracy} +PP: {statistics.pp} +""" diff --git a/app/router/chat/message.py b/app/router/chat/message.py index dc4f134..318aefe 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -14,6 +14,7 @@ from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router +from .banchobot import bot from .server import server from fastapi import Depends, HTTPException, Query, Security @@ -63,8 +64,12 @@ async def send_message( await session.commit() await session.refresh(msg) await session.refresh(current_user) + await session.refresh(db_channel) resp = await ChatMessageResp.from_db(msg, session, current_user) - await server.send_message_to_channel(resp) + is_bot_command = req.message.startswith("!") + await server.send_message_to_channel(resp, is_bot_command) + if is_bot_command: + await bot.try_handle(current_user, db_channel, req.message, session) return resp diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 6add821..b91dd02 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -72,16 +72,24 @@ class ChatServer: async def mark_as_read(self, channel_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_msg", message_id) - async def send_message_to_channel(self, message: ChatMessageResp): - self._add_task( - self.broadcast( - message.channel_id, - ChatEvent( - event="chat.message.new", - data={"messages": [message], "users": [message.sender]}, - ), - ) + async def send_message_to_channel( + self, message: ChatMessageResp, is_bot_command: bool = False + ): + event = ChatEvent( + event="chat.message.new", + data={"messages": [message], "users": [message.sender]}, ) + if is_bot_command: + client = self.connect_client.get(message.sender_id) + if client: + self._add_task(self.send_event(client, event)) + else: + self._add_task( + self.broadcast( + message.channel_id, + event, + ) + ) assert message.message_id await self.mark_as_read(message.channel_id, message.message_id) @@ -91,14 +99,17 @@ class ChatServer: channel_id = channel.channel_id assert channel_id is not None + not_joined = [] + if channel_id not in self.channels: self.channels[channel_id] = [] - for user_id in [user.id for user in users]: - assert user_id is not None - if user_id not in self.channels[channel_id]: - self.channels[channel_id].append(user_id) - for user in users: + assert user.id is not None + if user.id not in self.channels[channel_id]: + self.channels[channel_id].append(user.id) + not_joined.append(user) + + for user in not_joined: assert user.id is not None channel_resp = await ChatChannelResp.from_db( channel, diff --git a/app/router/v2/user.py b/app/router/v2/user.py index dece5be..123e120 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta from typing import Literal +from app.const import BANCHOBOT_ID from app.database import ( BeatmapPlaycounts, BeatmapPlaycountsResp, @@ -65,6 +66,7 @@ async def get_users( include=SEARCH_INCLUDED, ) for searched_user in searched_users + if searched_user.id != BANCHOBOT_ID ] ) @@ -91,7 +93,7 @@ async def get_user_info_ruleset( ) ) ).first() - if not searched_user: + if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") return await UserResp.from_db( searched_user, @@ -123,7 +125,7 @@ async def get_user_info( ) ) ).first() - if not searched_user: + if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") return await UserResp.from_db( searched_user, @@ -148,7 +150,7 @@ async def get_user_beatmapsets( offset: int = Query(0, ge=0, description="偏移量"), ): user = await session.get(User, user_id) - if not user: + if not user or user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") if type in { @@ -218,7 +220,7 @@ async def get_user_scores( current_user: User = Security(get_current_user, scopes=["public"]), ): db_user = await session.get(User, user_id) - if not db_user: + if not db_user or db_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") gamemode = mode or db_user.playmode @@ -271,7 +273,7 @@ async def get_user_events( session: AsyncSession = Depends(get_db), ): db_user = await session.get(User, user) - if db_user is None: + if db_user is None or db_user.id == BANCHOBOT_ID: raise HTTPException(404, "User Not found") events = await db_user.awaitable_attrs.events if limit is not None: diff --git a/app/service/create_banchobot.py b/app/service/create_banchobot.py new file mode 100644 index 0000000..0d855cf --- /dev/null +++ b/app/service/create_banchobot.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from app.const import BANCHOBOT_ID +from app.database.lazer_user import User +from app.database.statistics import UserStatistics +from app.dependencies.database import engine +from app.models.score import GameMode + +from sqlmodel import exists, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def create_banchobot(): + async with AsyncSession(engine) as session: + is_exist = (await session.exec(select(exists()).where(User.id == 2))).first() + if not is_exist: + banchobot = User( + username="BanchoBot", + email="banchobot@ppy.sh", + is_bot=True, + pw_bcrypt="0", + id=BANCHOBOT_ID, + avatar_url="https://a.ppy.sh/3", + country_code="SH", + website="https://twitter.com/banchoboat", + ) + session.add(banchobot) + statistics = UserStatistics(user_id=BANCHOBOT_ID, mode=GameMode.OSU) + session.add(statistics) + await session.commit() diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py index ec7f9d0..f3dbf05 100644 --- a/app/service/daily_challenge.py +++ b/app/service/daily_challenge.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta import json +from app.const import BANCHOBOT_ID from app.database.playlists import Playlist from app.database.room import Room from app.dependencies.database import engine, get_redis @@ -26,12 +27,12 @@ async def create_daily_challenge_room( return await create_playlist_room( session=session, name=str(today), - host_id=3, + host_id=BANCHOBOT_ID, playlist=[ Playlist( id=0, room_id=0, - owner_id=3, + owner_id=BANCHOBOT_ID, ruleset_id=ruleset_id, beatmap_id=beatmap, required_mods=required_mods, diff --git a/app/service/osu_rx_statistics.py b/app/service/osu_rx_statistics.py index 8a0441f..60f94ce 100644 --- a/app/service/osu_rx_statistics.py +++ b/app/service/osu_rx_statistics.py @@ -1,6 +1,7 @@ from __future__ import annotations from app.config import settings +from app.const import BANCHOBOT_ID from app.database.lazer_user import User from app.database.statistics import UserStatistics from app.dependencies.database import engine @@ -15,6 +16,9 @@ async def create_rx_statistics(): async with AsyncSession(engine) as session: users = (await session.exec(select(User.id))).all() for i in users: + if i == BANCHOBOT_ID: + continue + if settings.enable_rx: for mode in ( GameMode.OSURX, diff --git a/main.py b/main.py index 66471cc..526a92e 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,7 @@ from app.router import ( ) from app.router.redirect import redirect_router from app.service.calculate_all_user_rank import calculate_user_rank +from app.service.create_banchobot import create_banchobot from app.service.daily_challenge import daily_challenge_job from app.service.osu_rx_statistics import create_rx_statistics from app.service.pp_recalculate import recalculate_all_players_pp @@ -43,6 +44,7 @@ async def lifespan(app: FastAPI): await calculate_user_rank(True) init_scheduler() await daily_challenge_job() + await create_banchobot() # on shutdown yield stop_scheduler()