refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -19,17 +19,11 @@ class Achievement(NamedTuple):
@property
def url(self) -> str:
return (
self.medal_url
or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
)
return self.medal_url or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
@property
def url2x(self) -> str:
return (
self.medal_url2x
or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
)
return self.medal_url2x or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]]

View File

@@ -11,7 +11,8 @@ class APIMe(UserResp):
"""
/me 端点的响应模型
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
"""
pass

View File

@@ -95,11 +95,7 @@ class SearchQueryModel(BaseModel):
q: str = Field("", description="搜索关键词")
c: Annotated[
list[
Literal[
"recommended", "converts", "follows", "spotlights", "featured_artists"
]
],
list[Literal["recommended", "converts", "follows", "spotlights", "featured_artists"]],
BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)),
] = Field(
@@ -188,12 +184,10 @@ class SearchQueryModel(BaseModel):
list[Literal["video", "storyboard"]],
BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)),
] = Field(
default_factory=list, description=("其他video 有视频 / storyboard 有故事板")
] = Field(default_factory=list, description=("其他video 有视频 / storyboard 有故事板"))
r: Annotated[list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))] = Field(
default_factory=list, description="成绩"
)
r: Annotated[
list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))
] = Field(default_factory=list, description="成绩")
played: bool = Field(
default=False,
description="玩过",

View File

@@ -9,12 +9,13 @@ from pydantic import BaseModel
class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None
token_type: str = "Bearer"
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
# 二次验证相关字段
requires_second_factor: bool = False
verification_message: str | None = None
@@ -23,6 +24,7 @@ class ExtendedTokenResponse(BaseModel):
class SessionState(BaseModel):
"""会话状态"""
user_id: int
username: str
email: str

View File

@@ -145,9 +145,7 @@ class MultiplayerPlaylistItemStats(BaseModel):
class MultiplayerRoomStats(BaseModel):
room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
default_factory=dict
)
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
class MultiplayerRoomScoreSetEvent(BaseModel):

View File

@@ -174,11 +174,7 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
return True
ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods:
if (
app_settings.enable_rx
and mod["acronym"] == "RX"
and ruleset_id in {0, 1, 2}
):
if app_settings.enable_rx and mod["acronym"] == "RX" and ruleset_id in {0, 1, 2}:
continue
if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0:
continue
@@ -251,10 +247,7 @@ def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[API
if mod_acronym in incompatible_mods:
continue
if any(
required_acronym in mod_data["IncompatibleMods"]
for required_acronym in required_mod_acronyms
):
if any(required_acronym in mod_data["IncompatibleMods"] for required_acronym in required_mod_acronyms):
continue
if mod_data.get("UserPlayable", False):

View File

@@ -121,32 +121,21 @@ class PlaylistItem(BaseModel):
star_rating: float
freestyle: bool
def _validate_mod_for_ruleset(
self, mod: APIMod, ruleset_key: int, context: str = "mod"
) -> None:
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if (
typed_ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[typed_ruleset_key]
):
raise InvokeException(
f"{context} {mod['acronym']} is invalid for this ruleset"
)
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not playable by users"
)
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not valid for multiplayer"
)
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
@@ -159,10 +148,7 @@ class PlaylistItem(BaseModel):
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(
f"Mods {mod1['acronym']} and "
f"{mod2['acronym']} are incompatible"
)
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
@@ -178,10 +164,7 @@ class PlaylistItem(BaseModel):
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(
f"Required mod {req_acronym} conflicts with "
f"allowed mods: {conflict_list}"
)
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
@@ -219,10 +202,7 @@ class PlaylistItem(BaseModel):
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
all_proposed_valid = False
continue
valid_mods.append(mod)
@@ -252,9 +232,7 @@ class PlaylistItem(BaseModel):
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
@@ -288,9 +266,7 @@ class PlaylistItem(BaseModel):
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
class MatchStartCountdown(_MultiplayerCountdown):
@@ -305,17 +281,13 @@ class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = (
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, download_progress=None
)
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
@@ -358,9 +330,7 @@ class MultiplayerRoom(BaseModel):
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating
if item.beatmap is not None
else 0.0,
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
freestyle=item.freestyle,
)
)
@@ -425,9 +395,7 @@ class MultiplayerQueue:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max(
(len(items) for items in user_item_groups.values()), default=0
)
max_items = max((len(items) for items in user_item_groups.values()), default=0)
for i in range(max_items):
current_set = []
@@ -436,20 +404,13 @@ class MultiplayerQueue:
current_set.append(items[i])
if is_first_set:
current_set.sort(
key=lambda item: (item.playlist_order, item.id)
)
current_set.sort(key=lambda item: (item.playlist_order, item.id))
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx
for idx, item in enumerate(ordered_active_items)
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
ordered_active_items.extend(current_set)
is_first_set = False
@@ -464,9 +425,7 @@ class MultiplayerQueue:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False
)
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
async def update_current_item(self):
upcoming_items = self.upcoming_items
@@ -494,16 +453,7 @@ class MultiplayerQueue:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if (
len(
[
True
for u in self.room.playlist
if u.owner_id == user.user_id and not u.expired
]
)
>= limit
):
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
@@ -512,9 +462,7 @@ class MultiplayerQueue:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
@@ -538,29 +486,19 @@ class MultiplayerQueue:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next(
(i for i in self.room.playlist if i.id == item.id), None
)
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
if existing_item is None:
raise InvokeException(
"Attempted to change an item that doesn't exist"
)
raise InvokeException("Attempted to change an item that doesn't exist")
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to change an item which is not owned by the user"
)
raise InvokeException("Attempted to change an item which is not owned by the user")
if existing_item.expired:
raise InvokeException(
"Attempted to change an item which has already been played"
)
raise InvokeException("Attempted to change an item which has already been played")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
@@ -578,8 +516,7 @@ class MultiplayerQueue:
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
@@ -600,14 +537,10 @@ class MultiplayerQueue:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to remove an item which is not owned by the user"
)
raise InvokeException("Attempted to remove an item which is not owned by the user")
if item.expired:
raise InvokeException(
"Attempted to remove an item which has already been played"
)
raise InvokeException("Attempted to remove an item which has already been played")
async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
@@ -668,9 +601,7 @@ class CountdownInfo:
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
)
@@ -704,9 +635,7 @@ class MatchTypeHandler(ABC):
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -723,9 +652,7 @@ class HeadToHeadHandler(MatchTypeHandler):
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -762,9 +689,7 @@ class TeamVersusHandler(MatchTypeHandler):
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
team_counts[user.match_state.team_id] += 1
if team_counts:
@@ -798,9 +723,7 @@ class TeamVersusHandler(MatchTypeHandler):
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
@@ -843,9 +766,7 @@ class ServerMultiplayerRoom:
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
self
)
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
@@ -871,9 +792,7 @@ class ServerMultiplayerRoom:
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):

View File

@@ -53,7 +53,7 @@ class NotificationName(str, Enum):
NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change",
NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem",
NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state",
@@ -164,17 +164,11 @@ class ChannelMessageTeam(ChannelMessageBase):
from app.database import TeamMember
user_team_id = (
await session.exec(
select(TeamMember.team_id).where(TeamMember.user_id == self._user.id)
)
await session.exec(select(TeamMember.team_id).where(TeamMember.user_id == self._user.id))
).first()
if not user_team_id:
return []
user_ids = (
await session.exec(
select(TeamMember.user_id).where(TeamMember.team_id == user_team_id)
)
).all()
user_ids = (await session.exec(select(TeamMember.user_id).where(TeamMember.team_id == user_team_id))).all()
return list(user_ids)

View File

@@ -197,9 +197,7 @@ class SoloScoreSubmissionInfo(BaseModel):
# check incompatible mods
for mod in mods:
if mod["acronym"] in incompatible_mods:
raise ValueError(
f"Mod {mod['acronym']} is incompatible with other mods"
)
raise ValueError(f"Mod {mod['acronym']} is incompatible with other mods")
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
if not setting_mods:
raise ValueError(f"Invalid mod: {mod['acronym']}")

View File

@@ -22,9 +22,7 @@ class SignalRUnionMessage(BaseModel):
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
)
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
class NegotiateResponse(BaseModel):

View File

@@ -89,9 +89,7 @@ class LegacyReplayFrame(BaseModel):
mouse_y: float | None = None
button_state: int
header: Annotated[
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
]
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
class FrameDataBundle(BaseModel):