add message redis

This commit is contained in:
咕谷酱
2025-08-22 01:49:03 +08:00
parent 36b695b531
commit 1fe603f416
11 changed files with 1461 additions and 86 deletions

View File

@@ -53,16 +53,24 @@ async def get_update(
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session)
if channel:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
channel,
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -105,7 +113,19 @@ async def join_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -125,7 +145,19 @@ async def leave_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -152,15 +184,19 @@ async def get_channel_list(
).all()
results = []
for channel in channels:
assert channel.channel_id is not None
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -185,14 +221,33 @@ async def get_channel(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
@@ -210,8 +265,8 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -270,7 +325,8 @@ async def create_channel(
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
@@ -294,12 +350,16 @@ async def create_channel(
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
server.channels.get(channel_id, []),
include_recent_messages=True,
)