refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -17,6 +17,7 @@ from app.log import logger
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -37,20 +38,11 @@ class ChatServer:
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
@@ -61,9 +53,7 @@ class ChatServer:
channel.remove(user_id)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@@ -93,11 +83,10 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
@@ -106,62 +95,44 @@ class ChatServer:
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event))
bg_tasks.add_task(self.send_event, message.sender_id, event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(
f"Broadcasting message to all users in channel {message.channel_id}"
)
self._add_task(
self.broadcast(
message.channel_id,
event,
)
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(
f"chat:{message.channel_id}:last_msg", message.message_id
)
logger.info(
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
)
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(
f"Skipping last message update for message ID: {message.message_id}"
)
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
channel_id = channel.channel_id
assert channel_id is not None
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user.id,
@@ -171,13 +142,9 @@ class ChatServer:
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
@@ -202,13 +169,9 @@ class ChatServer:
return channel_resp
async def leave_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
@@ -221,9 +184,7 @@ class ChatServer:
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
@@ -236,11 +197,7 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -253,11 +210,7 @@ class ChatServer:
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -270,13 +223,7 @@ class ChatServer:
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
@@ -332,11 +277,7 @@ async def chat_websocket(
async for session in factory():
token = authorization[7:]
if (
user := await get_current_user(
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
await websocket.close(code=1008)
return
@@ -346,12 +287,9 @@ async def chat_websocket(
await websocket.close(code=1008)
return
user_id = user.id
assert user_id
server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)