refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -43,9 +43,9 @@ async def get_notifications(
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
@@ -96,21 +96,15 @@ async def _get_notifications(
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.category) == identity.category)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
|
||||
@@ -91,9 +91,7 @@ class Bot:
|
||||
if reply:
|
||||
await self._send_reply(user, channel, reply, session)
|
||||
|
||||
async def _send_message(
|
||||
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||
) -> None:
|
||||
async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None:
|
||||
return
|
||||
@@ -101,7 +99,6 @@ class Bot:
|
||||
if channel_id is None:
|
||||
return
|
||||
|
||||
assert bot.id is not None
|
||||
msg = ChatMessage(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
@@ -115,9 +112,7 @@ class Bot:
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(
|
||||
self, user: User, session: AsyncSession
|
||||
) -> ChatChannel | None:
|
||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||
user_id = user.id
|
||||
if user_id is None:
|
||||
return None
|
||||
@@ -160,9 +155,7 @@ bot = Bot()
|
||||
|
||||
|
||||
@bot.command("help")
|
||||
async def _help(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
cmds = sorted(bot._handlers.keys())
|
||||
if args:
|
||||
target = args[0].lower()
|
||||
@@ -175,9 +168,7 @@ async def _help(
|
||||
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
@@ -186,13 +177,9 @@ def _roll(
|
||||
|
||||
|
||||
@bot.command("stats")
|
||||
async def _stats(
|
||||
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) >= 1:
|
||||
target_user = (
|
||||
await session.exec(select(User).where(User.username == args[0]))
|
||||
).first()
|
||||
target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
|
||||
if not target_user:
|
||||
return f"User '{args[0]}' not found."
|
||||
else:
|
||||
@@ -202,14 +189,8 @@ async def _stats(
|
||||
if len(args) >= 2:
|
||||
gamemode = GameMode.parse(args[1].upper())
|
||||
if gamemode is None:
|
||||
subquery = (
|
||||
select(func.max(Score.id))
|
||||
.where(Score.user_id == target_user.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
last_score = (
|
||||
await session.exec(select(Score).where(Score.id == subquery))
|
||||
).first()
|
||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
@@ -295,9 +276,7 @@ async def _mp_host(
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -362,24 +341,18 @@ async def _mp_team(
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
if (
|
||||
user_client.user_id != signalr_client.user_id
|
||||
and room.room.host.user_id != signalr_client.user_id
|
||||
):
|
||||
assert room.room.host
|
||||
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
|
||||
return "You are not allowed to change other users' teams."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
user_client, ChangeTeamRequest(team_id=team)
|
||||
)
|
||||
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
@@ -414,9 +387,7 @@ async def _mp_kick(
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -456,10 +427,7 @@ async def _mp_map(
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return (
|
||||
f"Cannot convert to {playmode.value}. "
|
||||
f"Original mode is {beatmap.mode.value}."
|
||||
)
|
||||
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
@@ -530,9 +498,7 @@ async def _mp_mods(
|
||||
if freestyle:
|
||||
item.allowed_mods = []
|
||||
elif freemod:
|
||||
item.allowed_mods = get_available_mods(
|
||||
current_item.ruleset_id, required_mods
|
||||
)
|
||||
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
|
||||
else:
|
||||
item.allowed_mods = allowed_mods
|
||||
item.required_mods = required_mods
|
||||
@@ -601,14 +567,9 @@ async def _score(
|
||||
include_fail: bool = False,
|
||||
gamemode: GameMode | None = None,
|
||||
) -> str:
|
||||
q = (
|
||||
select(Score)
|
||||
.where(Score.user_id == user_id)
|
||||
.order_by(col(Score.id).desc())
|
||||
.options(joinedload(Score.beatmap))
|
||||
)
|
||||
q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
|
||||
if not include_fail:
|
||||
q = q.where(Score.passed.is_(True))
|
||||
q = q.where(col(Score.passed).is_(True))
|
||||
if gamemode is not None:
|
||||
q = q.where(Score.gamemode == gamemode)
|
||||
|
||||
@@ -619,17 +580,13 @@ async def _score(
|
||||
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
|
||||
Played at {score.started_at}
|
||||
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
|
||||
if score.gamemode == GameMode.MANIA:
|
||||
keys = next(
|
||||
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
|
||||
)
|
||||
keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
|
||||
if keys is None:
|
||||
keys = f"{int(score.beatmap.cs)}K"
|
||||
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
|
||||
result += (
|
||||
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
)
|
||||
result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
@@ -69,34 +60,20 @@ async def get_update(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -115,15 +92,9 @@ async def join_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -145,15 +116,9 @@ async def leave_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -173,27 +138,20 @@ async def get_channel_list(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
@@ -219,15 +177,9 @@ async def get_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -237,8 +189,6 @@ async def get_channel(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
@@ -259,9 +209,7 @@ async def get_channel(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
|
||||
raise ValueError("target_id must be set for PM channels")
|
||||
else:
|
||||
if self.target_ids is None or self.channel is None or self.message is None:
|
||||
raise ValueError(
|
||||
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||
)
|
||||
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
|
||||
return self
|
||||
|
||||
|
||||
@@ -312,24 +258,20 @@ async def create_channel(
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(
|
||||
current_user.id, # pyright: ignore[reportArgumentType]
|
||||
current_user.id,
|
||||
req.target_id, # pyright: ignore[reportArgumentType]
|
||||
session,
|
||||
)
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel_name)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=channel_name,
|
||||
description=req.channel.description
|
||||
if req.channel
|
||||
else "Private message channel",
|
||||
description=req.channel.description if req.channel else "Private message channel",
|
||||
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||
)
|
||||
session.add(channel)
|
||||
@@ -340,16 +282,13 @@ async def create_channel(
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(
|
||||
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||
)
|
||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
|
||||
@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
|
||||
return resp
|
||||
|
||||
@@ -93,15 +79,9 @@ async def send_message(
|
||||
):
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -111,9 +91,6 @@ async def send_message(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
@@ -125,9 +102,7 @@ async def send_message(
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
@@ -147,14 +122,10 @@ async def send_message(
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage.init(
|
||||
temp_msg, current_user, [int(u) for u in user_ids], channel_type
|
||||
)
|
||||
ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
|
||||
)
|
||||
elif channel_type == ChannelType.TEAM:
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
|
||||
|
||||
return resp
|
||||
|
||||
@@ -176,22 +147,15 @@ async def get_message(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
# 使用 Redis 消息系统获取消息
|
||||
try:
|
||||
@@ -230,23 +194,15 @@ async def mark_as_read(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
assert current_user.id
|
||||
await server.mark_as_read(channel_id, current_user.id, message)
|
||||
|
||||
|
||||
@@ -283,7 +239,6 @@ async def create_new_pm(
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
assert user_id
|
||||
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
@@ -297,7 +252,6 @@ async def create_new_pm(
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
assert channel.channel_id
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -37,20 +38,11 @@ class ChatServer:
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
@@ -61,9 +53,7 @@ class ChatServer:
|
||||
channel.remove(user_id)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
@@ -93,11 +83,10 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
@@ -106,62 +95,44 @@ class ChatServer:
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(
|
||||
f"Broadcasting message to all users in channel {message.channel_id}"
|
||||
)
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(
|
||||
f"chat:{message.channel_id}:last_msg", message.message_id
|
||||
)
|
||||
logger.info(
|
||||
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
|
||||
)
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping last message update for message ID: {message.message_id}"
|
||||
)
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
not_joined = []
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
for user in users:
|
||||
assert user.id is not None
|
||||
if user.id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user.id)
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
assert user.id is not None
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id]
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
@@ -171,13 +142,9 @@ class ChatServer:
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
@@ -202,13 +169,9 @@ class ChatServer:
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
@@ -221,9 +184,7 @@ class ChatServer:
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
@@ -236,11 +197,7 @@ class ChatServer:
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -253,11 +210,7 @@ class ChatServer:
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -270,13 +223,7 @@ class ChatServer:
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
id = await insert_notification(session, detail)
|
||||
users = (
|
||||
await session.exec(
|
||||
select(UserNotification).where(
|
||||
UserNotification.notification_id == id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
|
||||
for user_notification in users:
|
||||
data = user_notification.notification.model_dump()
|
||||
data["is_read"] = user_notification.is_read
|
||||
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
@@ -332,11 +277,7 @@ async def chat_websocket(
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||
)
|
||||
) is None:
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
@@ -346,12 +287,9 @@ async def chat_websocket(
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user