feat(chat): support BanchoBot
This commit is contained in:
3
app/const.py
Normal file
3
app/const.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
BANCHOBOT_ID = 2
|
||||||
@@ -88,6 +88,16 @@ class GameMode(str, Enum):
|
|||||||
}[self]
|
}[self]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, v: str | int) -> "GameMode | None":
|
||||||
|
if isinstance(v, int) or v.isdigit():
|
||||||
|
return cls.from_int_extra(int(v))
|
||||||
|
v = v.lower()
|
||||||
|
try:
|
||||||
|
return cls[v]
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Rank(str, Enum):
|
class Rank(str, Enum):
|
||||||
X = "X"
|
X = "X"
|
||||||
|
|||||||
@@ -459,7 +459,7 @@ async def oauth_token(
|
|||||||
# 存储令牌
|
# 存储令牌
|
||||||
await store_token(
|
await store_token(
|
||||||
db,
|
db,
|
||||||
3,
|
2,
|
||||||
client_id,
|
client_id,
|
||||||
scopes,
|
scopes,
|
||||||
access_token,
|
access_token,
|
||||||
|
|||||||
211
app/router/chat/banchobot.py
Normal file
211
app/router/chat/banchobot.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from math import ceil
|
||||||
|
import random
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from app.const import BANCHOBOT_ID
|
||||||
|
from app.database import ChatMessageResp
|
||||||
|
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||||
|
from app.database.lazer_user import User
|
||||||
|
from app.database.score import Score
|
||||||
|
from app.database.statistics import UserStatistics, get_rank
|
||||||
|
from app.models.score import GameMode
|
||||||
|
|
||||||
|
from .server import server
|
||||||
|
|
||||||
|
from sqlmodel import func, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
HandlerResult = str | None | Awaitable[str | None]
|
||||||
|
Handler = Callable[[User, list[str], AsyncSession], HandlerResult]
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None:
|
||||||
|
self._handlers: dict[str, Handler] = {}
|
||||||
|
self.bot_user_id = bot_user_id
|
||||||
|
|
||||||
|
# decorator: @bot.command("ping")
|
||||||
|
def command(self, name: str) -> Callable[[Handler], Handler]:
|
||||||
|
def _decorator(func: Handler) -> Handler:
|
||||||
|
self._handlers[name.lower()] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
def parse(self, content: str) -> tuple[str, list[str]] | None:
|
||||||
|
if not content or not content.startswith("!"):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parts = shlex.split(content[1:])
|
||||||
|
except ValueError:
|
||||||
|
parts = content[1:].split()
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
cmd = parts[0].lower()
|
||||||
|
args = parts[1:]
|
||||||
|
return cmd, args
|
||||||
|
|
||||||
|
async def try_handle(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
channel: ChatChannel,
|
||||||
|
content: str,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
parsed = self.parse(content)
|
||||||
|
if not parsed:
|
||||||
|
return
|
||||||
|
cmd, args = parsed
|
||||||
|
handler = self._handlers.get(cmd)
|
||||||
|
|
||||||
|
reply: str | None = None
|
||||||
|
if handler is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
res = handler(user, args, session)
|
||||||
|
if asyncio.iscoroutine(res):
|
||||||
|
res = await res
|
||||||
|
reply = res # type: ignore[assignment]
|
||||||
|
|
||||||
|
if reply:
|
||||||
|
await self._send_reply(user, channel, reply, session)
|
||||||
|
|
||||||
|
async def _send_message(
|
||||||
|
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
bot = await session.get(User, self.bot_user_id)
|
||||||
|
if bot is None:
|
||||||
|
return
|
||||||
|
channel_id = channel.channel_id
|
||||||
|
if channel_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert bot.id is not None
|
||||||
|
msg = ChatMessage(
|
||||||
|
channel_id=channel_id,
|
||||||
|
content=content,
|
||||||
|
sender_id=bot.id,
|
||||||
|
type=MessageType.PLAIN,
|
||||||
|
)
|
||||||
|
session.add(msg)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(msg)
|
||||||
|
await session.refresh(bot)
|
||||||
|
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||||
|
await server.send_message_to_channel(resp)
|
||||||
|
|
||||||
|
async def _ensure_pm_channel(
|
||||||
|
self, user: User, session: AsyncSession
|
||||||
|
) -> ChatChannel | None:
|
||||||
|
user_id = user.id
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bot = await session.get(User, self.bot_user_id)
|
||||||
|
if bot is None or bot.id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
channel = await ChatChannel.get_pm_channel(user_id, bot.id, session)
|
||||||
|
if channel is None:
|
||||||
|
channel = ChatChannel(
|
||||||
|
name=f"pm_{user_id}_{bot.id}",
|
||||||
|
description="Private message channel",
|
||||||
|
type=ChannelType.PM,
|
||||||
|
)
|
||||||
|
session.add(channel)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(channel)
|
||||||
|
await session.refresh(user)
|
||||||
|
await session.refresh(bot)
|
||||||
|
await server.batch_join_channel([user, bot], channel, session)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
async def _send_reply(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
src_channel: ChatChannel,
|
||||||
|
content: str,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
target_channel = src_channel
|
||||||
|
if src_channel.type == ChannelType.PUBLIC:
|
||||||
|
pm = await self._ensure_pm_channel(user, session)
|
||||||
|
if pm is not None:
|
||||||
|
target_channel = pm
|
||||||
|
await self._send_message(target_channel, content, session)
|
||||||
|
|
||||||
|
|
||||||
|
bot = Bot()
|
||||||
|
|
||||||
|
|
||||||
|
@bot.command("help")
|
||||||
|
async def _help(user: User, args: list[str], _session: AsyncSession) -> str:
|
||||||
|
cmds = sorted(bot._handlers.keys())
|
||||||
|
if args:
|
||||||
|
target = args[0].lower()
|
||||||
|
if target in bot._handlers:
|
||||||
|
return f"Use: !{target} [args]"
|
||||||
|
return f"No such command: {target}"
|
||||||
|
if not cmds:
|
||||||
|
return "No available commands"
|
||||||
|
return "Available: " + ", ".join(f"!{c}" for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.command("roll")
|
||||||
|
def _roll(user: User, args: list[str], _session: AsyncSession) -> str:
|
||||||
|
if len(args) > 0 and args[0].isdigit():
|
||||||
|
r = random.randint(1, int(args[0]))
|
||||||
|
else:
|
||||||
|
r = random.randint(1, 100)
|
||||||
|
return f"{user.username} rolls {r} point(s)"
|
||||||
|
|
||||||
|
|
||||||
|
@bot.command("stats")
|
||||||
|
async def _stats(user: User, args: list[str], session: AsyncSession) -> str:
|
||||||
|
if len(args) < 1:
|
||||||
|
return "Usage: !stats <username>"
|
||||||
|
|
||||||
|
target_user = (
|
||||||
|
await session.exec(select(User).where(User.username == args[0]))
|
||||||
|
).first()
|
||||||
|
if not target_user:
|
||||||
|
return f"User '{args[0]}' not found."
|
||||||
|
|
||||||
|
gamemode = None
|
||||||
|
if len(args) >= 2:
|
||||||
|
gamemode = GameMode.parse(args[1].upper())
|
||||||
|
if gamemode is None:
|
||||||
|
subquery = (
|
||||||
|
select(func.max(Score.id))
|
||||||
|
.where(Score.user_id == target_user.id)
|
||||||
|
.scalar_subquery()
|
||||||
|
)
|
||||||
|
last_score = (
|
||||||
|
await session.exec(select(Score).where(Score.id == subquery))
|
||||||
|
).first()
|
||||||
|
if last_score is not None:
|
||||||
|
gamemode = last_score.gamemode
|
||||||
|
else:
|
||||||
|
gamemode = target_user.playmode
|
||||||
|
|
||||||
|
statistics = (
|
||||||
|
await session.exec(
|
||||||
|
select(UserStatistics).where(
|
||||||
|
UserStatistics.user_id == target_user.id,
|
||||||
|
UserStatistics.mode == gamemode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not statistics:
|
||||||
|
return f"User '{args[0]}' has no statistics."
|
||||||
|
|
||||||
|
return f"""Stats for {target_user.username} ({gamemode.name.lower()}):
|
||||||
|
Score: {statistics.total_score} (#{await get_rank(session, statistics)})
|
||||||
|
Plays: {statistics.play_count} (lv{ceil(statistics.level_current)})
|
||||||
|
Accuracy: {statistics.hit_accuracy}
|
||||||
|
PP: {statistics.pp}
|
||||||
|
"""
|
||||||
@@ -14,6 +14,7 @@ from app.dependencies.param import BodyOrForm
|
|||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
|
|
||||||
|
from .banchobot import bot
|
||||||
from .server import server
|
from .server import server
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, Security
|
from fastapi import Depends, HTTPException, Query, Security
|
||||||
@@ -63,8 +64,12 @@ async def send_message(
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(msg)
|
await session.refresh(msg)
|
||||||
await session.refresh(current_user)
|
await session.refresh(current_user)
|
||||||
|
await session.refresh(db_channel)
|
||||||
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||||
await server.send_message_to_channel(resp)
|
is_bot_command = req.message.startswith("!")
|
||||||
|
await server.send_message_to_channel(resp, is_bot_command)
|
||||||
|
if is_bot_command:
|
||||||
|
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -72,16 +72,24 @@ class ChatServer:
|
|||||||
async def mark_as_read(self, channel_id: int, message_id: int):
|
async def mark_as_read(self, channel_id: int, message_id: int):
|
||||||
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
|
await self.redis.set(f"chat:{channel_id}:last_msg", message_id)
|
||||||
|
|
||||||
async def send_message_to_channel(self, message: ChatMessageResp):
|
async def send_message_to_channel(
|
||||||
self._add_task(
|
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||||
self.broadcast(
|
):
|
||||||
message.channel_id,
|
event = ChatEvent(
|
||||||
ChatEvent(
|
event="chat.message.new",
|
||||||
event="chat.message.new",
|
data={"messages": [message], "users": [message.sender]},
|
||||||
data={"messages": [message], "users": [message.sender]},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
if is_bot_command:
|
||||||
|
client = self.connect_client.get(message.sender_id)
|
||||||
|
if client:
|
||||||
|
self._add_task(self.send_event(client, event))
|
||||||
|
else:
|
||||||
|
self._add_task(
|
||||||
|
self.broadcast(
|
||||||
|
message.channel_id,
|
||||||
|
event,
|
||||||
|
)
|
||||||
|
)
|
||||||
assert message.message_id
|
assert message.message_id
|
||||||
await self.mark_as_read(message.channel_id, message.message_id)
|
await self.mark_as_read(message.channel_id, message.message_id)
|
||||||
|
|
||||||
@@ -91,14 +99,17 @@ class ChatServer:
|
|||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
assert channel_id is not None
|
assert channel_id is not None
|
||||||
|
|
||||||
|
not_joined = []
|
||||||
|
|
||||||
if channel_id not in self.channels:
|
if channel_id not in self.channels:
|
||||||
self.channels[channel_id] = []
|
self.channels[channel_id] = []
|
||||||
for user_id in [user.id for user in users]:
|
|
||||||
assert user_id is not None
|
|
||||||
if user_id not in self.channels[channel_id]:
|
|
||||||
self.channels[channel_id].append(user_id)
|
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
|
assert user.id is not None
|
||||||
|
if user.id not in self.channels[channel_id]:
|
||||||
|
self.channels[channel_id].append(user.id)
|
||||||
|
not_joined.append(user)
|
||||||
|
|
||||||
|
for user in not_joined:
|
||||||
assert user.id is not None
|
assert user.id is not None
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
channel,
|
channel,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import (
|
from app.database import (
|
||||||
BeatmapPlaycounts,
|
BeatmapPlaycounts,
|
||||||
BeatmapPlaycountsResp,
|
BeatmapPlaycountsResp,
|
||||||
@@ -65,6 +66,7 @@ async def get_users(
|
|||||||
include=SEARCH_INCLUDED,
|
include=SEARCH_INCLUDED,
|
||||||
)
|
)
|
||||||
for searched_user in searched_users
|
for searched_user in searched_users
|
||||||
|
if searched_user.id != BANCHOBOT_ID
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,7 +93,7 @@ async def get_user_info_ruleset(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not searched_user:
|
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
return await UserResp.from_db(
|
return await UserResp.from_db(
|
||||||
searched_user,
|
searched_user,
|
||||||
@@ -123,7 +125,7 @@ async def get_user_info(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if not searched_user:
|
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
return await UserResp.from_db(
|
return await UserResp.from_db(
|
||||||
searched_user,
|
searched_user,
|
||||||
@@ -148,7 +150,7 @@ async def get_user_beatmapsets(
|
|||||||
offset: int = Query(0, ge=0, description="偏移量"),
|
offset: int = Query(0, ge=0, description="偏移量"),
|
||||||
):
|
):
|
||||||
user = await session.get(User, user_id)
|
user = await session.get(User, user_id)
|
||||||
if not user:
|
if not user or user.id == BANCHOBOT_ID:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
|
|
||||||
if type in {
|
if type in {
|
||||||
@@ -218,7 +220,7 @@ async def get_user_scores(
|
|||||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
):
|
):
|
||||||
db_user = await session.get(User, user_id)
|
db_user = await session.get(User, user_id)
|
||||||
if not db_user:
|
if not db_user or db_user.id == BANCHOBOT_ID:
|
||||||
raise HTTPException(404, detail="User not found")
|
raise HTTPException(404, detail="User not found")
|
||||||
|
|
||||||
gamemode = mode or db_user.playmode
|
gamemode = mode or db_user.playmode
|
||||||
@@ -271,7 +273,7 @@ async def get_user_events(
|
|||||||
session: AsyncSession = Depends(get_db),
|
session: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
db_user = await session.get(User, user)
|
db_user = await session.get(User, user)
|
||||||
if db_user is None:
|
if db_user is None or db_user.id == BANCHOBOT_ID:
|
||||||
raise HTTPException(404, "User Not found")
|
raise HTTPException(404, "User Not found")
|
||||||
events = await db_user.awaitable_attrs.events
|
events = await db_user.awaitable_attrs.events
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
|
|||||||
30
app/service/create_banchobot.py
Normal file
30
app/service/create_banchobot.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.const import BANCHOBOT_ID
|
||||||
|
from app.database.lazer_user import User
|
||||||
|
from app.database.statistics import UserStatistics
|
||||||
|
from app.dependencies.database import engine
|
||||||
|
from app.models.score import GameMode
|
||||||
|
|
||||||
|
from sqlmodel import exists, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
async def create_banchobot():
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
is_exist = (await session.exec(select(exists()).where(User.id == 2))).first()
|
||||||
|
if not is_exist:
|
||||||
|
banchobot = User(
|
||||||
|
username="BanchoBot",
|
||||||
|
email="banchobot@ppy.sh",
|
||||||
|
is_bot=True,
|
||||||
|
pw_bcrypt="0",
|
||||||
|
id=BANCHOBOT_ID,
|
||||||
|
avatar_url="https://a.ppy.sh/3",
|
||||||
|
country_code="SH",
|
||||||
|
website="https://twitter.com/banchoboat",
|
||||||
|
)
|
||||||
|
session.add(banchobot)
|
||||||
|
statistics = UserStatistics(user_id=BANCHOBOT_ID, mode=GameMode.OSU)
|
||||||
|
session.add(statistics)
|
||||||
|
await session.commit()
|
||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database.playlists import Playlist
|
from app.database.playlists import Playlist
|
||||||
from app.database.room import Room
|
from app.database.room import Room
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import engine, get_redis
|
||||||
@@ -26,12 +27,12 @@ async def create_daily_challenge_room(
|
|||||||
return await create_playlist_room(
|
return await create_playlist_room(
|
||||||
session=session,
|
session=session,
|
||||||
name=str(today),
|
name=str(today),
|
||||||
host_id=3,
|
host_id=BANCHOBOT_ID,
|
||||||
playlist=[
|
playlist=[
|
||||||
Playlist(
|
Playlist(
|
||||||
id=0,
|
id=0,
|
||||||
room_id=0,
|
room_id=0,
|
||||||
owner_id=3,
|
owner_id=BANCHOBOT_ID,
|
||||||
ruleset_id=ruleset_id,
|
ruleset_id=ruleset_id,
|
||||||
beatmap_id=beatmap,
|
beatmap_id=beatmap,
|
||||||
required_mods=required_mods,
|
required_mods=required_mods,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import engine
|
||||||
@@ -15,6 +16,9 @@ async def create_rx_statistics():
|
|||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
users = (await session.exec(select(User.id))).all()
|
users = (await session.exec(select(User.id))).all()
|
||||||
for i in users:
|
for i in users:
|
||||||
|
if i == BANCHOBOT_ID:
|
||||||
|
continue
|
||||||
|
|
||||||
if settings.enable_rx:
|
if settings.enable_rx:
|
||||||
for mode in (
|
for mode in (
|
||||||
GameMode.OSURX,
|
GameMode.OSURX,
|
||||||
|
|||||||
2
main.py
2
main.py
@@ -22,6 +22,7 @@ from app.router import (
|
|||||||
)
|
)
|
||||||
from app.router.redirect import redirect_router
|
from app.router.redirect import redirect_router
|
||||||
from app.service.calculate_all_user_rank import calculate_user_rank
|
from app.service.calculate_all_user_rank import calculate_user_rank
|
||||||
|
from app.service.create_banchobot import create_banchobot
|
||||||
from app.service.daily_challenge import daily_challenge_job
|
from app.service.daily_challenge import daily_challenge_job
|
||||||
from app.service.osu_rx_statistics import create_rx_statistics
|
from app.service.osu_rx_statistics import create_rx_statistics
|
||||||
from app.service.pp_recalculate import recalculate_all_players_pp
|
from app.service.pp_recalculate import recalculate_all_players_pp
|
||||||
@@ -43,6 +44,7 @@ async def lifespan(app: FastAPI):
|
|||||||
await calculate_user_rank(True)
|
await calculate_user_rank(True)
|
||||||
init_scheduler()
|
init_scheduler()
|
||||||
await daily_challenge_job()
|
await daily_challenge_job()
|
||||||
|
await create_banchobot()
|
||||||
# on shutdown
|
# on shutdown
|
||||||
yield
|
yield
|
||||||
stop_scheduler()
|
stop_scheduler()
|
||||||
|
|||||||
Reference in New Issue
Block a user