add message redis
This commit is contained in:
@@ -53,16 +53,24 @@ async def get_update(
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -105,7 +113,19 @@ async def join_channel(
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -125,7 +145,19 @@ async def leave_channel(
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -152,15 +184,19 @@ async def get_channel_list(
|
||||
).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
assert channel.channel_id is not None
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -185,14 +221,33 @@ async def get_channel(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id is not None
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
if len(user_ids) != 2:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
for id_ in user_ids:
|
||||
@@ -210,8 +265,8 @@ async def get_channel(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(db_channel.channel_id, [])
|
||||
if db_channel.type != ChannelType.PUBLIC
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -270,7 +325,8 @@ async def create_channel(
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
channel = await ChatChannel.get(channel_name, session)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
@@ -294,12 +350,16 @@ async def create_channel(
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
assert channel.channel_id
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, []),
|
||||
server.channels.get(channel_id, []),
|
||||
include_recent_messages=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user