refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
)
async def get_update(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
@@ -69,34 +60,20 @@ async def get_update(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
if "silences" in includes:
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@@ -115,15 +92,9 @@ async def join_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -145,15 +116,9 @@ async def leave_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -173,27 +138,20 @@ async def get_channel_list(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
channels = (
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
@@ -219,15 +177,9 @@ async def get_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -237,8 +189,6 @@ async def get_channel(
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
@@ -259,9 +209,7 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
return self
@@ -312,24 +258,20 @@ async def create_channel(
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
current_user.id,
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel_name)
)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
description=req.channel.description if req.channel else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
@@ -340,16 +282,13 @@ async def create_channel(
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,