feat(chat): support pm
This commit is contained in:
@@ -16,6 +16,7 @@ from sqlmodel import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Relationship,
|
Relationship,
|
||||||
SQLModel,
|
SQLModel,
|
||||||
|
col,
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -65,6 +66,15 @@ class ChatChannel(ChatChannelBase, table=True):
|
|||||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_pm_channel(
|
||||||
|
cls, user1: int, user2: int, session: AsyncSession
|
||||||
|
) -> "ChatChannel | None":
|
||||||
|
channel = await cls.get(f"pm_{user1}_{user2}", session)
|
||||||
|
if channel is None:
|
||||||
|
channel = await cls.get(f"pm_{user2}_{user1}", session)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
|
||||||
class ChatChannelResp(ChatChannelBase):
|
class ChatChannelResp(ChatChannelBase):
|
||||||
channel_id: int
|
channel_id: int
|
||||||
@@ -73,8 +83,8 @@ class ChatChannelResp(ChatChannelBase):
|
|||||||
current_user_attributes: ChatUserAttributes | None = None
|
current_user_attributes: ChatUserAttributes | None = None
|
||||||
last_read_id: int | None = None
|
last_read_id: int | None = None
|
||||||
last_message_id: int | None = None
|
last_message_id: int | None = None
|
||||||
recent_messages: list[str] | None = None
|
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
|
||||||
users: list[int] | None = None
|
users: list[int] = Field(default_factory=list)
|
||||||
message_length_limit: int = 1000
|
message_length_limit: int = 1000
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -82,9 +92,10 @@ class ChatChannelResp(ChatChannelBase):
|
|||||||
cls,
|
cls,
|
||||||
channel: ChatChannel,
|
channel: ChatChannel,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
users: list[int],
|
|
||||||
user: User,
|
user: User,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
|
users: list[int] | None = None,
|
||||||
|
include_recent_messages: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
c = cls.model_validate(channel)
|
c = cls.model_validate(channel)
|
||||||
silence = (
|
silence = (
|
||||||
@@ -123,9 +134,33 @@ class ChatChannelResp(ChatChannelBase):
|
|||||||
c.moderated = False
|
c.moderated = False
|
||||||
|
|
||||||
c.current_user_attributes = attribute
|
c.current_user_attributes = attribute
|
||||||
c.users = users
|
if c.type != ChannelType.PUBLIC and users is not None:
|
||||||
|
c.users = users
|
||||||
c.last_message_id = last_msg
|
c.last_message_id = last_msg
|
||||||
c.last_read_id = last_read_id
|
c.last_read_id = last_read_id
|
||||||
|
|
||||||
|
if include_recent_messages:
|
||||||
|
messages = (
|
||||||
|
await session.exec(
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(ChatMessage.channel_id == channel.channel_id)
|
||||||
|
.order_by(col(ChatMessage.timestamp).desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
c.recent_messages = [
|
||||||
|
await ChatMessageResp.from_db(msg, session, user) for msg in messages
|
||||||
|
]
|
||||||
|
c.recent_messages.reverse()
|
||||||
|
|
||||||
|
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||||
|
target_user_id = next(u for u in users if u != user.id)
|
||||||
|
target_name = await session.exec(
|
||||||
|
select(User.username).where(User.id == target_user_id)
|
||||||
|
)
|
||||||
|
c.name = target_name.one()
|
||||||
|
assert user.id
|
||||||
|
c.users = [target_user_id, user.id]
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True):
|
|||||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def is_user_can_pm(
|
||||||
|
self, from_user: "User", session: AsyncSession
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
from .relationship import Relationship, RelationshipType
|
||||||
|
|
||||||
|
from_relationship = (
|
||||||
|
await session.exec(
|
||||||
|
select(Relationship).where(
|
||||||
|
Relationship.user_id == from_user.id,
|
||||||
|
Relationship.target_id == self.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
|
||||||
|
return False, "You have blocked the target user."
|
||||||
|
if from_user.pm_friends_only and (
|
||||||
|
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"You have disabled non-friend communications "
|
||||||
|
"and target user is not your friend.",
|
||||||
|
)
|
||||||
|
|
||||||
|
relationship = (
|
||||||
|
await session.exec(
|
||||||
|
select(Relationship).where(
|
||||||
|
Relationship.user_id == self.id,
|
||||||
|
Relationship.target_id == from_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if relationship and relationship.type == RelationshipType.BLOCK:
|
||||||
|
return False, "Target user has blocked you."
|
||||||
|
if self.pm_friends_only and (
|
||||||
|
not relationship or relationship.type != RelationshipType.FOLLOW
|
||||||
|
):
|
||||||
|
return False, "Target user has disabled non-friend communications"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
class UserResp(UserBase):
|
class UserResp(UserBase):
|
||||||
id: int | None = None
|
id: int | None = None
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal, Self
|
||||||
|
|
||||||
from app.database.chat import (
|
from app.database.chat import (
|
||||||
ChannelType,
|
ChannelType,
|
||||||
@@ -9,15 +9,16 @@ from app.database.chat import (
|
|||||||
)
|
)
|
||||||
from app.database.lazer_user import User, UserResp
|
from app.database.lazer_user import User, UserResp
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import get_db, get_redis
|
||||||
|
from app.dependencies.param import BodyOrForm
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
|
|
||||||
from .server import server
|
from .server import server
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, Security
|
from fastapi import Depends, HTTPException, Query, Security
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@ async def get_update(
|
|||||||
):
|
):
|
||||||
resp = UpdateResponse()
|
resp = UpdateResponse()
|
||||||
if "presence" in includes:
|
if "presence" in includes:
|
||||||
|
assert current_user.id
|
||||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||||
for channel_id in channel_ids:
|
for channel_id in channel_ids:
|
||||||
channel = await ChatChannel.get(channel_id, session)
|
channel = await ChatChannel.get(channel_id, session)
|
||||||
@@ -45,9 +47,11 @@ async def get_update(
|
|||||||
await ChatChannelResp.from_db(
|
await ChatChannelResp.from_db(
|
||||||
channel,
|
channel,
|
||||||
session,
|
session,
|
||||||
server.channels.get(channel_id, []),
|
|
||||||
current_user,
|
current_user,
|
||||||
redis,
|
redis,
|
||||||
|
server.channels.get(channel_id, [])
|
||||||
|
if channel.type != ChannelType.PUBLIC
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return resp
|
return resp
|
||||||
@@ -103,9 +107,11 @@ async def get_channel_list(
|
|||||||
await ChatChannelResp.from_db(
|
await ChatChannelResp.from_db(
|
||||||
channel,
|
channel,
|
||||||
session,
|
session,
|
||||||
server.channels.get(channel.channel_id, []),
|
|
||||||
current_user,
|
current_user,
|
||||||
redis,
|
redis,
|
||||||
|
server.channels.get(channel.channel_id, [])
|
||||||
|
if channel.type != ChannelType.PUBLIC
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
@@ -127,12 +133,111 @@ async def get_channel(
|
|||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
assert db_channel.channel_id is not None
|
assert db_channel.channel_id is not None
|
||||||
|
|
||||||
|
users = []
|
||||||
|
if db_channel.type == ChannelType.PM:
|
||||||
|
user_ids = db_channel.name.split("_")[1:]
|
||||||
|
if len(user_ids) != 2:
|
||||||
|
raise HTTPException(status_code=404, detail="Target user not found")
|
||||||
|
for id_ in user_ids:
|
||||||
|
if int(id_) == current_user.id:
|
||||||
|
continue
|
||||||
|
target_user = await session.get(User, int(id_))
|
||||||
|
if target_user is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Target user not found")
|
||||||
|
users.extend([target_user, current_user])
|
||||||
|
break
|
||||||
|
|
||||||
return GetChannelResp(
|
return GetChannelResp(
|
||||||
channel=await ChatChannelResp.from_db(
|
channel=await ChatChannelResp.from_db(
|
||||||
db_channel,
|
db_channel,
|
||||||
session,
|
session,
|
||||||
server.channels.get(db_channel.channel_id, []),
|
|
||||||
current_user,
|
current_user,
|
||||||
redis,
|
redis,
|
||||||
|
server.channels.get(db_channel.channel_id, [])
|
||||||
|
if db_channel.type != ChannelType.PUBLIC
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChannelReq(BaseModel):
|
||||||
|
class AnnounceChannel(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
message: str | None = None
|
||||||
|
type: Literal["ANNOUNCE", "PM"] = "PM"
|
||||||
|
target_id: int | None = None
|
||||||
|
target_ids: list[int] | None = None
|
||||||
|
channel: AnnounceChannel | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check(self) -> Self:
|
||||||
|
if self.type == "PM":
|
||||||
|
if self.target_id is None:
|
||||||
|
raise ValueError("target_id must be set for PM channels")
|
||||||
|
else:
|
||||||
|
if self.target_ids is None or self.channel is None or self.message is None:
|
||||||
|
raise ValueError(
|
||||||
|
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/channels")
|
||||||
|
async def create_channel(
|
||||||
|
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
):
|
||||||
|
if req.type == "PM":
|
||||||
|
target = await session.get(User, req.target_id)
|
||||||
|
if not target:
|
||||||
|
raise HTTPException(status_code=404, detail="Target user not found")
|
||||||
|
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||||
|
if not is_can_pm:
|
||||||
|
raise HTTPException(status_code=403, detail=block)
|
||||||
|
|
||||||
|
channel = await ChatChannel.get_pm_channel(
|
||||||
|
current_user.id, # pyright: ignore[reportArgumentType]
|
||||||
|
req.target_id, # pyright: ignore[reportArgumentType]
|
||||||
|
session,
|
||||||
|
)
|
||||||
|
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||||
|
else:
|
||||||
|
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||||
|
channel = await ChatChannel.get(channel_name, session)
|
||||||
|
|
||||||
|
if channel is None:
|
||||||
|
channel = ChatChannel(
|
||||||
|
name=channel_name,
|
||||||
|
description=req.channel.description
|
||||||
|
if req.channel
|
||||||
|
else "Private message channel",
|
||||||
|
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||||
|
)
|
||||||
|
session.add(channel)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(channel)
|
||||||
|
await session.refresh(current_user)
|
||||||
|
if req.type == "PM":
|
||||||
|
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||||
|
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||||
|
else:
|
||||||
|
target_users = await session.exec(
|
||||||
|
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||||
|
)
|
||||||
|
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||||
|
|
||||||
|
await server.join_channel(current_user, channel, session)
|
||||||
|
assert channel.channel_id
|
||||||
|
return await ChatChannelResp.from_db(
|
||||||
|
channel,
|
||||||
|
session,
|
||||||
|
current_user,
|
||||||
|
redis,
|
||||||
|
server.channels.get(channel.channel_id, []),
|
||||||
|
include_recent_messages=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.database import ChatMessageResp
|
from app.database import ChatMessageResp
|
||||||
from app.database.chat import ChatChannel, ChatMessage, MessageType
|
from app.database.chat import (
|
||||||
|
ChannelType,
|
||||||
|
ChatChannel,
|
||||||
|
ChatChannelResp,
|
||||||
|
ChatMessage,
|
||||||
|
MessageType,
|
||||||
|
)
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import get_db, get_redis
|
||||||
from app.dependencies.param import BodyOrForm
|
from app.dependencies.param import BodyOrForm
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
from app.router.v2 import api_v2_router as router
|
from app.router.v2 import api_v2_router as router
|
||||||
@@ -12,6 +18,7 @@ from .server import server
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, Security
|
from fastapi import Depends, HTTPException, Query, Security
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -42,6 +49,9 @@ async def send_message(
|
|||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
|
||||||
|
assert db_channel.channel_id
|
||||||
|
assert current_user.id
|
||||||
msg = ChatMessage(
|
msg = ChatMessage(
|
||||||
channel_id=db_channel.channel_id,
|
channel_id=db_channel.channel_id,
|
||||||
content=req.message,
|
content=req.message,
|
||||||
@@ -95,4 +105,73 @@ async def mark_as_read(
|
|||||||
db_channel = await ChatChannel.get(channel, session)
|
db_channel = await ChatChannel.get(channel, session)
|
||||||
if db_channel is None:
|
if db_channel is None:
|
||||||
raise HTTPException(status_code=404, detail="Channel not found")
|
raise HTTPException(status_code=404, detail="Channel not found")
|
||||||
|
assert db_channel.channel_id
|
||||||
await server.mark_as_read(db_channel.channel_id, message)
|
await server.mark_as_read(db_channel.channel_id, message)
|
||||||
|
|
||||||
|
|
||||||
|
class PMReq(BaseModel):
|
||||||
|
target_id: int
|
||||||
|
message: str
|
||||||
|
is_action: bool = False
|
||||||
|
uuid: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class NewPMResp(BaseModel):
|
||||||
|
channel: ChatChannelResp
|
||||||
|
message: ChatMessageResp
|
||||||
|
new_channel_id: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/new")
|
||||||
|
async def create_new_pm(
|
||||||
|
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||||
|
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
):
|
||||||
|
user_id = current_user.id
|
||||||
|
target = await session.get(User, req.target_id)
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Target user not found")
|
||||||
|
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
||||||
|
if not is_can_pm:
|
||||||
|
raise HTTPException(status_code=403, detail=block)
|
||||||
|
|
||||||
|
assert user_id
|
||||||
|
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||||
|
if channel is None:
|
||||||
|
channel = ChatChannel(
|
||||||
|
name=f"pm_{user_id}_{req.target_id}",
|
||||||
|
description="Private message channel",
|
||||||
|
type=ChannelType.PM,
|
||||||
|
)
|
||||||
|
session.add(channel)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(channel)
|
||||||
|
await session.refresh(target)
|
||||||
|
await session.refresh(current_user)
|
||||||
|
|
||||||
|
assert channel.channel_id
|
||||||
|
await server.batch_join_channel([target, current_user], channel, session)
|
||||||
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
|
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||||
|
)
|
||||||
|
msg = ChatMessage(
|
||||||
|
channel_id=channel.channel_id,
|
||||||
|
content=req.message,
|
||||||
|
sender_id=user_id,
|
||||||
|
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||||
|
uuid=req.uuid,
|
||||||
|
)
|
||||||
|
session.add(msg)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(msg)
|
||||||
|
await session.refresh(current_user)
|
||||||
|
await session.refresh(channel)
|
||||||
|
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||||
|
await server.send_message_to_channel(message_resp)
|
||||||
|
return NewPMResp(
|
||||||
|
channel=channel_resp,
|
||||||
|
message=message_resp,
|
||||||
|
new_channel_id=channel_resp.channel_id,
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp
|
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
@@ -73,14 +73,50 @@ class ChatServer:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert message.message_id
|
||||||
await self.mark_as_read(message.channel_id, message.message_id)
|
await self.mark_as_read(message.channel_id, message.message_id)
|
||||||
|
|
||||||
|
async def batch_join_channel(
|
||||||
|
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||||
|
):
|
||||||
|
channel_id = channel.channel_id
|
||||||
|
assert channel_id is not None
|
||||||
|
|
||||||
|
if channel_id not in self.channels:
|
||||||
|
self.channels[channel_id] = []
|
||||||
|
for user_id in [user.id for user in users]:
|
||||||
|
assert user_id is not None
|
||||||
|
if user_id not in self.channels[channel_id]:
|
||||||
|
self.channels[channel_id].append(user_id)
|
||||||
|
|
||||||
|
for user in users:
|
||||||
|
assert user.id is not None
|
||||||
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
|
channel,
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
self.redis,
|
||||||
|
self.channels[channel_id]
|
||||||
|
if channel.type != ChannelType.PUBLIC
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
client = self.connect_client.get(user.id)
|
||||||
|
if client:
|
||||||
|
await self.send_event(
|
||||||
|
client,
|
||||||
|
ChatEvent(
|
||||||
|
event="chat.channel.join",
|
||||||
|
data=channel_resp.model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def join_channel(
|
async def join_channel(
|
||||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||||
) -> ChatChannelResp:
|
) -> ChatChannelResp:
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
assert channel_id is not None
|
assert channel_id is not None
|
||||||
|
assert user_id is not None
|
||||||
|
|
||||||
if channel_id not in self.channels:
|
if channel_id not in self.channels:
|
||||||
self.channels[channel_id] = []
|
self.channels[channel_id] = []
|
||||||
@@ -88,7 +124,11 @@ class ChatServer:
|
|||||||
self.channels[channel_id].append(user_id)
|
self.channels[channel_id].append(user_id)
|
||||||
|
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
channel, session, self.channels[channel_id], user, self.redis
|
channel,
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
self.redis,
|
||||||
|
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = self.connect_client.get(user_id)
|
client = self.connect_client.get(user_id)
|
||||||
@@ -109,15 +149,16 @@ class ChatServer:
|
|||||||
user_id = user.id
|
user_id = user.id
|
||||||
channel_id = channel.channel_id
|
channel_id = channel.channel_id
|
||||||
assert channel_id is not None
|
assert channel_id is not None
|
||||||
|
assert user_id is not None
|
||||||
|
|
||||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||||
self.channels[channel_id].remove(user_id)
|
self.channels[channel_id].remove(user_id)
|
||||||
|
|
||||||
if not self.channels.get(channel_id):
|
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||||
del self.channels[channel_id]
|
del self.channels[channel_id]
|
||||||
|
|
||||||
channel_resp = await ChatChannelResp.from_db(
|
channel_resp = await ChatChannelResp.from_db(
|
||||||
channel, session, self.channels.get(channel_id, []), user, self.redis
|
channel, session, user, self.redis, self.channels[channel_id]
|
||||||
)
|
)
|
||||||
client = self.connect_client.get(user_id)
|
client = self.connect_client.get(user_id)
|
||||||
if client:
|
if client:
|
||||||
|
|||||||
Reference in New Issue
Block a user