From f992e4cc71a0decb2cbdfcdd34c36797344844bf Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 05:29:16 +0000 Subject: [PATCH 01/10] feat(chat): support public channel chat --- app/database/__init__.py | 12 ++ app/database/chat.py | 193 ++++++++++++++++++ app/dependencies/database.py | 12 ++ app/dependencies/param.py | 39 ++++ app/models/chat.py | 10 + app/router/__init__.py | 2 + app/router/chat/__init__.py | 28 +++ app/router/chat/channel.py | 138 +++++++++++++ app/router/chat/message.py | 98 +++++++++ app/router/chat/server.py | 190 +++++++++++++++++ app/service/subscribers/base.py | 12 +- main.py | 2 + .../versions/dd33d89aa2c2_chat_add_chat.py | 192 +++++++++++++++++ 13 files changed, 925 insertions(+), 3 deletions(-) create mode 100644 app/database/chat.py create mode 100644 app/dependencies/param.py create mode 100644 app/models/chat.py create mode 100644 app/router/chat/__init__.py create mode 100644 app/router/chat/channel.py create mode 100644 app/router/chat/message.py create mode 100644 app/router/chat/server.py create mode 100644 migrations/versions/dd33d89aa2c2_chat_add_chat.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 0283923..d8794c4 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -10,6 +10,13 @@ from .beatmapset import ( BeatmapsetResp, ) from .best_score import BestScore +from .chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, + ChatMessage, + ChatMessageResp, +) from .counts import ( CountResp, MonthlyPlaycounts, @@ -63,6 +70,11 @@ __all__ = [ "Beatmapset", "BeatmapsetResp", "BestScore", + "ChannelType", + "ChatChannel", + "ChatChannelResp", + "ChatMessage", + "ChatMessageResp", "CountResp", "DailyChallengeStats", "DailyChallengeStatsResp", diff --git a/app/database/chat.py b/app/database/chat.py new file mode 100644 index 0000000..565657d --- /dev/null +++ b/app/database/chat.py @@ -0,0 +1,193 @@ +from datetime import UTC, datetime +from enum import Enum +from typing import Self + +from app.database.lazer_user import RANKING_INCLUDES, User, UserResp +from app.models.model import UTCBaseModel + +from pydantic import BaseModel +from redis.asyncio import Redis +from sqlmodel import ( + VARCHAR, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +# ChatChannel + + +class ChatUserAttributes(BaseModel): + can_message: bool + can_message_error: str | None = None + last_read_id: int + + +class ChannelType(str, Enum): + PUBLIC = "PUBLIC" + PRIVATE = "PRIVATE" + MULTIPLAYER = "MULTIPLAYER" + SPECTATOR = "SPECTATOR" + TEMPORARY = "TEMPORARY" + PM = "PM" + GROUP = "GROUP" + SYSTEM = "SYSTEM" + ANNOUNCE = "ANNOUNCE" + TEAM = "TEAM" + + +class ChatChannelBase(SQLModel): + name: str = Field(sa_column=Column(VARCHAR(50), index=True)) + description: str = Field(sa_column=Column(VARCHAR(255), index=True)) + icon: str | None = Field(default=None) + type: ChannelType = Field(index=True) + + +class ChatChannel(ChatChannelBase, table=True): + __tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType] + channel_id: int | None = Field(primary_key=True, index=True, default=None) + + @classmethod + async def get( + cls, channel: str | int, session: AsyncSession + ) -> "ChatChannel | None": + if isinstance(channel, int) or channel.isdigit(): + channel_ = await session.get(ChatChannel, channel) + if channel_ is not None: + return channel_ + return ( + await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) + ).first() + + +class ChatChannelResp(ChatChannelBase): + channel_id: int + moderated: bool = False + uuid: str | None = None + current_user_attributes: ChatUserAttributes | None = None + last_read_id: int | None = None + last_message_id: int | None = None + recent_messages: list[str] | None = None + users: list[int] | None = None + message_length_limit: int = 1000 + + @classmethod + async def from_db( + cls, + channel: ChatChannel, + session: AsyncSession, + users: list[int], + user: User, + redis: Redis, + ) -> Self: + c = cls.model_validate(channel) + silence = ( + await session.exec( + select(SilenceUser).where( + SilenceUser.channel_id == channel.channel_id, + SilenceUser.user_id == user.id, + ) + ) + ).first() + + last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg") + if last_msg and last_msg.isdigit(): + last_msg = int(last_msg) + else: + last_msg = None + + last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") + if last_read_id and last_read_id.isdigit(): + last_read_id = int(last_read_id) + else: + last_read_id = last_msg + + if silence is not None: + attribute = ChatUserAttributes( + can_message=False, + can_message_error=silence.reason or "You are muted in this channel.", + last_read_id=last_read_id or 0, + ) + c.moderated = True + else: + attribute = ChatUserAttributes( + can_message=True, + last_read_id=last_read_id or 0, + ) + c.moderated = False + + c.current_user_attributes = attribute + c.users = users + c.last_message_id = last_msg + c.last_read_id = last_read_id + return c + + +# ChatMessage + + +class MessageType(str, Enum): + ACTION = "action" + MARKDOWN = "markdown" + PLAIN = "plain" + + +class ChatMessageBase(UTCBaseModel, SQLModel): + channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") + content: str = Field(sa_column=Column(VARCHAR(1000))) + message_id: int | None = Field(index=True, primary_key=True, default=None) + sender_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + timestamp: datetime = Field( + sa_column=Column(DateTime, index=True), default=datetime.now(UTC) + ) + type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) + uuid: str | None = Field(default=None) + + +class ChatMessage(ChatMessageBase, table=True): + __tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType] + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) + + +class ChatMessageResp(ChatMessageBase): + sender: UserResp | None = None + is_action: bool = False + + @classmethod + async def from_db( + cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None + ) -> "ChatMessageResp": + m = cls.model_validate(db_message.model_dump()) + m.is_action = db_message.type == MessageType.ACTION + if user: + m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES) + else: + m.sender = await UserResp.from_db( + db_message.user, session, RANKING_INCLUDES + ) + return m + + +# SilenceUser + + +class SilenceUser(UTCBaseModel, SQLModel, table=True): + __tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType] + id: int | None = Field(primary_key=True, default=None, index=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True) + until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None) + reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True)) + banned_at: datetime = Field( + sa_column=Column(DateTime, index=True), default=datetime.now(UTC) + ) diff --git a/app/dependencies/database.py b/app/dependencies/database.py index bb4c065..f7e94ce 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import AsyncIterator, Callable from contextvars import ContextVar import json @@ -51,6 +52,17 @@ async def get_db(): yield session +DBFactory = Callable[[], AsyncIterator[AsyncSession]] + + +async def get_db_factory() -> DBFactory: + async def _factory() -> AsyncIterator[AsyncSession]: + async with AsyncSession(engine) as session: + yield session + + return _factory + + # Redis 依赖 def get_redis(): return redis_client diff --git a/app/dependencies/param.py b/app/dependencies/param.py new file mode 100644 index 0000000..174adde --- /dev/null +++ b/app/dependencies/param.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from pydantic import BaseModel, ValidationError + + +def BodyOrForm[T: BaseModel](model: type[T]): + async def dependency( + request: Request, + ) -> T: + content_type = request.headers.get("content-type", "") + + data: dict[str, Any] = {} + if "application/json" in content_type: + try: + data = await request.json() + except Exception: + raise RequestValidationError( + [ + { + "loc": ("body",), + "msg": "Invalid JSON body", + "type": "value_error", + } + ] + ) + else: + form = await request.form() + data = dict(form) + + try: + return model(**data) + except ValidationError as e: + raise RequestValidationError(e.errors()) + + return dependency diff --git a/app/models/chat.py b/app/models/chat.py new file mode 100644 index 0000000..116342f --- /dev/null +++ b/app/models/chat.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + + +class ChatEvent(BaseModel): + event: str + data: dict[str, Any] | None = None diff --git a/app/router/__init__.py b/app/router/__init__.py index 57fc949..6be08fc 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from app.signalr import signalr_router as signalr_router from .auth import router as auth_router +from .chat import chat_router as chat_router from .fetcher import fetcher_router as fetcher_router from .file import file_router as file_router from .private import private_router as private_router @@ -17,6 +18,7 @@ __all__ = [ "api_v1_router", "api_v2_router", "auth_router", + "chat_router", "fetcher_router", "file_router", "private_router", diff --git a/app/router/chat/__init__.py b/app/router/chat/__init__.py new file mode 100644 index 0000000..3ed7897 --- /dev/null +++ b/app/router/chat/__init__.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from app.config import settings +from app.router.v2 import api_v2_router as router + +from . import channel, message # noqa: F401 +from .server import chat_router as chat_router + +from fastapi import Query + +__all__ = ["chat_router"] + + +@router.get("/notifications") +async def get_notifications(max_id: int | None = Query(None)): + if settings.server_url is not None: + notification_endpoint = f"{settings.server_url}notification-server".replace( + "http://", "ws://" + ).replace("https://", "wss://") + else: + notification_endpoint = "/notification-server" + + return { + "has_more": False, + "notifications": [], + "unread_count": 0, + "notification_endpoint": notification_endpoint, + } diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py new file mode 100644 index 0000000..798a5f2 --- /dev/null +++ b/app/router/chat/channel.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Any + +from app.database.chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, +) +from app.database.lazer_user import User, UserResp +from app.dependencies.database import get_db, get_redis +from app.dependencies.user import get_current_user +from app.router.v2 import api_v2_router as router + +from .server import server + +from fastapi import Depends, HTTPException, Query, Security +from pydantic import BaseModel, Field +from redis.asyncio import Redis +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + + +class UpdateResponse(BaseModel): + presence: list[ChatChannelResp] = Field(default_factory=list) + silences: list[Any] = Field(default_factory=list) + + +@router.get("/chat/updates", response_model=UpdateResponse) +async def get_update( + history_since: int | None = Query(None), + since: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + includes: list[str] = Query(["presence"], alias="includes[]"), + redis: Redis = Depends(get_redis), +): + resp = UpdateResponse() + if "presence" in includes: + channel_ids = server.get_user_joined_channel(current_user.id) + for channel_id in channel_ids: + channel = await ChatChannel.get(channel_id, session) + if channel: + resp.presence.append( + await ChatChannelResp.from_db( + channel, + session, + server.channels.get(channel_id, []), + current_user, + redis, + ) + ) + return resp + + +@router.put("/chat/channels/{channel}/users/{user}", response_model=ChatChannelResp) +async def join_channel( + channel: str, + user: str, + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + return await server.join_channel(current_user, db_channel, session) + + +@router.delete( + "/chat/channels/{channel}/users/{user}", + status_code=204, +) +async def leave_channel( + channel: str, + user: str, + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + await server.leave_channel(current_user, db_channel, session) + return + + +@router.get("/chat/channels") +async def get_channel_list( + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + channels = ( + await session.exec( + select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC) + ) + ).all() + results = [] + for channel in channels: + assert channel.channel_id is not None + results.append( + await ChatChannelResp.from_db( + channel, + session, + server.channels.get(channel.channel_id, []), + current_user, + redis, + ) + ) + return results + + +class GetChannelResp(BaseModel): + channel: ChatChannelResp + users: list[UserResp] = Field(default_factory=list) + + +@router.get("/chat/channels/{channel}") +async def get_channel( + channel: str, + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + assert db_channel.channel_id is not None + return GetChannelResp( + channel=await ChatChannelResp.from_db( + db_channel, + session, + server.channels.get(db_channel.channel_id, []), + current_user, + redis, + ) + ) diff --git a/app/router/chat/message.py b/app/router/chat/message.py new file mode 100644 index 0000000..5bab869 --- /dev/null +++ b/app/router/chat/message.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from app.database import ChatMessageResp +from app.database.chat import ChatChannel, ChatMessage, MessageType +from app.database.lazer_user import User +from app.dependencies.database import get_db +from app.dependencies.param import BodyOrForm +from app.dependencies.user import get_current_user +from app.router.v2 import api_v2_router as router + +from .server import server + +from fastapi import Depends, HTTPException, Query, Security +from pydantic import BaseModel +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +@router.post("/chat/ack") +async def keep_alive( + history_since: int | None = Query(None), + since: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + return {"silences": []} + + +class MessageReq(BaseModel): + message: str + is_action: bool = False + uuid: str | None = None + + +@router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp) +async def send_message( + channel: str, + req: MessageReq = Depends(BodyOrForm(MessageReq)), + current_user: User = Security(get_current_user, scopes=["chat.write"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + msg = ChatMessage( + channel_id=db_channel.channel_id, + content=req.message, + sender_id=current_user.id, + type=MessageType.ACTION if req.is_action else MessageType.PLAIN, + uuid=req.uuid, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(current_user) + resp = await ChatMessageResp.from_db(msg, session, current_user) + await server.send_message_to_channel(resp) + return resp + + +@router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp]) +async def get_message( + channel: str, + limit: int = Query(50, ge=1, le=50), + since: int = Query(default=0, ge=0), + until: int | None = Query(None), + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + messages = await session.exec( + select(ChatMessage) + .where( + ChatMessage.channel_id == db_channel.channel_id, + col(ChatMessage.message_id) > since, + col(ChatMessage.message_id) < until if until is not None else True, + ) + .order_by(col(ChatMessage.timestamp).desc()) + .limit(limit) + ) + resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] + resp.reverse() + return resp + + +@router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204) +async def mark_as_read( + channel: str, + message: int, + current_user: User = Security(get_current_user, scopes=["chat.read"]), + session: AsyncSession = Depends(get_db), +): + db_channel = await ChatChannel.get(channel, session) + if db_channel is None: + raise HTTPException(status_code=404, detail="Channel not found") + await server.mark_as_read(db_channel.channel_id, message) diff --git a/app/router/chat/server.py b/app/router/chat/server.py new file mode 100644 index 0000000..727f426 --- /dev/null +++ b/app/router/chat/server.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio + +from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp +from app.database.lazer_user import User +from app.dependencies.database import DBFactory, get_db_factory, get_redis +from app.dependencies.user import get_current_user +from app.log import logger +from app.models.chat import ChatEvent + +from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect +from fastapi.security import SecurityScopes +from fastapi.websockets import WebSocketState +from redis.asyncio import Redis +from sqlmodel.ext.asyncio.session import AsyncSession + + +class ChatServer: + def __init__(self): + self.connect_client: dict[int, WebSocket] = {} + self.channels: dict[int, list[int]] = {} + self.redis: Redis = get_redis() + + self.tasks: set[asyncio.Task] = set() + + def _add_task(self, task): + task = asyncio.create_task(task) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) + + def connect(self, user_id: int, client: WebSocket): + self.connect_client[user_id] = client + + def get_user_joined_channel(self, user_id: int) -> list[int]: + return [ + channel_id + for channel_id, users in self.channels.items() + if user_id in users + ] + + async def disconnect(self, user: User, session: AsyncSession): + user_id = user.id + if user_id in self.connect_client: + del self.connect_client[user_id] + for channel_id, channel in self.channels.items(): + if user_id in channel: + channel.remove(user_id) + channel = await ChatChannel.get(channel_id, session) + if channel: + await self.leave_channel(user, channel, session) + + async def send_event(self, client: WebSocket, event: ChatEvent): + if client.client_state == WebSocketState.CONNECTED: + await client.send_text(event.model_dump_json()) + + async def broadcast(self, channel_id: int, event: ChatEvent): + for user_id in self.channels.get(channel_id, []): + client = self.connect_client.get(user_id) + if client: + await self.send_event(client, event) + + async def mark_as_read(self, channel_id: int, message_id: int): + await self.redis.set(f"chat:{channel_id}:last_msg", message_id) + + async def send_message_to_channel(self, message: ChatMessageResp): + self._add_task( + self.broadcast( + message.channel_id, + ChatEvent( + event="chat.message.new", + data={"messages": [message], "users": [message.sender]}, + ), + ) + ) + await self.mark_as_read(message.channel_id, message.message_id) + + async def join_channel( + self, user: User, channel: ChatChannel, session: AsyncSession + ) -> ChatChannelResp: + user_id = user.id + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id not in self.channels: + self.channels[channel_id] = [] + if user_id not in self.channels[channel_id]: + self.channels[channel_id].append(user_id) + + channel_resp = await ChatChannelResp.from_db( + channel, session, self.channels[channel_id], user, self.redis + ) + + client = self.connect_client.get(user_id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) + + return channel_resp + + async def leave_channel( + self, user: User, channel: ChatChannel, session: AsyncSession + ) -> None: + user_id = user.id + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id in self.channels and user_id in self.channels[channel_id]: + self.channels[channel_id].remove(user_id) + + if not self.channels.get(channel_id): + del self.channels[channel_id] + + channel_resp = await ChatChannelResp.from_db( + channel, session, self.channels.get(channel_id, []), user, self.redis + ) + client = self.connect_client.get(user_id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.part", + data=channel_resp.model_dump(), + ), + ) + + +server = ChatServer() + +chat_router = APIRouter(include_in_schema=False) + + +async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): + try: + while True: + packets = await ws.receive_json() + if packets.get("event") == "chat.end": + async for session in factory(): + user = await session.get(User, user_id) + if user is None: + break + await server.disconnect(user, session) + await ws.close(code=1000) + break + except WebSocketDisconnect as e: + logger.info( + f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}" + ) + except RuntimeError as e: + if "disconnect message" in str(e): + logger.info(f"[NotificationServer] Client {user_id} closed the connection.") + else: + logger.exception(f"RuntimeError in client {user_id}: {e}") + except Exception: + logger.exception(f"Error in client {user_id}") + + +@chat_router.websocket("/notification-server") +async def chat_websocket( + websocket: WebSocket, + authorization: str = Header(...), + factory: DBFactory = Depends(get_db_factory), +): + async for session in factory(): + token = authorization[7:] + if ( + user := await get_current_user( + SecurityScopes(scopes=["chat.read"]), session, token_pw=token + ) + ) is None: + await websocket.close(code=1008) + return + + await websocket.accept() + login = await websocket.receive_json() + if login.get("event") != "chat.start": + await websocket.close(code=1008) + return + user_id = user.id + assert user_id + server.connect(user_id, websocket) + channel = await ChatChannel.get(1, session) + if channel is not None: + await server.join_channel(user, channel, session) + await _listen_stop(websocket, user_id, factory) diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py index 144dfd0..8fc5654 100644 --- a/app/service/subscribers/base.py +++ b/app/service/subscribers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable +from fnmatch import fnmatch from typing import Any from app.dependencies.database import get_redis_pubsub @@ -29,12 +30,17 @@ class RedisSubscriber: ignore_subscribe_messages=True, timeout=None ) if message is not None and message["type"] == "message": - method = self.handlers.get(message["channel"]) - if method: + matched_handlers = [] + if message["channel"] in self.handlers: + matched_handlers.extend(self.handlers[message["channel"]]) + for pattern, handlers in self.handlers.items(): + if fnmatch(message["channel"], pattern): + matched_handlers.extend(handlers) + if matched_handlers: await asyncio.gather( *[ handler(message["channel"], message["data"]) - for handler in method + for handler in matched_handlers ] ) diff --git a/main.py b/main.py index c2739c4..66471cc 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ from app.router import ( api_v1_router, api_v2_router, auth_router, + chat_router, fetcher_router, file_router, private_router, @@ -71,6 +72,7 @@ app = FastAPI( app.include_router(api_v2_router) app.include_router(api_v1_router) +app.include_router(chat_router) app.include_router(redirect_api_router) app.include_router(signalr_router) app.include_router(fetcher_router) diff --git a/migrations/versions/dd33d89aa2c2_chat_add_chat.py b/migrations/versions/dd33d89aa2c2_chat_add_chat.py new file mode 100644 index 0000000..f548f55 --- /dev/null +++ b/migrations/versions/dd33d89aa2c2_chat_add_chat.py @@ -0,0 +1,192 @@ +"""chat: add chat + +Revision ID: dd33d89aa2c2 +Revises: 9f6b27e8ea51 +Create Date: 2025-08-15 14:22:34.775877 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "dd33d89aa2c2" +down_revision: str | Sequence[str] | None = "9f6b27e8ea51" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + channel_table = op.create_table( + "chat_channels", + sa.Column("name", sa.VARCHAR(length=50), nullable=True), + sa.Column("description", sa.VARCHAR(length=255), nullable=True), + sa.Column("icon", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column( + "type", + sa.Enum( + "PUBLIC", + "PRIVATE", + "MULTIPLAYER", + "SPECTATOR", + "TEMPORARY", + "PM", + "GROUP", + "SYSTEM", + "ANNOUNCE", + "TEAM", + name="channeltype", + ), + nullable=False, + ), + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("channel_id"), + ) + op.create_index( + op.f("ix_chat_channels_channel_id"), + "chat_channels", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_channels_description"), + "chat_channels", + ["description"], + unique=False, + ) + op.create_index( + op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False + ) + op.create_index( + op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False + ) + op.create_table( + "chat_messages", + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.Column("content", sa.VARCHAR(length=1000), nullable=True), + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("sender_id", sa.BigInteger(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column( + "type", + sa.Enum("ACTION", "MARKDOWN", "PLAIN", name="messagetype"), + nullable=False, + ), + sa.Column("uuid", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["channel_id"], + ["chat_channels.channel_id"], + ), + sa.ForeignKeyConstraint( + ["sender_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_index( + op.f("ix_chat_messages_channel_id"), + "chat_messages", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_messages_message_id"), + "chat_messages", + ["message_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False + ) + op.create_index( + op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False + ) + op.create_index( + op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False + ) + op.create_table( + "chat_silence_users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("channel_id", sa.Integer(), nullable=False), + sa.Column("until", sa.DateTime(), nullable=True), + sa.Column("banned_at", sa.DateTime(), nullable=False), + sa.Column("reason", sa.VARCHAR(length=255), nullable=True), + sa.ForeignKeyConstraint( + ["channel_id"], + ["chat_channels.channel_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_chat_silence_users_channel_id"), + "chat_silence_users", + ["channel_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_reason"), + "chat_silence_users", + ["reason"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_until"), + "chat_silence_users", + ["until"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_banned_at"), + "chat_silence_users", + ["banned_at"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_user_id"), + "chat_silence_users", + ["user_id"], + unique=False, + ) + op.create_index( + op.f("ix_chat_silence_users_id"), + "chat_silence_users", + ["id"], + unique=False, + ) + op.bulk_insert( + channel_table, + [ + { + "name": "osu!", + "description": "General discussion for osu!", + "type": "PUBLIC", + }, + { + "name": "announce", + "description": "Official announcements", + "type": "PUBLIC", + }, + ], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("chat_silence_users") + op.drop_table("chat_messages") + op.drop_table("chat_channels") + # ### end Alembic commands ### From 368bdfe58811bce30ebd67afb56a4fd74720c2a2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 07:48:19 +0000 Subject: [PATCH 02/10] feat(chat): support pm --- app/database/chat.py | 43 ++++++++++++-- app/database/lazer_user.py | 40 +++++++++++++ app/router/chat/channel.py | 117 +++++++++++++++++++++++++++++++++++-- app/router/chat/message.py | 83 +++++++++++++++++++++++++- app/router/chat/server.py | 49 ++++++++++++++-- 5 files changed, 316 insertions(+), 16 deletions(-) diff --git a/app/database/chat.py b/app/database/chat.py index 565657d..777e2ac 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -16,6 +16,7 @@ from sqlmodel import ( ForeignKey, Relationship, SQLModel, + col, select, ) from sqlmodel.ext.asyncio.session import AsyncSession @@ -65,6 +66,15 @@ class ChatChannel(ChatChannelBase, table=True): await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() + @classmethod + async def get_pm_channel( + cls, user1: int, user2: int, session: AsyncSession + ) -> "ChatChannel | None": + channel = await cls.get(f"pm_{user1}_{user2}", session) + if channel is None: + channel = await cls.get(f"pm_{user2}_{user1}", session) + return channel + class ChatChannelResp(ChatChannelBase): channel_id: int @@ -73,8 +83,8 @@ class ChatChannelResp(ChatChannelBase): current_user_attributes: ChatUserAttributes | None = None last_read_id: int | None = None last_message_id: int | None = None - recent_messages: list[str] | None = None - users: list[int] | None = None + recent_messages: list["ChatMessageResp"] = Field(default_factory=list) + users: list[int] = Field(default_factory=list) message_length_limit: int = 1000 @classmethod @@ -82,9 +92,10 @@ class ChatChannelResp(ChatChannelBase): cls, channel: ChatChannel, session: AsyncSession, - users: list[int], user: User, redis: Redis, + users: list[int] | None = None, + include_recent_messages: bool = False, ) -> Self: c = cls.model_validate(channel) silence = ( @@ -123,9 +134,33 @@ class ChatChannelResp(ChatChannelBase): c.moderated = False c.current_user_attributes = attribute - c.users = users + if c.type != ChannelType.PUBLIC and users is not None: + c.users = users c.last_message_id = last_msg c.last_read_id = last_read_id + + if include_recent_messages: + messages = ( + await session.exec( + select(ChatMessage) + .where(ChatMessage.channel_id == channel.channel_id) + .order_by(col(ChatMessage.timestamp).desc()) + .limit(10) + ) + ).all() + c.recent_messages = [ + await ChatMessageResp.from_db(msg, session, user) for msg in messages + ] + c.recent_messages.reverse() + + if c.type == ChannelType.PM and users and len(users) == 2: + target_user_id = next(u for u in users if u != user.id) + target_name = await session.exec( + select(User.username).where(User.id == target_user_id) + ) + c.name = target_name.one() + assert user.id + c.users = [target_user_id, user.id] return c diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index ca767f5..e242fa4 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) + async def is_user_can_pm( + self, from_user: "User", session: AsyncSession + ) -> tuple[bool, str]: + from .relationship import Relationship, RelationshipType + + from_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == from_user.id, + Relationship.target_id == self.id, + ) + ) + ).first() + if from_relationship and from_relationship.type == RelationshipType.BLOCK: + return False, "You have blocked the target user." + if from_user.pm_friends_only and ( + not from_relationship or from_relationship.type != RelationshipType.FOLLOW + ): + return ( + False, + "You have disabled non-friend communications " + "and target user is not your friend.", + ) + + relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == self.id, + Relationship.target_id == from_user.id, + ) + ) + ).first() + if relationship and relationship.type == RelationshipType.BLOCK: + return False, "Target user has blocked you." + if self.pm_friends_only and ( + not relationship or relationship.type != RelationshipType.FOLLOW + ): + return False, "Target user has disabled non-friend communications" + return True, "" + class UserResp(UserBase): id: int | None = None diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index 798a5f2..ceb7b66 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal, Self from app.database.chat import ( ChannelType, @@ -9,15 +9,16 @@ from app.database.chat import ( ) from app.database.lazer_user import User, UserResp from app.dependencies.database import get_db, get_redis +from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router from .server import server from fastapi import Depends, HTTPException, Query, Security -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from redis.asyncio import Redis -from sqlmodel import select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -37,6 +38,7 @@ async def get_update( ): resp = UpdateResponse() if "presence" in includes: + assert current_user.id channel_ids = server.get_user_joined_channel(current_user.id) for channel_id in channel_ids: channel = await ChatChannel.get(channel_id, session) @@ -45,9 +47,11 @@ async def get_update( await ChatChannelResp.from_db( channel, session, - server.channels.get(channel_id, []), current_user, redis, + server.channels.get(channel_id, []) + if channel.type != ChannelType.PUBLIC + else None, ) ) return resp @@ -103,9 +107,11 @@ async def get_channel_list( await ChatChannelResp.from_db( channel, session, - server.channels.get(channel.channel_id, []), current_user, redis, + server.channels.get(channel.channel_id, []) + if channel.type != ChannelType.PUBLIC + else None, ) ) return results @@ -127,12 +133,111 @@ async def get_channel( if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") assert db_channel.channel_id is not None + + users = [] + if db_channel.type == ChannelType.PM: + user_ids = db_channel.name.split("_")[1:] + if len(user_ids) != 2: + raise HTTPException(status_code=404, detail="Target user not found") + for id_ in user_ids: + if int(id_) == current_user.id: + continue + target_user = await session.get(User, int(id_)) + if target_user is None: + raise HTTPException(status_code=404, detail="Target user not found") + users.extend([target_user, current_user]) + break + return GetChannelResp( channel=await ChatChannelResp.from_db( db_channel, session, - server.channels.get(db_channel.channel_id, []), current_user, redis, + server.channels.get(db_channel.channel_id, []) + if db_channel.type != ChannelType.PUBLIC + else None, ) ) + + +class CreateChannelReq(BaseModel): + class AnnounceChannel(BaseModel): + name: str + description: str + + message: str | None = None + type: Literal["ANNOUNCE", "PM"] = "PM" + target_id: int | None = None + target_ids: list[int] | None = None + channel: AnnounceChannel | None = None + + @model_validator(mode="after") + def check(self) -> Self: + if self.type == "PM": + if self.target_id is None: + raise ValueError("target_id must be set for PM channels") + else: + if self.target_ids is None or self.channel is None or self.message is None: + raise ValueError( + "target_ids, channel, and message must be set for ANNOUNCE channels" + ) + return self + + +@router.post("/chat/channels") +async def create_channel( + req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + if req.type == "PM": + target = await session.get(User, req.target_id) + if not target: + raise HTTPException(status_code=404, detail="Target user not found") + is_can_pm, block = await target.is_user_can_pm(current_user, session) + if not is_can_pm: + raise HTTPException(status_code=403, detail=block) + + channel = await ChatChannel.get_pm_channel( + current_user.id, # pyright: ignore[reportArgumentType] + req.target_id, # pyright: ignore[reportArgumentType] + session, + ) + channel_name = f"pm_{current_user.id}_{req.target_id}" + else: + channel_name = req.channel.name if req.channel else "Unnamed Channel" + channel = await ChatChannel.get(channel_name, session) + + if channel is None: + channel = ChatChannel( + name=channel_name, + description=req.channel.description + if req.channel + else "Private message channel", + type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(current_user) + if req.type == "PM": + await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable] + await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable] + else: + target_users = await session.exec( + select(User).where(col(User.id).in_(req.target_ids or [])) + ) + await server.batch_join_channel([*target_users, current_user], channel, session) + + await server.join_channel(current_user, channel, session) + assert channel.channel_id + return await ChatChannelResp.from_db( + channel, + session, + current_user, + redis, + server.channels.get(channel.channel_id, []), + include_recent_messages=True, + ) diff --git a/app/router/chat/message.py b/app/router/chat/message.py index 5bab869..dc4f134 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -1,9 +1,15 @@ from __future__ import annotations from app.database import ChatMessageResp -from app.database.chat import ChatChannel, ChatMessage, MessageType +from app.database.chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, + ChatMessage, + MessageType, +) from app.database.lazer_user import User -from app.dependencies.database import get_db +from app.dependencies.database import get_db, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router @@ -12,6 +18,7 @@ from .server import server from fastapi import Depends, HTTPException, Query, Security from pydantic import BaseModel +from redis.asyncio import Redis from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -42,6 +49,9 @@ async def send_message( db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") + + assert db_channel.channel_id + assert current_user.id msg = ChatMessage( channel_id=db_channel.channel_id, content=req.message, @@ -95,4 +105,73 @@ async def mark_as_read( db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") + assert db_channel.channel_id await server.mark_as_read(db_channel.channel_id, message) + + +class PMReq(BaseModel): + target_id: int + message: str + is_action: bool = False + uuid: str | None = None + + +class NewPMResp(BaseModel): + channel: ChatChannelResp + message: ChatMessageResp + new_channel_id: int + + +@router.post("/chat/new") +async def create_new_pm( + req: PMReq = Depends(BodyOrForm(PMReq)), + current_user: User = Security(get_current_user, scopes=["chat.write"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + user_id = current_user.id + target = await session.get(User, req.target_id) + if target is None: + raise HTTPException(status_code=404, detail="Target user not found") + is_can_pm, block = await target.is_user_can_pm(current_user, session) + if not is_can_pm: + raise HTTPException(status_code=403, detail=block) + + assert user_id + channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session) + if channel is None: + channel = ChatChannel( + name=f"pm_{user_id}_{req.target_id}", + description="Private message channel", + type=ChannelType.PM, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(target) + await session.refresh(current_user) + + assert channel.channel_id + await server.batch_join_channel([target, current_user], channel, session) + channel_resp = await ChatChannelResp.from_db( + channel, session, current_user, redis, server.channels[channel.channel_id] + ) + msg = ChatMessage( + channel_id=channel.channel_id, + content=req.message, + sender_id=user_id, + type=MessageType.ACTION if req.is_action else MessageType.PLAIN, + uuid=req.uuid, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(current_user) + await session.refresh(channel) + message_resp = await ChatMessageResp.from_db(msg, session, current_user) + await server.send_message_to_channel(message_resp) + return NewPMResp( + channel=channel_resp, + message=message_resp, + new_channel_id=channel_resp.channel_id, + ) diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 727f426..2c91952 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp +from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User from app.dependencies.database import DBFactory, get_db_factory, get_redis from app.dependencies.user import get_current_user @@ -73,14 +73,50 @@ class ChatServer: ), ) ) + assert message.message_id await self.mark_as_read(message.channel_id, message.message_id) + async def batch_join_channel( + self, users: list[User], channel: ChatChannel, session: AsyncSession + ): + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id not in self.channels: + self.channels[channel_id] = [] + for user_id in [user.id for user in users]: + assert user_id is not None + if user_id not in self.channels[channel_id]: + self.channels[channel_id].append(user_id) + + for user in users: + assert user.id is not None + channel_resp = await ChatChannelResp.from_db( + channel, + session, + user, + self.redis, + self.channels[channel_id] + if channel.type != ChannelType.PUBLIC + else None, + ) + client = self.connect_client.get(user.id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) + async def join_channel( self, user: User, channel: ChatChannel, session: AsyncSession ) -> ChatChannelResp: user_id = user.id channel_id = channel.channel_id assert channel_id is not None + assert user_id is not None if channel_id not in self.channels: self.channels[channel_id] = [] @@ -88,7 +124,11 @@ class ChatServer: self.channels[channel_id].append(user_id) channel_resp = await ChatChannelResp.from_db( - channel, session, self.channels[channel_id], user, self.redis + channel, + session, + user, + self.redis, + self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) client = self.connect_client.get(user_id) @@ -109,15 +149,16 @@ class ChatServer: user_id = user.id channel_id = channel.channel_id assert channel_id is not None + assert user_id is not None if channel_id in self.channels and user_id in self.channels[channel_id]: self.channels[channel_id].remove(user_id) - if not self.channels.get(channel_id): + if (c := self.channels.get(channel_id)) is not None and not c: del self.channels[channel_id] channel_resp = await ChatChannelResp.from_db( - channel, session, self.channels.get(channel_id, []), user, self.redis + channel, session, user, self.redis, self.channels[channel_id] ) client = self.connect_client.get(user_id) if client: From 3de73f242091ffaa6369f433ea85081c3a18a282 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 08:42:40 +0000 Subject: [PATCH 03/10] feat(chat): support mp/playlist chat --- app/database/room.py | 8 ++- app/models/multiplayer_hub.py | 7 +-- app/router/chat/server.py | 47 +++++++++++++++- app/router/v2/room.py | 12 +++-- app/service/room.py | 25 +++++++++ app/service/subscribers/base.py | 5 ++ app/service/subscribers/chat.py | 38 +++++++++++++ app/signalr/hub/multiplayer.py | 20 +++++++ .../df9f725a077c_room_add_channel_id.py | 54 +++++++++++++++++++ 9 files changed, 206 insertions(+), 10 deletions(-) create mode 100644 app/service/subscribers/chat.py create mode 100644 migrations/versions/df9f725a077c_room_add_channel_id.py diff --git a/app/database/room.py b/app/database/room.py index 368a04a..54497db 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel): auto_skip: bool auto_start_duration: int status: RoomStatus - # TODO: channel_id + channel_id: int | None = None class Room(AsyncAttrs, RoomBase, table=True): @@ -84,6 +84,7 @@ class RoomResp(RoomBase): current_playlist_item: PlaylistResp | None = None current_user_score: PlaylistAggregateScore | None = None recent_participants: list[UserResp] = Field(default_factory=list) + channel_id: int = 0 @classmethod async def from_db( @@ -93,7 +94,9 @@ class RoomResp(RoomBase): include: list[str] = [], user: User | None = None, ) -> "RoomResp": - resp = cls.model_validate(room.model_dump()) + d = room.model_dump() + d["channel_id"] = d.get("channel_id", 0) or 0 + resp = cls.model_validate(d) stats = RoomPlaylistItemStats(count_active=0, count_total=0) difficulty_range = RoomDifficultyRange( @@ -158,6 +161,7 @@ class RoomResp(RoomBase): # duration = room.settings.duration, starts_at=server_room.start_at, participant_count=len(room.users), + channel_id=server_room.room.channel_id or 0, ) return resp diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 5f99136..9eec7b5 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -44,6 +44,7 @@ from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: + from app.database.room import Room from app.signalr.hub import MultiplayerHub HOST_LIMIT = 50 @@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel): channel_id: int @classmethod - def from_db(cls, room) -> "MultiplayerRoom": + def from_db(cls, room: "Room") -> "MultiplayerRoom": """ 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) """ @@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel): host_user = MultiplayerRoomUser(user_id=room.host_id) # playlist 转换 playlist = [] - if hasattr(room, "playlist"): + if room.playlist: for item in room.playlist: playlist.append( PlaylistItem( @@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel): match_state=None, playlist=playlist, active_countdowns=[], - channel_id=getattr(room, "channel_id", 0), + channel_id=room.channel_id, ) diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 2c91952..6add821 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -4,10 +4,16 @@ import asyncio from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User -from app.dependencies.database import DBFactory, get_db_factory, get_redis +from app.dependencies.database import ( + DBFactory, + engine, + get_db_factory, + get_redis, +) from app.dependencies.user import get_current_user from app.log import logger from app.models.chat import ChatEvent +from app.service.subscribers.chat import ChatSubscriber from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes @@ -23,6 +29,9 @@ class ChatServer: self.redis: Redis = get_redis() self.tasks: set[asyncio.Task] = set() + self.ChatSubscriber = ChatSubscriber() + self.ChatSubscriber.chat_server = self + self._subscribed = False def _add_task(self, task): task = asyncio.create_task(task) @@ -158,7 +167,13 @@ class ChatServer: del self.channels[channel_id] channel_resp = await ChatChannelResp.from_db( - channel, session, user, self.redis, self.channels[channel_id] + channel, + session, + user, + self.redis, + self.channels.get(channel_id) + if channel.type != ChannelType.PUBLIC + else None, ) client = self.connect_client.get(user_id) if client: @@ -170,6 +185,30 @@ class ChatServer: ), ) + async def join_room_channel(self, channel_id: int, user_id: int): + async with AsyncSession(engine) as session: + channel = await ChatChannel.get(channel_id, session) + if channel is None: + return + + user = await session.get(User, user_id) + if user is None: + return + + await self.join_channel(user, channel, session) + + async def leave_room_channel(self, channel_id: int, user_id: int): + async with AsyncSession(engine) as session: + channel = await ChatChannel.get(channel_id, session) + if channel is None: + return + + user = await session.get(User, user_id) + if user is None: + return + + await self.leave_channel(user, channel, session) + server = ChatServer() @@ -207,6 +246,10 @@ async def chat_websocket( authorization: str = Header(...), factory: DBFactory = Depends(get_db_factory), ): + if not server._subscribed: + server._subscribed = True + await server.ChatSubscriber.start_subscribe() + async for session in factory(): token = authorization[7:] if ( diff --git a/app/router/v2/room.py b/app/router/v2/room.py index ff62b4c..9a66e8a 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp): async def _participate_room( - room_id: int, user_id: int, db_room: Room, session: AsyncSession + room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis ): participated_user = ( await session.exec( @@ -138,6 +138,8 @@ async def _participate_room( participated_user.joined_at = datetime.now(UTC) db_room.participant_count += 1 + await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}") + @router.post( "/rooms", @@ -150,11 +152,12 @@ async def create_room( room: APIUploadedRoom, db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), + redis: Redis = Depends(get_redis), ): assert current_user.id is not None user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) - await _participate_room(db_room.id, user_id, db_room, db) + await _participate_room(db_room.id, user_id, db_room, db, redis) created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db)) created_room.error = "" return created_room @@ -219,11 +222,12 @@ async def add_user_to_room( room_id: int = Path(..., description="房间 ID"), user_id: int = Path(..., description="用户 ID"), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), current_user: User = Security(get_client_user), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: - await _participate_room(room_id, user_id, db_room, db) + await _participate_room(room_id, user_id, db_room, db, redis) await db.commit() await db.refresh(db_room) resp = await RoomResp.from_db(db_room, db) @@ -243,6 +247,7 @@ async def remove_user_from_room( user_id: int = Path(..., description="用户 ID"), db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), + redis: Redis = Depends(get_redis), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: @@ -257,6 +262,7 @@ async def remove_user_from_room( if participated_user is not None: participated_user.left_at = datetime.now(UTC) db_room.participant_count -= 1 + await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}") await db.commit() return None else: diff --git a/app/service/room.py b/app/service/room.py index d11dced..d6fdcc6 100644 --- a/app/service/room.py +++ b/app/service/room.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta from app.database.beatmap import Beatmap +from app.database.chat import ChannelType, ChatChannel from app.database.playlists import Playlist from app.database.room import APIUploadedRoom, Room from app.dependencies.fetcher import get_fetcher @@ -25,6 +26,18 @@ async def create_playlist_room_from_api( session.add(db_room) await session.commit() await session.refresh(db_room) + + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Playlist room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + db_room.channel_id = channel.channel_id + await add_playlists_to_room(session, db_room.id, room.playlist, host_id) await session.refresh(db_room) return db_room @@ -57,6 +70,18 @@ async def create_playlist_room( session.add(db_room) await session.commit() await session.refresh(db_room) + + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Playlist room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + db_room.channel_id = channel.channel_id + await add_playlists_to_room(session, db_room.id, playlist, host_id) await session.refresh(db_room) return db_room diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py index 8fc5654..693d4d8 100644 --- a/app/service/subscribers/base.py +++ b/app/service/subscribers/base.py @@ -24,6 +24,11 @@ class RedisSubscriber: del self.handlers[channel] await self.pubsub.unsubscribe(channel) + def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]): + if channel not in self.handlers: + self.handlers[channel] = [] + self.handlers[channel].append(handler) + async def listen(self): while True: message = await self.pubsub.get_message( diff --git a/app/service/subscribers/chat.py b/app/service/subscribers/chat.py new file mode 100644 index 0000000..0c61c06 --- /dev/null +++ b/app/service/subscribers/chat.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import RedisSubscriber + +if TYPE_CHECKING: + from app.router.chat.server import ChatServer + + +JOIN_CHANNEL = "chat:room:joined" +EXIT_CHANNEL = "chat:room:left" + + +class ChatSubscriber(RedisSubscriber): + def __init__(self): + super().__init__() + self.room_subscriber: dict[int, list[int]] = {} + self.chat_server: "ChatServer | None" = None + + async def start_subscribe(self): + await self.subscribe(JOIN_CHANNEL) + self.add_handler(JOIN_CHANNEL, self.on_join_room) + await self.subscribe(EXIT_CHANNEL) + self.add_handler(EXIT_CHANNEL, self.on_leave_room) + self.start() + + async def on_join_room(self, c: str, s: str): + channel_id, user_id = s.split(":") + if self.chat_server is None: + return + await self.chat_server.join_room_channel(int(channel_id), int(user_id)) + + async def on_leave_room(self, c: str, s: str): + channel_id, user_id = s.split(":") + if self.chat_server is None: + return + await self.chat_server.leave_room_channel(int(channel_id), int(user_id)) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 4ada615..20eecf6 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -6,6 +6,7 @@ from typing import override from app.database import Room from app.database.beatmap import Beatmap +from app.database.chat import ChannelType, ChatChannel from app.database.lazer_user import User from app.database.multiplayer_event import MultiplayerEvent from app.database.playlists import Playlist @@ -195,6 +196,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await session.commit() await session.refresh(db_room) + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Multiplayer room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue] + db_room.channel_id = channel.channel_id + item = room.playlist[0] item.owner_id = client.user_id room.room_id = db_room.id @@ -280,6 +293,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if db_room is None: raise InvokeException("Room does not exist in database") db_room.participant_count += 1 + + redis = get_redis() + await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}") + return room async def change_beatmap_availability( @@ -914,6 +931,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if target_store: target_store.room_id = 0 + redis = get_redis() + await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}") + async def end_room(self, room: ServerMultiplayerRoom): assert room.room.host async with AsyncSession(engine) as session: diff --git a/migrations/versions/df9f725a077c_room_add_channel_id.py b/migrations/versions/df9f725a077c_room_add_channel_id.py new file mode 100644 index 0000000..85e7d93 --- /dev/null +++ b/migrations/versions/df9f725a077c_room_add_channel_id.py @@ -0,0 +1,54 @@ +"""room: add channel_id + +Revision ID: df9f725a077c +Revises: dd33d89aa2c2 +Create Date: 2025-08-16 08:05:28.748265 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "df9f725a077c" +down_revision: str | Sequence[str] | None = "dd33d89aa2c2" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True + ) + op.create_index( + op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False + ) + op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("rooms", "channel_id") + op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users") + op.alter_column( + "chat_silence_users", + "banned_at", + existing_type=mysql.DATETIME(), + nullable=False, + ) + op.alter_column( + "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + # ### end Alembic commands ### From e1d42743d3131728eb63bfa9de823ebc546bdcf2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 10:31:46 +0000 Subject: [PATCH 04/10] feat(chat): support BanchoBot --- app/const.py | 3 + app/models/score.py | 10 ++ app/router/auth.py | 2 +- app/router/chat/banchobot.py | 211 +++++++++++++++++++++++++++++++ app/router/chat/message.py | 7 +- app/router/chat/server.py | 39 ++++-- app/router/v2/user.py | 12 +- app/service/create_banchobot.py | 30 +++++ app/service/daily_challenge.py | 5 +- app/service/osu_rx_statistics.py | 4 + main.py | 2 + 11 files changed, 302 insertions(+), 23 deletions(-) create mode 100644 app/const.py create mode 100644 app/router/chat/banchobot.py create mode 100644 app/service/create_banchobot.py diff --git a/app/const.py b/app/const.py new file mode 100644 index 0000000..78ad45c --- /dev/null +++ b/app/const.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +BANCHOBOT_ID = 2 diff --git a/app/models/score.py b/app/models/score.py index 703b55c..53b61c5 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -88,6 +88,16 @@ class GameMode(str, Enum): }[self] return self + @classmethod + def parse(cls, v: str | int) -> "GameMode | None": + if isinstance(v, int) or v.isdigit(): + return cls.from_int_extra(int(v)) + v = v.lower() + try: + return cls[v] + except ValueError: + return None + class Rank(str, Enum): X = "X" diff --git a/app/router/auth.py b/app/router/auth.py index a1149fc..88528e2 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -459,7 +459,7 @@ async def oauth_token( # 存储令牌 await store_token( db, - 3, + 2, client_id, scopes, access_token, diff --git a/app/router/chat/banchobot.py b/app/router/chat/banchobot.py new file mode 100644 index 0000000..76c8beb --- /dev/null +++ b/app/router/chat/banchobot.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from math import ceil +import random +import shlex + +from app.const import BANCHOBOT_ID +from app.database import ChatMessageResp +from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType +from app.database.lazer_user import User +from app.database.score import Score +from app.database.statistics import UserStatistics, get_rank +from app.models.score import GameMode + +from .server import server + +from sqlmodel import func, select +from sqlmodel.ext.asyncio.session import AsyncSession + +HandlerResult = str | None | Awaitable[str | None] +Handler = Callable[[User, list[str], AsyncSession], HandlerResult] + + +class Bot: + def __init__(self, bot_user_id: int = BANCHOBOT_ID) -> None: + self._handlers: dict[str, Handler] = {} + self.bot_user_id = bot_user_id + + # decorator: @bot.command("ping") + def command(self, name: str) -> Callable[[Handler], Handler]: + def _decorator(func: Handler) -> Handler: + self._handlers[name.lower()] = func + return func + + return _decorator + + def parse(self, content: str) -> tuple[str, list[str]] | None: + if not content or not content.startswith("!"): + return None + try: + parts = shlex.split(content[1:]) + except ValueError: + parts = content[1:].split() + if not parts: + return None + cmd = parts[0].lower() + args = parts[1:] + return cmd, args + + async def try_handle( + self, + user: User, + channel: ChatChannel, + content: str, + session: AsyncSession, + ) -> None: + parsed = self.parse(content) + if not parsed: + return + cmd, args = parsed + handler = self._handlers.get(cmd) + + reply: str | None = None + if handler is None: + return + else: + res = handler(user, args, session) + if asyncio.iscoroutine(res): + res = await res + reply = res # type: ignore[assignment] + + if reply: + await self._send_reply(user, channel, reply, session) + + async def _send_message( + self, channel: ChatChannel, content: str, session: AsyncSession + ) -> None: + bot = await session.get(User, self.bot_user_id) + if bot is None: + return + channel_id = channel.channel_id + if channel_id is None: + return + + assert bot.id is not None + msg = ChatMessage( + channel_id=channel_id, + content=content, + sender_id=bot.id, + type=MessageType.PLAIN, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(bot) + resp = await ChatMessageResp.from_db(msg, session, bot) + await server.send_message_to_channel(resp) + + async def _ensure_pm_channel( + self, user: User, session: AsyncSession + ) -> ChatChannel | None: + user_id = user.id + if user_id is None: + return None + + bot = await session.get(User, self.bot_user_id) + if bot is None or bot.id is None: + return None + + channel = await ChatChannel.get_pm_channel(user_id, bot.id, session) + if channel is None: + channel = ChatChannel( + name=f"pm_{user_id}_{bot.id}", + description="Private message channel", + type=ChannelType.PM, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(user) + await session.refresh(bot) + await server.batch_join_channel([user, bot], channel, session) + return channel + + async def _send_reply( + self, + user: User, + src_channel: ChatChannel, + content: str, + session: AsyncSession, + ) -> None: + target_channel = src_channel + if src_channel.type == ChannelType.PUBLIC: + pm = await self._ensure_pm_channel(user, session) + if pm is not None: + target_channel = pm + await self._send_message(target_channel, content, session) + + +bot = Bot() + + +@bot.command("help") +async def _help(user: User, args: list[str], _session: AsyncSession) -> str: + cmds = sorted(bot._handlers.keys()) + if args: + target = args[0].lower() + if target in bot._handlers: + return f"Use: !{target} [args]" + return f"No such command: {target}" + if not cmds: + return "No available commands" + return "Available: " + ", ".join(f"!{c}" for c in cmds) + + +@bot.command("roll") +def _roll(user: User, args: list[str], _session: AsyncSession) -> str: + if len(args) > 0 and args[0].isdigit(): + r = random.randint(1, int(args[0])) + else: + r = random.randint(1, 100) + return f"{user.username} rolls {r} point(s)" + + +@bot.command("stats") +async def _stats(user: User, args: list[str], session: AsyncSession) -> str: + if len(args) < 1: + return "Usage: !stats " + + target_user = ( + await session.exec(select(User).where(User.username == args[0])) + ).first() + if not target_user: + return f"User '{args[0]}' not found." + + gamemode = None + if len(args) >= 2: + gamemode = GameMode.parse(args[1].upper()) + if gamemode is None: + subquery = ( + select(func.max(Score.id)) + .where(Score.user_id == target_user.id) + .scalar_subquery() + ) + last_score = ( + await session.exec(select(Score).where(Score.id == subquery)) + ).first() + if last_score is not None: + gamemode = last_score.gamemode + else: + gamemode = target_user.playmode + + statistics = ( + await session.exec( + select(UserStatistics).where( + UserStatistics.user_id == target_user.id, + UserStatistics.mode == gamemode, + ) + ) + ).first() + if not statistics: + return f"User '{args[0]}' has no statistics." + + return f"""Stats for {target_user.username} ({gamemode.name.lower()}): +Score: {statistics.total_score} (#{await get_rank(session, statistics)}) +Plays: {statistics.play_count} (lv{ceil(statistics.level_current)}) +Accuracy: {statistics.hit_accuracy} +PP: {statistics.pp} +""" diff --git a/app/router/chat/message.py b/app/router/chat/message.py index dc4f134..318aefe 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -14,6 +14,7 @@ from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router +from .banchobot import bot from .server import server from fastapi import Depends, HTTPException, Query, Security @@ -63,8 +64,12 @@ async def send_message( await session.commit() await session.refresh(msg) await session.refresh(current_user) + await session.refresh(db_channel) resp = await ChatMessageResp.from_db(msg, session, current_user) - await server.send_message_to_channel(resp) + is_bot_command = req.message.startswith("!") + await server.send_message_to_channel(resp, is_bot_command) + if is_bot_command: + await bot.try_handle(current_user, db_channel, req.message, session) return resp diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 6add821..b91dd02 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -72,16 +72,24 @@ class ChatServer: async def mark_as_read(self, channel_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_msg", message_id) - async def send_message_to_channel(self, message: ChatMessageResp): - self._add_task( - self.broadcast( - message.channel_id, - ChatEvent( - event="chat.message.new", - data={"messages": [message], "users": [message.sender]}, - ), - ) + async def send_message_to_channel( + self, message: ChatMessageResp, is_bot_command: bool = False + ): + event = ChatEvent( + event="chat.message.new", + data={"messages": [message], "users": [message.sender]}, ) + if is_bot_command: + client = self.connect_client.get(message.sender_id) + if client: + self._add_task(self.send_event(client, event)) + else: + self._add_task( + self.broadcast( + message.channel_id, + event, + ) + ) assert message.message_id await self.mark_as_read(message.channel_id, message.message_id) @@ -91,14 +99,17 @@ class ChatServer: channel_id = channel.channel_id assert channel_id is not None + not_joined = [] + if channel_id not in self.channels: self.channels[channel_id] = [] - for user_id in [user.id for user in users]: - assert user_id is not None - if user_id not in self.channels[channel_id]: - self.channels[channel_id].append(user_id) - for user in users: + assert user.id is not None + if user.id not in self.channels[channel_id]: + self.channels[channel_id].append(user.id) + not_joined.append(user) + + for user in not_joined: assert user.id is not None channel_resp = await ChatChannelResp.from_db( channel, diff --git a/app/router/v2/user.py b/app/router/v2/user.py index dece5be..123e120 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta from typing import Literal +from app.const import BANCHOBOT_ID from app.database import ( BeatmapPlaycounts, BeatmapPlaycountsResp, @@ -65,6 +66,7 @@ async def get_users( include=SEARCH_INCLUDED, ) for searched_user in searched_users + if searched_user.id != BANCHOBOT_ID ] ) @@ -91,7 +93,7 @@ async def get_user_info_ruleset( ) ) ).first() - if not searched_user: + if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") return await UserResp.from_db( searched_user, @@ -123,7 +125,7 @@ async def get_user_info( ) ) ).first() - if not searched_user: + if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") return await UserResp.from_db( searched_user, @@ -148,7 +150,7 @@ async def get_user_beatmapsets( offset: int = Query(0, ge=0, description="偏移量"), ): user = await session.get(User, user_id) - if not user: + if not user or user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") if type in { @@ -218,7 +220,7 @@ async def get_user_scores( current_user: User = Security(get_current_user, scopes=["public"]), ): db_user = await session.get(User, user_id) - if not db_user: + if not db_user or db_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") gamemode = mode or db_user.playmode @@ -271,7 +273,7 @@ async def get_user_events( session: AsyncSession = Depends(get_db), ): db_user = await session.get(User, user) - if db_user is None: + if db_user is None or db_user.id == BANCHOBOT_ID: raise HTTPException(404, "User Not found") events = await db_user.awaitable_attrs.events if limit is not None: diff --git a/app/service/create_banchobot.py b/app/service/create_banchobot.py new file mode 100644 index 0000000..0d855cf --- /dev/null +++ b/app/service/create_banchobot.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from app.const import BANCHOBOT_ID +from app.database.lazer_user import User +from app.database.statistics import UserStatistics +from app.dependencies.database import engine +from app.models.score import GameMode + +from sqlmodel import exists, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def create_banchobot(): + async with AsyncSession(engine) as session: + is_exist = (await session.exec(select(exists()).where(User.id == 2))).first() + if not is_exist: + banchobot = User( + username="BanchoBot", + email="banchobot@ppy.sh", + is_bot=True, + pw_bcrypt="0", + id=BANCHOBOT_ID, + avatar_url="https://a.ppy.sh/3", + country_code="SH", + website="https://twitter.com/banchoboat", + ) + session.add(banchobot) + statistics = UserStatistics(user_id=BANCHOBOT_ID, mode=GameMode.OSU) + session.add(statistics) + await session.commit() diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py index ec7f9d0..f3dbf05 100644 --- a/app/service/daily_challenge.py +++ b/app/service/daily_challenge.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta import json +from app.const import BANCHOBOT_ID from app.database.playlists import Playlist from app.database.room import Room from app.dependencies.database import engine, get_redis @@ -26,12 +27,12 @@ async def create_daily_challenge_room( return await create_playlist_room( session=session, name=str(today), - host_id=3, + host_id=BANCHOBOT_ID, playlist=[ Playlist( id=0, room_id=0, - owner_id=3, + owner_id=BANCHOBOT_ID, ruleset_id=ruleset_id, beatmap_id=beatmap, required_mods=required_mods, diff --git a/app/service/osu_rx_statistics.py b/app/service/osu_rx_statistics.py index 8a0441f..60f94ce 100644 --- a/app/service/osu_rx_statistics.py +++ b/app/service/osu_rx_statistics.py @@ -1,6 +1,7 @@ from __future__ import annotations from app.config import settings +from app.const import BANCHOBOT_ID from app.database.lazer_user import User from app.database.statistics import UserStatistics from app.dependencies.database import engine @@ -15,6 +16,9 @@ async def create_rx_statistics(): async with AsyncSession(engine) as session: users = (await session.exec(select(User.id))).all() for i in users: + if i == BANCHOBOT_ID: + continue + if settings.enable_rx: for mode in ( GameMode.OSURX, diff --git a/main.py b/main.py index 66471cc..526a92e 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,7 @@ from app.router import ( ) from app.router.redirect import redirect_router from app.service.calculate_all_user_rank import calculate_user_rank +from app.service.create_banchobot import create_banchobot from app.service.daily_challenge import daily_challenge_job from app.service.osu_rx_statistics import create_rx_statistics from app.service.pp_recalculate import recalculate_all_players_pp @@ -43,6 +44,7 @@ async def lifespan(app: FastAPI): await calculate_user_rank(True) init_scheduler() await daily_challenge_job() + await create_banchobot() # on shutdown yield stop_scheduler() From 99018f45e52e63319b5deca140bdc5388cd2bc20 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 14:55:50 +0000 Subject: [PATCH 05/10] feat(chat): support mp command --- app/models/multiplayer_hub.py | 4 +- app/router/chat/banchobot.py | 374 ++++++++++++++++++++++++++++++++- app/signalr/hub/multiplayer.py | 25 ++- 3 files changed, 387 insertions(+), 16 deletions(-) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 9eec7b5..c315300 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -625,7 +625,9 @@ class MultiplayerQueue: async with AsyncSession(engine) as session: await Playlist.delete_item(item.id, self.room.room_id, session) - self.room.playlist.remove(item) + found_item = next((i for i in self.room.playlist if i.id == item.id), None) + if found_item: + self.room.playlist.remove(found_item) self.current_index = self.room.playlist.index(self.upcoming_items[0]) await self.update_order() diff --git a/app/router/chat/banchobot.py b/app/router/chat/banchobot.py index 76c8beb..a92eb15 100644 --- a/app/router/chat/banchobot.py +++ b/app/router/chat/banchobot.py @@ -2,25 +2,39 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable +from datetime import timedelta from math import ceil import random import shlex from app.const import BANCHOBOT_ID from app.database import ChatMessageResp +from app.database.beatmap import Beatmap from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType from app.database.lazer_user import User from app.database.score import Score from app.database.statistics import UserStatistics, get_rank +from app.dependencies.fetcher import get_fetcher +from app.exception import InvokeException +from app.models.mods import APIMod +from app.models.multiplayer_hub import ( + ChangeTeamRequest, + ServerMultiplayerRoom, + StartMatchCountdownRequest, +) +from app.models.room import MatchType, QueueMode from app.models.score import GameMode +from app.signalr.hub import MultiplayerHubs +from app.signalr.hub.hub import Client from .server import server +from httpx import HTTPError from sqlmodel import func, select from sqlmodel.ext.asyncio.session import AsyncSession HandlerResult = str | None | Awaitable[str | None] -Handler = Callable[[User, list[str], AsyncSession], HandlerResult] +Handler = Callable[[User, list[str], AsyncSession, ChatChannel], HandlerResult] class Bot: @@ -66,7 +80,7 @@ class Bot: if handler is None: return else: - res = handler(user, args, session) + res = handler(user, args, session, channel) if asyncio.iscoroutine(res): res = await res reply = res # type: ignore[assignment] @@ -143,7 +157,9 @@ bot = Bot() @bot.command("help") -async def _help(user: User, args: list[str], _session: AsyncSession) -> str: +async def _help( + user: User, args: list[str], _session: AsyncSession, channel: ChatChannel +) -> str: cmds = sorted(bot._handlers.keys()) if args: target = args[0].lower() @@ -156,7 +172,9 @@ async def _help(user: User, args: list[str], _session: AsyncSession) -> str: @bot.command("roll") -def _roll(user: User, args: list[str], _session: AsyncSession) -> str: +def _roll( + user: User, args: list[str], _session: AsyncSession, channel: ChatChannel +) -> str: if len(args) > 0 and args[0].isdigit(): r = random.randint(1, int(args[0])) else: @@ -165,9 +183,11 @@ def _roll(user: User, args: list[str], _session: AsyncSession) -> str: @bot.command("stats") -async def _stats(user: User, args: list[str], session: AsyncSession) -> str: +async def _stats( + user: User, args: list[str], session: AsyncSession, channel: ChatChannel +) -> str: if len(args) < 1: - return "Usage: !stats " + return "Usage: !stats [gamemode]" target_user = ( await session.exec(select(User).where(User.username == args[0])) @@ -209,3 +229,345 @@ Plays: {statistics.play_count} (lv{ceil(statistics.level_current)}) Accuracy: {statistics.hit_accuracy} PP: {statistics.pp} """ + + +async def _mp_name( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp name " + + name = args[0] + try: + settings = room.room.settings.model_copy() + settings.name = name + await MultiplayerHubs.ChangeSettings(signalr_client, settings) + return f"Room name has changed to {name}" + except InvokeException as e: + return e.message + + +async def _mp_set( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp set []" + + teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0]) + if not teammode: + return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus." + queuemode = ( + { + "0": QueueMode.HOST_ONLY, + "1": QueueMode.ALL_PLAYERS, + "2": QueueMode.ALL_PLAYERS_ROUND_ROBIN, + }.get(args[1]) + if len(args) >= 2 + else None + ) + try: + settings = room.room.settings.model_copy() + settings.match_type = teammode + if queuemode: + settings.queue_mode = queuemode + await MultiplayerHubs.ChangeSettings(signalr_client, settings) + return f"Room setting 'teammode' has been changed to {teammode.name.lower()}" + except InvokeException as e: + return e.message + + +async def _mp_host( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp host " + + username = args[0] + user_id = ( + await session.exec(select(User.id).where(User.username == username)) + ).first() + if not user_id: + return f"User '{username}' not found." + + try: + await MultiplayerHubs.TransferHost(signalr_client, user_id) + return f"User '{username}' has been hosted in the room." + except InvokeException as e: + return e.message + + +async def _mp_start( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + timer = None + if len(args) >= 1 and args[0].isdigit(): + timer = int(args[0]) + + try: + if timer is not None: + await MultiplayerHubs.SendMatchRequest( + signalr_client, + StartMatchCountdownRequest(duration=timedelta(seconds=timer)), + ) + return "" + else: + await MultiplayerHubs.StartMatch(signalr_client) + return "Good luck! Enjoy game!" + except InvokeException as e: + return e.message + + +async def _mp_abort( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + try: + await MultiplayerHubs.AbortMatch(signalr_client) + return "Match aborted." + except InvokeException as e: + return e.message + + +async def _mp_team( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +): + if room.room.settings.match_type != MatchType.TEAM_VERSUS: + return "This command is only available in Team Versus mode." + + if len(args) < 2: + return "Usage: !mp team " + + username = args[0] + team = {"red": 0, "blue": 1}.get(args[1]) + if team is None: + return "Invalid team colour. Use 'red' or 'blue'." + + user_id = ( + await session.exec(select(User.id).where(User.username == username)) + ).first() + if not user_id: + return f"User '{username}' not found." + user_client = MultiplayerHubs.get_client_by_id(str(user_id)) + if not user_client: + return f"User '{username}' is not in the room." + + try: + await MultiplayerHubs.SendMatchRequest( + user_client, ChangeTeamRequest(team_id=team) + ) + return "" + except InvokeException as e: + return e.message + + +async def _mp_password( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + password = "" + if len(args) >= 1: + password = args[0] + + try: + settings = room.room.settings.model_copy() + settings.password = password + await MultiplayerHubs.ChangeSettings(signalr_client, settings) + return "Room password has been set." + except InvokeException as e: + return e.message + + +async def _mp_kick( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp kick " + + username = args[0] + user_id = ( + await session.exec(select(User.id).where(User.username == username)) + ).first() + if not user_id: + return f"User '{username}' not found." + + try: + await MultiplayerHubs.KickUser(signalr_client, user_id) + return f"User '{username}' has been kicked from the room." + except InvokeException as e: + return e.message + + +async def _mp_map( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp map []" + + map_id = args[0] + if not map_id.isdigit(): + return "Invalid map ID." + map_id = int(map_id) + playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None + if playmode not in ( + GameMode.OSU, + GameMode.TAIKO, + GameMode.FRUITS, + GameMode.MANIA, + None, + ): + return "Invalid playmode." + + try: + beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id) + if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode: + return ( + f"Cannot convert to {playmode.value}. " + f"Original mode is {beatmap.mode.value}." + ) + except HTTPError: + return "Beatmap not found" + + try: + current_item = room.queue.current_item + item = current_item.model_copy(deep=True) + item.owner_id = signalr_client.user_id + item.beatmap_checksum = beatmap.checksum + item.required_mods = [] + item.allowed_mods = [] + item.freestyle = False + item.beatmap_id = map_id + if playmode is not None: + item.ruleset_id = int(playmode) + if item.expired: + item.id = 0 + item.expired = False + item.played_at = None + await MultiplayerHubs.AddPlaylistItem(signalr_client, item) + else: + await MultiplayerHubs.EditPlaylistItem(signalr_client, item) + return "" + except InvokeException as e: + return e.message + + +async def _mp_mods( + signalr_client: Client, + room: ServerMultiplayerRoom, + args: list[str], + session: AsyncSession, +) -> str: + if len(args) < 1: + return "Usage: !mp mods [ ...]" + + required_mods = [] + allowed_mods = [] + freestyle = False + for arg in args: + if arg == "None": + required_mods.clear() + allowed_mods.clear() + break + elif arg == "Freestyle": + freestyle = True + elif arg.startswith("+"): + mod = arg.removeprefix("+") + if len(mod) != 2: + return f"Invalid mod: {mod}." + allowed_mods.append(APIMod(acronym=mod)) + else: + if len(arg) != 2: + return f"Invalid mod: {arg}." + required_mods.append(APIMod(acronym=arg)) + + try: + current_item = room.queue.current_item + item = current_item.model_copy(deep=True) + item.owner_id = signalr_client.user_id + item.freestyle = freestyle + if not freestyle: + item.allowed_mods = allowed_mods + else: + item.allowed_mods = [] + item.required_mods = required_mods + if item.expired: + item.id = 0 + item.expired = False + item.played_at = None + await MultiplayerHubs.AddPlaylistItem(signalr_client, item) + else: + await MultiplayerHubs.EditPlaylistItem(signalr_client, item) + return "" + except InvokeException as e: + return e.message + + +_MP_COMMANDS = { + "name": _mp_name, + "set": _mp_set, + "host": _mp_host, + "start": _mp_start, + "abort": _mp_abort, + "map": _mp_map, + "mods": _mp_mods, + "kick": _mp_kick, + "password": _mp_password, + "team": _mp_team, +} +_MP_HELP = """!mp name +!mp set [] +!mp host +!mp start [] +!mp abort +!mp map [] +!mp mods [ ...] +!mp kick +!mp password [] +!mp team """ + + +@bot.command("mp") +async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel): + if not channel.name.startswith("room_"): + return + + room_id = int(channel.name[5:]) + room = MultiplayerHubs.rooms.get(room_id) + if not room: + return + signalr_client = MultiplayerHubs.get_client_by_id(str(user.id)) + if not signalr_client: + return + + if len(args) < 1: + return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]" + + command = args[0].lower() + if command not in _MP_COMMANDS: + return f"No such command: {command}" + + return await _MP_COMMANDS[command](signalr_client, room, args[1:], session) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 20eecf6..128fb91 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -173,6 +173,20 @@ class MultiplayerHub(Hub[MultiplayerClientState]): self.get_client_by_id(str(user_id)), server_room, user ) + def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom: + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + return server_room + + def _ensure_host(self, client: Client, server_room: ServerMultiplayerRoom): + room = server_room.room + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + async def CreateRoom(self, client: Client, room: MultiplayerRoom): logger.info(f"[MultiplayerHub] {client.user_id} creating room") store = self.get_or_create_state(client) @@ -1105,17 +1119,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings): - store = self.get_or_create_state(client) - if store.room_id == 0: - raise InvokeException("You are not in a room") - if store.room_id not in self.rooms: - raise InvokeException("Room does not exist") - server_room = self.rooms[store.room_id] + server_room = self._ensure_in_room(client) + self._ensure_host(client, server_room) room = server_room.room - if room.host is None or room.host.user_id != client.user_id: - raise InvokeException("You are not the host of this room") - if room.state != MultiplayerRoomState.OPEN: raise InvokeException("Cannot change settings while playing") From 3f3afab4808b2ba84f1f17753a1eb920305c8299 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 15:01:42 +0000 Subject: [PATCH 06/10] fix(chat): broadcast bot command in non-public channels --- app/router/chat/message.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/router/chat/message.py b/app/router/chat/message.py index 318aefe..5c7c1bd 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -67,7 +67,9 @@ async def send_message( await session.refresh(db_channel) resp = await ChatMessageResp.from_db(msg, session, current_user) is_bot_command = req.message.startswith("!") - await server.send_message_to_channel(resp, is_bot_command) + await server.send_message_to_channel( + resp, is_bot_command and db_channel.type == ChannelType.PUBLIC + ) if is_bot_command: await bot.try_handle(current_user, db_channel, req.message, session) return resp From 4eace3f84e77d52a741d2dacd603fefbed09499b Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 15:21:11 +0000 Subject: [PATCH 07/10] feat(chat): complete sliences --- app/database/chat.py | 12 ++++++++++++ app/router/chat/channel.py | 26 ++++++++++++++++++++++++++ app/router/chat/message.py | 32 ++++++++++++++++++++++++++++++-- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/app/database/chat.py b/app/database/chat.py index 777e2ac..29334d6 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -226,3 +226,15 @@ class SilenceUser(UTCBaseModel, SQLModel, table=True): banned_at: datetime = Field( sa_column=Column(DateTime, index=True), default=datetime.now(UTC) ) + + +class UserSilenceResp(SQLModel): + id: int + user_id: int + + @classmethod + def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp": + return cls( + id=db_silence.id, + user_id=db_silence.user_id, + ) diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index ceb7b66..2b42dfe 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -6,6 +6,9 @@ from app.database.chat import ( ChannelType, ChatChannel, ChatChannelResp, + ChatMessage, + SilenceUser, + UserSilenceResp, ) from app.database.lazer_user import User, UserResp from app.dependencies.database import get_db, get_redis @@ -54,6 +57,29 @@ async def get_update( else None, ) ) + if "sliences" in includes: + if history_since: + silences = ( + await session.exec( + select(SilenceUser).where(col(SilenceUser.id) > history_since) + ) + ).all() + resp.silences.extend( + [UserSilenceResp.from_db(silence) for silence in silences] + ) + elif since: + msg = await session.get(ChatMessage, since) + if msg: + silences = ( + await session.exec( + select(SilenceUser).where( + col(SilenceUser.banned_at) > msg.timestamp + ) + ) + ).all() + resp.silences.extend( + [UserSilenceResp.from_db(silence) for silence in silences] + ) return resp diff --git a/app/router/chat/message.py b/app/router/chat/message.py index 5c7c1bd..6a4a2e4 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -7,6 +7,8 @@ from app.database.chat import ( ChatChannelResp, ChatMessage, MessageType, + SilenceUser, + UserSilenceResp, ) from app.database.lazer_user import User from app.dependencies.database import get_db, get_redis @@ -18,12 +20,16 @@ from .banchobot import bot from .server import server from fastapi import Depends, HTTPException, Query, Security -from pydantic import BaseModel +from pydantic import BaseModel, Field from redis.asyncio import Redis from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +class KeepAliveResp(BaseModel): + silences: list[UserSilenceResp] = Field(default_factory=list) + + @router.post("/chat/ack") async def keep_alive( history_since: int | None = Query(None), @@ -31,7 +37,29 @@ async def keep_alive( current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): - return {"silences": []} + resp = KeepAliveResp() + if history_since: + silences = ( + await session.exec( + select(SilenceUser).where(col(SilenceUser.id) > history_since) + ) + ).all() + resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) + elif since: + msg = await session.get(ChatMessage, since) + if msg: + silences = ( + await session.exec( + select(SilenceUser).where( + col(SilenceUser.banned_at) > msg.timestamp + ) + ) + ).all() + resp.silences.extend( + [UserSilenceResp.from_db(silence) for silence in silences] + ) + + return resp class MessageReq(BaseModel): From 87a3928e20fdb7e473ea00d83412f35d18d3d241 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 15:35:27 +0000 Subject: [PATCH 08/10] docs(chat): add API docs --- app/dependencies/param.py | 12 +++++++ app/router/chat/__init__.py | 11 +++++-- app/router/chat/channel.py | 65 +++++++++++++++++++++++++++++-------- app/router/chat/message.py | 59 +++++++++++++++++++++++++-------- 4 files changed, 117 insertions(+), 30 deletions(-) diff --git a/app/dependencies/param.py b/app/dependencies/param.py index 174adde..28a30c6 100644 --- a/app/dependencies/param.py +++ b/app/dependencies/param.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import Any from fastapi import Request @@ -36,4 +37,15 @@ def BodyOrForm[T: BaseModel](model: type[T]): except ValidationError as e: raise RequestValidationError(e.errors()) + dependency.__signature__ = inspect.signature( # pyright: ignore[reportFunctionMemberAccess] + lambda x: None + ).replace( + parameters=[ + inspect.Parameter( + name=model.__name__.lower(), + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=model, + ) + ] + ) return dependency diff --git a/app/router/chat/__init__.py b/app/router/chat/__init__.py index 3ed7897..f9fc0b3 100644 --- a/app/router/chat/__init__.py +++ b/app/router/chat/__init__.py @@ -11,8 +11,15 @@ from fastapi import Query __all__ = ["chat_router"] -@router.get("/notifications") -async def get_notifications(max_id: int | None = Query(None)): +@router.get( + "/notifications", + tags=["通知", "聊天"], + name="获取通知", + description="获取当前用户未读通知。根据 ID 排序。同时返回通知服务器入口。", +) +async def get_notifications( + max_id: int | None = Query(None, description="获取 ID 小于此值的通知"), +): if settings.server_url is not None: notification_endpoint = f"{settings.server_url}notification-server".replace( "http://", "ws://" diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index 2b42dfe..a514c1d 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -18,7 +18,7 @@ from app.router.v2 import api_v2_router as router from .server import server -from fastapi import Depends, HTTPException, Query, Security +from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field, model_validator from redis.asyncio import Redis from sqlmodel import col, select @@ -30,13 +30,23 @@ class UpdateResponse(BaseModel): silences: list[Any] = Field(default_factory=list) -@router.get("/chat/updates", response_model=UpdateResponse) +@router.get( + "/chat/updates", + response_model=UpdateResponse, + name="获取更新", + description="获取当前用户所在频道的最新的禁言情况。", + tags=["聊天"], +) async def get_update( - history_since: int | None = Query(None), - since: int | None = Query(None), + history_since: int | None = Query( + None, description="获取自此禁言 ID 之后的禁言记录" + ), + since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), + includes: list[str] = Query( + ["presence", "silences"], alias="includes[]", description="要包含的更新类型" + ), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), - includes: list[str] = Query(["presence"], alias="includes[]"), redis: Redis = Depends(get_redis), ): resp = UpdateResponse() @@ -83,10 +93,16 @@ async def get_update( return resp -@router.put("/chat/channels/{channel}/users/{user}", response_model=ChatChannelResp) +@router.put( + "/chat/channels/{channel}/users/{user}", + response_model=ChatChannelResp, + name="加入频道", + description="加入指定的公开/房间频道。", + tags=["聊天"], +) async def join_channel( - channel: str, - user: str, + channel: str = Path(..., description="频道 ID/名称"), + user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), session: AsyncSession = Depends(get_db), ): @@ -100,10 +116,13 @@ async def join_channel( @router.delete( "/chat/channels/{channel}/users/{user}", status_code=204, + name="离开频道", + description="将用户移出指定的公开/房间频道。", + tags=["聊天"], ) async def leave_channel( - channel: str, - user: str, + channel: str = Path(..., description="频道 ID/名称"), + user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), session: AsyncSession = Depends(get_db), ): @@ -115,7 +134,13 @@ async def leave_channel( return -@router.get("/chat/channels") +@router.get( + "/chat/channels", + response_model=list[ChatChannelResp], + name="获取频道列表", + description="获取所有公开频道。", + tags=["聊天"], +) async def get_channel_list( current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), @@ -148,9 +173,15 @@ class GetChannelResp(BaseModel): users: list[UserResp] = Field(default_factory=list) -@router.get("/chat/channels/{channel}") +@router.get( + "/chat/channels/{channel}", + response_model=GetChannelResp, + name="获取频道信息", + description="获取指定频道的信息。", + tags=["聊天"], +) async def get_channel( - channel: str, + channel: str = Path(..., description="频道 ID/名称"), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), @@ -211,7 +242,13 @@ class CreateChannelReq(BaseModel): return self -@router.post("/chat/channels") +@router.post( + "/chat/channels", + response_model=ChatChannelResp, + name="创建频道", + description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。", + tags=["聊天"], +) async def create_channel( req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), diff --git a/app/router/chat/message.py b/app/router/chat/message.py index 6a4a2e4..84f5c41 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -19,7 +19,7 @@ from app.router.v2 import api_v2_router as router from .banchobot import bot from .server import server -from fastapi import Depends, HTTPException, Query, Security +from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field from redis.asyncio import Redis from sqlmodel import col, select @@ -30,10 +30,18 @@ class KeepAliveResp(BaseModel): silences: list[UserSilenceResp] = Field(default_factory=list) -@router.post("/chat/ack") +@router.post( + "/chat/ack", + name="保持连接", + response_model=KeepAliveResp, + description="保持公共频道的连接。同时返回最近的禁言列表。", + tags=["聊天"], +) async def keep_alive( - history_since: int | None = Query(None), - since: int | None = Query(None), + history_since: int | None = Query( + None, description="获取自此禁言 ID 之后的禁言记录" + ), + since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): @@ -68,9 +76,15 @@ class MessageReq(BaseModel): uuid: str | None = None -@router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp) +@router.post( + "/chat/channels/{channel}/messages", + response_model=ChatMessageResp, + name="发送消息", + description="发送消息到指定频道。", + tags=["聊天"], +) async def send_message( - channel: str, + channel: str = Path(..., description="频道 ID/名称"), req: MessageReq = Depends(BodyOrForm(MessageReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), session: AsyncSession = Depends(get_db), @@ -103,12 +117,18 @@ async def send_message( return resp -@router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp]) +@router.get( + "/chat/channels/{channel}/messages", + response_model=list[ChatMessageResp], + name="获取消息", + description="获取指定频道的消息列表。", + tags=["聊天"], +) async def get_message( channel: str, - limit: int = Query(50, ge=1, le=50), - since: int = Query(default=0, ge=0), - until: int | None = Query(None), + limit: int = Query(50, ge=1, le=50, description="获取消息的数量"), + since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"), + until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): @@ -130,10 +150,16 @@ async def get_message( return resp -@router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204) +@router.put( + "/chat/channels/{channel}/mark-as-read/{message}", + status_code=204, + name="标记消息为已读", + description="标记指定消息为已读。", + tags=["聊天"], +) async def mark_as_read( - channel: str, - message: int, + channel: str = Path(..., description="频道 ID/名称"), + message: int = Path(..., description="消息 ID"), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): @@ -157,7 +183,12 @@ class NewPMResp(BaseModel): new_channel_id: int -@router.post("/chat/new") +@router.post( + "/chat/new", + name="创建私聊频道", + description="创建一个新的私聊频道。", + tags=["聊天"], +) async def create_new_pm( req: PMReq = Depends(BodyOrForm(PMReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), From 76dc41f78c5bea1a10976ef652ad286892a3bb9a Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 15:39:06 +0000 Subject: [PATCH 09/10] chore(chat): typo --- app/router/chat/banchobot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/router/chat/banchobot.py b/app/router/chat/banchobot.py index a92eb15..65f1c8b 100644 --- a/app/router/chat/banchobot.py +++ b/app/router/chat/banchobot.py @@ -164,7 +164,7 @@ async def _help( if args: target = args[0].lower() if target in bot._handlers: - return f"Use: !{target} [args]" + return f"Usage: !{target} [args]" return f"No such command: {target}" if not cmds: return "No available commands" From 24bfda4e0c12b702e86f35587f98dba6aaed3b21 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 15:49:22 +0000 Subject: [PATCH 10/10] fix(chat): resolve copilot's review --- app/router/auth.py | 3 ++- app/router/chat/channel.py | 2 +- app/service/create_banchobot.py | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/app/router/auth.py b/app/router/auth.py index 88528e2..d563ed3 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -14,6 +14,7 @@ from app.auth import ( store_token, ) from app.config import settings +from app.const import BANCHOBOT_ID from app.database import DailyChallengeStats, OAuthClient, User from app.database.statistics import UserStatistics from app.dependencies import get_db @@ -459,7 +460,7 @@ async def oauth_token( # 存储令牌 await store_token( db, - 2, + BANCHOBOT_ID, client_id, scopes, access_token, diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index a514c1d..059fc50 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -67,7 +67,7 @@ async def get_update( else None, ) ) - if "sliences" in includes: + if "silences" in includes: if history_since: silences = ( await session.exec( diff --git a/app/service/create_banchobot.py b/app/service/create_banchobot.py index 0d855cf..aa89fe6 100644 --- a/app/service/create_banchobot.py +++ b/app/service/create_banchobot.py @@ -12,7 +12,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession async def create_banchobot(): async with AsyncSession(engine) as session: - is_exist = (await session.exec(select(exists()).where(User.id == 2))).first() + is_exist = ( + await session.exec(select(exists()).where(User.id == BANCHOBOT_ID)) + ).first() if not is_exist: banchobot = User( username="BanchoBot",