diff --git a/app/achievements/daily_challenge.py b/app/achievements/daily_challenge.py index b2b3f4c..fcca594 100644 --- a/app/achievements/daily_challenge.py +++ b/app/achievements/daily_challenge.py @@ -32,11 +32,9 @@ async def process_streak( ).first() if not stats: return False - if streak <= stats.daily_streak_best < next_streak: - return True - elif next_streak == 0 and stats.daily_streak_best >= streak: - return True - return False + return bool( + streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak) + ) MEDALS = { diff --git a/app/achievements/hush_hush.py b/app/achievements/hush_hush.py index a887693..bc9c53f 100644 --- a/app/achievements/hush_hush.py +++ b/app/achievements/hush_hush.py @@ -68,9 +68,7 @@ async def to_the_core( if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist: return False mods_ = mod_to_save(score.mods) - if "DT" not in mods_ or "NC" not in mods_: - return False - return True + return not ("DT" not in mods_ or "NC" not in mods_) async def wysi( @@ -83,9 +81,7 @@ async def wysi( return False if str(round(score.accuracy, ndigits=4))[3:] != "727": return False - if "xi" not in beatmap.beatmapset.artist: - return False - return True + return "xi" in beatmap.beatmapset.artist async def prepared( @@ -97,9 +93,7 @@ async def prepared( if score.rank != Rank.X and score.rank != Rank.XH: return False mods_ = mod_to_save(score.mods) - if "NF" not in mods_: - return False - return True + return "NF" in mods_ async def reckless_adandon( @@ -117,9 +111,7 @@ async def reckless_adandon( redis = get_redis() mods_ = score.mods.copy() attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) - if attribute.star_rating < 3: - return False - return True + return not attribute.star_rating < 3 async def lights_out( @@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth( return False for mod in score.mods: - if mod.get("acronym") == "AC": - if "settings" in mod and "minimum_accuracy" in mod["settings"]: - target_accuracy = mod["settings"]["minimum_accuracy"] - if isinstance(target_accuracy, int | float): - return abs(score.accuracy - float(target_accuracy)) < 0.0001 + if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]: + target_accuracy = mod["settings"]["minimum_accuracy"] + if isinstance(target_accuracy, int | float): + return abs(score.accuracy - float(target_accuracy)) < 0.0001 return False diff --git a/app/achievements/mods.py b/app/achievements/mods.py index d157a15..0f3c728 100644 --- a/app/achievements/mods.py +++ b/app/achievements/mods.py @@ -19,9 +19,7 @@ async def process_mod( return False if not beatmap.beatmap_status.has_leaderboard(): return False - if len(score.mods) != 1 or score.mods[0]["acronym"] != mod: - return False - return True + return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod) async def process_category_mod( diff --git a/app/achievements/osu_combo.py b/app/achievements/osu_combo.py index 25abbbd..7039bae 100644 --- a/app/achievements/osu_combo.py +++ b/app/achievements/osu_combo.py @@ -22,11 +22,7 @@ async def process_combo( return False if next_combo != 0 and combo >= next_combo: return False - if combo <= score.max_combo < next_combo: - return True - elif next_combo == 0 and score.max_combo >= combo: - return True - return False + return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo)) MEDALS: Medals = { diff --git a/app/achievements/osu_playcount.py b/app/achievements/osu_playcount.py index 934e1c0..b5e1e9b 100644 --- a/app/achievements/osu_playcount.py +++ b/app/achievements/osu_playcount.py @@ -35,11 +35,7 @@ async def process_playcount( ).first() if not stats: return False - if pc <= stats.play_count < next_pc: - return True - elif next_pc == 0 and stats.play_count >= pc: - return True - return False + return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc)) MEDALS: Medals = { diff --git a/app/achievements/skill.py b/app/achievements/skill.py index 66123cc..43993d8 100644 --- a/app/achievements/skill.py +++ b/app/achievements/skill.py @@ -47,9 +47,7 @@ async def process_skill( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) if attribute.star_rating < star or attribute.star_rating >= star + 1: return False - if type == "fc" and not score.is_perfect_combo: - return False - return True + return not (type == "fc" and not score.is_perfect_combo) MEDALS: Medals = { diff --git a/app/achievements/total_hits.py b/app/achievements/total_hits.py index 5f3d13d..93fb2c5 100644 --- a/app/achievements/total_hits.py +++ b/app/achievements/total_hits.py @@ -35,11 +35,7 @@ async def process_tth( ).first() if not stats: return False - if tth <= stats.total_hits < next_tth: - return True - elif next_tth == 0 and stats.play_count >= tth: - return True - return False + return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth)) MEDALS: Medals = { diff --git a/app/auth.py b/app/auth.py index 6e43f8b..8bef0a8 100644 --- a/app/auth.py +++ b/app/auth.py @@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool: 2. MD5哈希 -> bcrypt验证 """ # 1. 明文密码转 MD5 - pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() + pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324 # 2. 检查缓存 if bcrypt_hash in bcrypt_cache: @@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: def get_password_hash(password: str) -> str: """生成密码哈希 - 使用 osu! 的方式""" # 1. 明文密码 -> MD5 - pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() + pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324 # 2. MD5 -> bcrypt pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt()) return pw_bcrypt.decode() @@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) - 验证用户身份 - 使用类似 from_login 的逻辑 """ # 1. 明文密码转 MD5 - pw_md5 = hashlib.md5(password.encode()).hexdigest() + pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324 # 2. 根据用户名查找用户 user = None @@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str: 根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性 """ - if settings.totp_use_username_in_label: - # 使用用户名作为主要标识 - primary_identifier = user.username - else: - # 使用邮箱作为标识 - primary_identifier = user.email + primary_identifier = user.username if settings.totp_use_username_in_label else user.email # 如果配置了服务名称,添加到标签中以便在认证器中区分 if settings.totp_service_name: diff --git a/app/calculator.py b/app/calculator.py index a408fcf..7c9eb7c 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool: if len(hit_objects) > i + per_1s: if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000: return True - elif len(hit_objects) > i + per_10s: - if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000: - return True + elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000: + return True return False @@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool: def is_2b(hit_objects: list[HitObject]) -> bool: - for i in range(0, len(hit_objects) - 1): - if hit_objects[i] == hit_objects[i + 1].start_time: - return True - return False + return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1)) def is_suspicious_beatmap(content: str) -> bool: diff --git a/app/config.py b/app/config.py index adb7794..b09c233 100644 --- a/app/config.py +++ b/app/config.py @@ -217,7 +217,7 @@ STORAGE_SETTINGS='{ # 服务器设置 host: Annotated[ str, - Field(default="0.0.0.0", description="服务器监听地址"), + Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104 "服务器设置", ] port: Annotated[ @@ -609,26 +609,26 @@ STORAGE_SETTINGS='{ ] @field_validator("fetcher_scopes", mode="before") + @classmethod def validate_fetcher_scopes(cls, v: Any) -> list[str]: if isinstance(v, str): return v.split(",") return v @field_validator("storage_settings", mode="after") + @classmethod def validate_storage_settings( cls, v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings, info: ValidationInfo, ) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings: - if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2: - if not isinstance(v, CloudflareR2Settings): - raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings") - elif info.data.get("storage_service") == StorageServiceType.LOCAL: - if not isinstance(v, LocalStorageSettings): - raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings") - elif info.data.get("storage_service") == StorageServiceType.AWS_S3: - if not isinstance(v, AWSS3StorageSettings): - raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings") + service = info.data.get("storage_service") + if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings): + raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings") + if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings): + raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings") + if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings): + raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings") return v diff --git a/app/database/beatmap.py b/app/database/beatmap.py index ff920b1..029915f 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True): failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}) @classmethod - async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": + async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": d = resp.model_dump() del d["beatmapset"] - beatmap = Beatmap.model_validate( + beatmap = cls.model_validate( { **d, "beatmapset_id": resp.beatmapset_id, @@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True): if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) await session.commit() - beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() - return beatmap + return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() @classmethod async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]: @@ -250,7 +249,7 @@ async def calculate_beatmap_attributes( redis: Redis, fetcher: "Fetcher", ): - key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" + key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes" if await redis.exists(key): return BeatmapAttributes.model_validate_json(await redis.get(key)) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index d2df25e..7829c8c 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod - async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset": + async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset": d = resp.model_dump() if resp.nominations: d["nominations_required"] = resp.nominations.required @@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): return beatmapset @classmethod - async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset": + async def from_resp( + cls, + session: AsyncSession, + resp: "BeatmapsetResp", + from_: int = 0, + ) -> "Beatmapset": from .beatmap import Beatmap - beatmapset = await cls.from_resp_no_save(session, resp, from_=from_) + beatmapset = await cls.from_resp_no_save(resp) if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first(): session.add(beatmapset) await session.commit() diff --git a/app/database/chat.py b/app/database/chat.py index f0d1c47..b05c790 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase): ) ).first() - last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg") - if last_msg and last_msg.isdigit(): - last_msg = int(last_msg) - else: - last_msg = None + last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") + last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None - last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") - if last_read_id and last_read_id.isdigit(): - last_read_id = int(last_read_id) - else: - last_read_id = last_msg + last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") + last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg if silence is not None: attribute = ChatUserAttributes( diff --git a/app/database/score.py b/app/database/score.py index fd9668f..172e8a4 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -520,12 +520,11 @@ async def _score_where( wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code)) else: return None - elif type == LeaderboardType.TEAM: - if user: - team_membership = await user.awaitable_attrs.team_membership - if team_membership: - team_id = team_membership.team_id - wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id))) + elif type == LeaderboardType.TEAM and user: + team_membership = await user.awaitable_attrs.team_membership + if team_membership: + team_id = team_membership.team_id + wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id))) if mods: if user and user.is_supporter: wheres.append( diff --git a/app/database/user.py b/app/database/user.py index 053c1b0..fbc9eed 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -256,8 +256,6 @@ class UserResp(UserBase): session: AsyncSession, include: list[str] = [], ruleset: GameMode | None = None, - *, - token_id: int | None = None, ) -> "UserResp": from app.dependencies.database import get_redis @@ -310,16 +308,16 @@ class UserResp(UserBase): ).all() ] - if "team" in include: - if team_membership := await obj.awaitable_attrs.team_membership: - u.team = team_membership.team + if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership): + u.team = team_membership.team if "account_history" in include: u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history] - if "daily_challenge_user_stats": - if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats: - u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats) + if "daily_challenge_user_stats" in include and ( + daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats + ): + u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats) if "statistics" in include: current_stattistics = None @@ -443,7 +441,7 @@ class MeResp(UserResp): from app.dependencies.database import get_redis from app.service.verification_service import LoginSessionService - u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id) + u = await super().from_db(obj, session, ALL_INCLUDED, ruleset) u.session_verified = ( not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) if token_id diff --git a/app/dependencies/param.py b/app/dependencies/param.py index 174adde..9e640bd 100644 --- a/app/dependencies/param.py +++ b/app/dependencies/param.py @@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, ValidationError -def BodyOrForm[T: BaseModel](model: type[T]): +def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802 async def dependency( request: Request, ) -> T: diff --git a/app/dependencies/user.py b/app/dependencies/user.py index ff3ff53..7061bcb 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -119,10 +119,7 @@ async def get_client_user( if verify_method is None: # 智能选择验证方式(有TOTP优先TOTP) totp_key = await user.awaitable_attrs.totp_key - if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER: - verify_method = "totp" - else: - verify_method = "mail" + verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail" # 设置选择的验证方法到Redis中,避免重复选择 if api_version >= 20250913: diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 8801fef..a67c6f6 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher): # 序列化为 JSON 并生成 MD5 哈希 cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":")) - cache_hash = hashlib.md5(cache_json.encode()).hexdigest() + cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest() logger.opt(colors=True).debug(f"[CacheKey] Query: {cache_data}, Hash: {cache_hash}") @@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher): cached_data = json.loads(cached_result) return SearchBeatmapsetsResp.model_validate(cached_data) except Exception as e: - logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}") + logger.warning(f"Cache data invalid, fetching from API: {e}") # 缓存未命中,从 API 获取数据 - logger.opt(colors=True).debug("Cache miss, fetching from API") + logger.debug("Cache miss, fetching from API") params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) @@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher): try: await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) except RateLimitError: - logger.opt(colors=True).info("Prefetch skipped due to rate limit") + logger.info("Prefetch skipped due to rate limit") bg_tasks.add_task(delayed_prefetch) @@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher): # 使用当前 cursor 请求下一页 next_query = query.model_copy() - logger.opt(colors=True).debug(f"Prefetching page {page + 1}") + logger.debug(f"Prefetching page {page + 1}") # 生成下一页的缓存键 next_cache_key = self._generate_cache_key(next_query, cursor) # 检查是否已经缓存 if await redis_client.exists(next_cache_key): - logger.opt(colors=True).debug(f"Page {page + 1} already cached") + logger.debug(f"Page {page + 1} already cached") # 尝试从缓存获取cursor继续预取 cached_data = await redis_client.get(next_cache_key) if cached_data: @@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher): cursor = data["cursor"] continue except Exception: - pass + logger.warning("Failed to parse cached data for cursor") break # 在预取页面之间添加延迟,避免突发请求 @@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher): ex=prefetch_ttl, ) - logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)") + logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)") except RateLimitError: - logger.opt(colors=True).info("Prefetch stopped due to rate limit") + logger.info("Prefetch stopped due to rate limit") except Exception as e: - logger.opt(colors=True).warning(f"Prefetch failed: {e}") + logger.warning(f"Prefetch failed: {e}") async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: """预热主页缓存""" homepage_queries = self._get_homepage_queries() - logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)") + logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)") for i, (query, cursor) in enumerate(homepage_queries): try: @@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher): # 检查是否已经缓存 if await redis_client.exists(cache_key): - logger.opt(colors=True).debug(f"Query {query.sort} already cached") + logger.debug(f"Query {query.sort} already cached") continue # 请求并缓存 @@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher): ex=cache_ttl, ) - logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)") + logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)") if api_response.get("cursor"): try: await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2) except RateLimitError: - logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit") + logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit") except RateLimitError: - logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit") + logger.warning(f"Warmup skipped for {query.sort} due to rate limit") except Exception as e: - logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}") + logger.error(f"Failed to warmup cache for {query.sort}: {e}") diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/helpers/geoip_helper.py b/app/helpers/geoip_helper.py index 3d65371..12c7c3c 100644 --- a/app/helpers/geoip_helper.py +++ b/app/helpers/geoip_helper.py @@ -1,19 +1,39 @@ """ -GeoLite2 Helper Class +GeoLite2 Helper Class (asynchronous) """ from __future__ import annotations +import asyncio +from contextlib import suppress import os from pathlib import Path import shutil import tarfile import tempfile import time +from typing import Any, Required, TypedDict +from app.log import logger + +import aiofiles import httpx import maxminddb + +class GeoIPLookupResult(TypedDict, total=False): + ip: Required[str] + country_iso: str + country_name: str + city_name: str + latitude: str + longitude: str + time_zone: str + postal_code: str + asn: int | None + organization: str + + BASE_URL = "https://download.maxmind.com/app/geoip_download" EDITIONS = { "City": "GeoLite2-City", @@ -25,161 +45,184 @@ EDITIONS = { class GeoIPHelper: def __init__( self, - dest_dir="./geoip", - license_key=None, - editions=None, - max_age_days=8, - timeout=60.0, + dest_dir: str | Path = Path("./geoip"), + license_key: str | None = None, + editions: list[str] | None = None, + max_age_days: int = 8, + timeout: float = 60.0, ): - self.dest_dir = dest_dir + self.dest_dir = Path(dest_dir).expanduser() self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY") - self.editions = editions or ["City", "ASN"] + self.editions = list(editions or ["City", "ASN"]) self.max_age_days = max_age_days self.timeout = timeout - self._readers = {} + self._readers: dict[str, maxminddb.Reader] = {} + self._update_lock = asyncio.Lock() @staticmethod - def _safe_extract(tar: tarfile.TarFile, path: str): - base = Path(path).resolve() - for m in tar.getmembers(): - target = (base / m.name).resolve() - if not str(target).startswith(str(base)): + def _safe_extract(tar: tarfile.TarFile, path: Path) -> None: + base = path.resolve() + for member in tar.getmembers(): + target = (base / member.name).resolve() + if not target.is_relative_to(base): # py312 raise RuntimeError("Unsafe path in tar file") - tar.extractall(path=path, filter="data") + tar.extractall(path=base, filter="data") - def _download_and_extract(self, edition_id: str) -> str: - """ - 下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb - - 跟随 302 重定向 - - 流式下载到临时文件 - - 临时目录退出后自动清理 - """ + @staticmethod + def _as_mapping(value: Any) -> dict[str, Any]: + return value if isinstance(value, dict) else {} + + @staticmethod + def _as_str(value: Any, default: str = "") -> str: + if isinstance(value, str): + return value + if value is None: + return default + return str(value) + + @staticmethod + def _as_int(value: Any) -> int | None: + return value if isinstance(value, int) else None + + @staticmethod + def _extract_tarball(src: Path, dest: Path) -> None: + with tarfile.open(src, "r:gz") as tar: + GeoIPHelper._safe_extract(tar, dest) + + @staticmethod + def _find_mmdb(root: Path) -> Path | None: + for candidate in root.rglob("*.mmdb"): + return candidate + return None + + def _latest_file_sync(self, edition_id: str) -> Path | None: + directory = self.dest_dir + if not directory.is_dir(): + return None + candidates = list(directory.glob(f"{edition_id}*.mmdb")) + if not candidates: + return None + return max(candidates, key=lambda p: p.stat().st_mtime) + + async def _latest_file(self, edition_id: str) -> Path | None: + return await asyncio.to_thread(self._latest_file_sync, edition_id) + + async def _download_and_extract(self, edition_id: str) -> Path: if not self.license_key: raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.") url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz" + tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp)) - with httpx.Client(follow_redirects=True, timeout=self.timeout) as client: - with client.stream("GET", url) as resp: + try: + tgz_path = tmp_dir / "db.tgz" + async with ( + httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client, + client.stream("GET", url) as resp, + ): resp.raise_for_status() - with tempfile.TemporaryDirectory() as tmpd: - tgz_path = os.path.join(tmpd, "db.tgz") - # 流式写入 - with open(tgz_path, "wb") as f: - for chunk in resp.iter_bytes(): - if chunk: - f.write(chunk) + async with aiofiles.open(tgz_path, "wb") as download_file: + async for chunk in resp.aiter_bytes(): + if chunk: + await download_file.write(chunk) - # 解压并只移动 .mmdb - with tarfile.open(tgz_path, "r:gz") as tar: - # 先安全检查与解压 - self._safe_extract(tar, tmpd) + await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir) + mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir) + if mmdb_path is None: + raise RuntimeError("未在压缩包中找到 .mmdb 文件") - # 递归找 .mmdb - mmdb_path = None - for root, _, files in os.walk(tmpd): - for fn in files: - if fn.endswith(".mmdb"): - mmdb_path = os.path.join(root, fn) - break - if mmdb_path: - break + await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True) + dst = self.dest_dir / mmdb_path.name + await asyncio.to_thread(shutil.move, mmdb_path, dst) + return dst + finally: + await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True) - if not mmdb_path: - raise RuntimeError("未在压缩包中找到 .mmdb 文件") + async def update(self, force: bool = False) -> None: + async with self._update_lock: + for edition in self.editions: + edition_id = EDITIONS[edition] + path = await self._latest_file(edition_id) + need_download = force or path is None - os.makedirs(self.dest_dir, exist_ok=True) - dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path)) - shutil.move(mmdb_path, dst) - return dst - - def _latest_file(self, edition_id: str): - if not os.path.isdir(self.dest_dir): - return None - files = [ - os.path.join(self.dest_dir, f) - for f in os.listdir(self.dest_dir) - if f.startswith(edition_id) and f.endswith(".mmdb") - ] - return max(files, key=os.path.getmtime) if files else None - - def update(self, force=False): - from app.log import logger - - for ed in self.editions: - eid = EDITIONS[ed] - path = self._latest_file(eid) - need = force or not path - - if path: - age_days = (time.time() - os.path.getmtime(path)) / 86400 - if age_days >= self.max_age_days: - need = True - logger.info( - f"{eid} database is {age_days:.1f} days old " - f"(max: {self.max_age_days}), will download new version" - ) + if path: + mtime = await asyncio.to_thread(path.stat) + age_days = (time.time() - mtime.st_mtime) / 86400 + if age_days >= self.max_age_days: + need_download = True + logger.info( + f"{edition_id} database is {age_days:.1f} days old " + f"(max: {self.max_age_days}), will download new version" + ) + else: + logger.info( + f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})" + ) else: - logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})") - else: - logger.info(f"{eid} database not found, will download") + logger.info(f"{edition_id} database not found, will download") - if need: - logger.info(f"Downloading {eid} database...") - path = self._download_and_extract(eid) - logger.info(f"{eid} database downloaded successfully") - else: - logger.info(f"Using existing {eid} database") + if need_download: + logger.info(f"Downloading {edition_id} database...") + path = await self._download_and_extract(edition_id) + logger.info(f"{edition_id} database downloaded successfully") + else: + logger.info(f"Using existing {edition_id} database") - old = self._readers.get(ed) - if old: - try: - old.close() - except Exception: - pass - if path is not None: - self._readers[ed] = maxminddb.open_database(path) + old_reader = self._readers.get(edition) + if old_reader: + with suppress(Exception): + old_reader.close() + if path is not None: + self._readers[edition] = maxminddb.open_database(str(path)) - def lookup(self, ip: str): - res = {"ip": ip} - # City - city_r = self._readers.get("City") - if city_r: - data = city_r.get(ip) - if data: - country = data.get("country") or {} - res["country_iso"] = country.get("iso_code") or "" - res["country_name"] = (country.get("names") or {}).get("en", "") - city = data.get("city") or {} - res["city_name"] = (city.get("names") or {}).get("en", "") - loc = data.get("location") or {} - res["latitude"] = str(loc.get("latitude") or "") - res["longitude"] = str(loc.get("longitude") or "") - res["time_zone"] = str(loc.get("time_zone") or "") - postal = data.get("postal") or {} - if "code" in postal: - res["postal_code"] = postal["code"] - # ASN - asn_r = self._readers.get("ASN") - if asn_r: - data = asn_r.get(ip) - if data: - res["asn"] = data.get("autonomous_system_number") - res["organization"] = data.get("autonomous_system_organization") + def lookup(self, ip: str) -> GeoIPLookupResult: + res: GeoIPLookupResult = {"ip": ip} + city_reader = self._readers.get("City") + if city_reader: + data = city_reader.get(ip) + if isinstance(data, dict): + country = self._as_mapping(data.get("country")) + res["country_iso"] = self._as_str(country.get("iso_code")) + country_names = self._as_mapping(country.get("names")) + res["country_name"] = self._as_str(country_names.get("en")) + + city = self._as_mapping(data.get("city")) + city_names = self._as_mapping(city.get("names")) + res["city_name"] = self._as_str(city_names.get("en")) + + location = self._as_mapping(data.get("location")) + latitude = location.get("latitude") + longitude = location.get("longitude") + res["latitude"] = str(latitude) if latitude is not None else "" + res["longitude"] = str(longitude) if longitude is not None else "" + res["time_zone"] = self._as_str(location.get("time_zone")) + + postal = self._as_mapping(data.get("postal")) + postal_code = postal.get("code") + if postal_code is not None: + res["postal_code"] = self._as_str(postal_code) + + asn_reader = self._readers.get("ASN") + if asn_reader: + data = asn_reader.get(ip) + if isinstance(data, dict): + res["asn"] = self._as_int(data.get("autonomous_system_number")) + res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="") return res - def close(self): - for r in self._readers.values(): - try: - r.close() - except Exception: - pass + def close(self) -> None: + for reader in self._readers.values(): + with suppress(Exception): + reader.close() self._readers = {} if __name__ == "__main__": - # 示例用法 - geo = GeoIPHelper(dest_dir="./geoip", license_key="") - geo.update() - print(geo.lookup("8.8.8.8")) - geo.close() + + async def _demo() -> None: + geo = GeoIPHelper(dest_dir="./geoip", license_key="") + await geo.update() + print(geo.lookup("8.8.8.8")) + geo.close() + + asyncio.run(_demo()) diff --git a/app/log.py b/app/log.py index ba0ef95..d2c5060 100644 --- a/app/log.py +++ b/app/log.py @@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler): status_color = "green" elif 300 <= status < 400: status_color = "yellow" - elif 400 <= status < 500: - status_color = "red" - elif 500 <= status < 600: + elif 400 <= status < 500 or 500 <= status < 600: status_color = "red" return ( diff --git a/app/middleware/verify_session.py b/app/middleware/verify_session.py index 58ed754..dee6332 100644 --- a/app/middleware/verify_session.py +++ b/app/middleware/verify_session.py @@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): return await call_next(request) # 启动验证流程 - return await self._initiate_verification(request, session_state) + return await self._initiate_verification(session_state) def _should_skip_verification(self, request: Request) -> bool: """检查是否应该跳过验证""" @@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): return True # 非API请求跳过 - if not path.startswith("/api/"): - return True - - return False + return bool(not path.startswith("/api/")) def _requires_verification(self, request: Request, user: User) -> bool: """检查是否需要验证""" @@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): logger.error(f"Error getting session state: {e}") return None - async def _initiate_verification(self, request: Request, state: SessionState) -> Response: + async def _initiate_verification(self, state: SessionState) -> Response: """启动验证流程""" try: method = await state.get_method() diff --git a/app/models/extended_auth.py b/app/models/extended_auth.py index b3fc831..35a3752 100644 --- a/app/models/extended_auth.py +++ b/app/models/extended_auth.py @@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel): """扩展的令牌响应,支持二次验证状态""" access_token: str | None = None - token_type: str = "Bearer" + token_type: str = "Bearer" # noqa: S105 expires_in: int | None = None refresh_token: str | None = None scope: str | None = None @@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel): requires_second_factor: bool = False verification_message: str | None = None user_id: int | None = None # 用于二次验证的用户ID - - -class SessionState(BaseModel): - """会话状态""" - - user_id: int - username: str - email: str - requires_verification: bool - session_token: str | None = None - verification_sent: bool = False diff --git a/app/models/notification.py b/app/models/notification.py index cc95e4c..ceef3b0 100644 --- a/app/models/notification.py +++ b/app/models/notification.py @@ -1,3 +1,4 @@ +# ruff: noqa: ARG002 from __future__ import annotations from abc import abstractmethod diff --git a/app/models/oauth.py b/app/models/oauth.py index f3db41f..ce4cabf 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -22,7 +22,7 @@ class TokenRequest(BaseModel): class TokenResponse(BaseModel): access_token: str - token_type: str = "Bearer" + token_type: str = "Bearer" # noqa: S105 expires_in: int refresh_token: str scope: str = "*" @@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel): class OAuth2ClientCredentialsBearer(OAuth2): def __init__( self, - tokenUrl: Annotated[ + tokenUrl: Annotated[ # noqa: N803 str, Doc( """ @@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2): """ ), ], - refreshUrl: Annotated[ + refreshUrl: Annotated[ # noqa: N803 str | None, Doc( """ diff --git a/app/models/v1_user.py b/app/models/v1_user.py index 9868260..ea7d177 100644 --- a/app/models/v1_user.py +++ b/app/models/v1_user.py @@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel): class PlayerEventItem(BaseModel): """玩家事件项目""" - userId: int + userId: int # noqa: N815 name: str - mapId: int | None = None - setId: int | None = None + mapId: int | None = None # noqa: N815 + setId: int | None = None # noqa: N815 artist: str | None = None title: str | None = None version: str | None = None @@ -88,7 +88,7 @@ class PlayerInfo(BaseModel): custom_badge_icon: str custom_badge_color: str userpage_content: str - recentFailed: int + recentFailed: int # noqa: N815 social_discord: str | None = None social_youtube: str | None = None social_twitter: str | None = None diff --git a/app/router/auth.py b/app/router/auth.py index ff163b3..1512642 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -126,21 +126,22 @@ async def register_user( try: # 获取客户端 IP 并查询地理位置 - country_code = "CN" # 默认国家代码 + country_code = None # 默认国家代码 try: # 查询 IP 地理位置 geo_info = geoip.lookup(client_ip) - if geo_info and geo_info.get("country_iso"): - country_code = geo_info["country_iso"] + if geo_info and (country_code := geo_info.get("country_iso")): logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}") else: logger.warning(f"Could not determine country for IP {client_ip}") except Exception as e: logger.warning(f"GeoIP lookup failed for {client_ip}: {e}") + if country_code is None: + country_code = "CN" # 创建新用户 - # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) + # 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot) result = await db.execute( text( "SELECT AUTO_INCREMENT FROM information_schema.TABLES " @@ -157,7 +158,7 @@ async def register_user( email=user_email, pw_bcrypt=get_password_hash(user_password), priv=1, # 普通用户权限 - country_code=country_code, # 根据 IP 地理位置设置国家 + country_code=country_code, join_date=utcnow(), last_visit=utcnow(), is_supporter=settings.enable_supporter_for_all_users, @@ -386,7 +387,7 @@ async def oauth_token( return TokenResponse( access_token=access_token, - token_type="Bearer", + token_type="Bearer", # noqa: S106 expires_in=settings.access_token_expire_minutes * 60, refresh_token=refresh_token_str, scope=scope, @@ -439,7 +440,7 @@ async def oauth_token( ) return TokenResponse( access_token=access_token, - token_type="Bearer", + token_type="Bearer", # noqa: S106 expires_in=settings.access_token_expire_minutes * 60, refresh_token=new_refresh_token, scope=scope, @@ -509,7 +510,7 @@ async def oauth_token( return TokenResponse( access_token=access_token, - token_type="Bearer", + token_type="Bearer", # noqa: S106 expires_in=settings.access_token_expire_minutes * 60, refresh_token=refresh_token_str, scope=" ".join(scopes), @@ -554,7 +555,7 @@ async def oauth_token( return TokenResponse( access_token=access_token, - token_type="Bearer", + token_type="Bearer", # noqa: S106 expires_in=settings.access_token_expire_minutes * 60, refresh_token=refresh_token_str, scope=" ".join(scopes), diff --git a/app/router/lio.py b/app/router/lio.py index 969b214..93e0088 100644 --- a/app/router/lio.py +++ b/app/router/lio.py @@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us "allowed_mods": item_data.get("allowed_mods", []), "expired": bool(item_data.get("expired", False)), "playlist_order": item_data.get("playlist_order", default_order), - "played_at": item_data.get("played_at", None), + "played_at": item_data.get("played_at"), "freestyle": bool(item_data.get("freestyle", True)), "beatmap_checksum": item_data.get("beatmap_checksum", ""), "star_rating": item_data.get("star_rating", 0.0), diff --git a/app/router/notification/banchobot.py b/app/router/notification/banchobot.py index 7b01347..a491b7d 100644 --- a/app/router/notification/banchobot.py +++ b/app/router/notification/banchobot.py @@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch @bot.command("roll") def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str: - if len(args) > 0 and args[0].isdigit(): - r = random.randint(1, int(args[0])) - else: - r = random.randint(1, 100) + r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100) return f"{user.username} rolls {r} point(s)" @@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch if gamemode is None: subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery() last_score = (await session.exec(select(Score).where(Score.id == subquery))).first() - if last_score is not None: - gamemode = last_score.gamemode - else: - gamemode = target_user.playmode + gamemode = last_score.gamemode if last_score is not None else target_user.playmode statistics = ( await session.exec( diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 9940801..0732124 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -313,10 +313,7 @@ async def chat_websocket( # 优先使用查询参数中的token,支持token或access_token参数名 auth_token = token or access_token if not auth_token and authorization: - if authorization.startswith("Bearer "): - auth_token = authorization[7:] - else: - auth_token = authorization + auth_token = authorization.removeprefix("Bearer ") if not auth_token: await websocket.close(code=1008, reason="Missing authentication token") diff --git a/app/router/redirect.py b/app/router/redirect.py index 7805fc4..bec9eca 100644 --- a/app/router/redirect.py +++ b/app/router/redirect.py @@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse redirect_router = APIRouter(include_in_schema=False) -@redirect_router.get("/users/{path:path}") +@redirect_router.get("/users/{path:path}") # noqa: FAST003 @redirect_router.get("/teams/{team_id}") @redirect_router.get("/u/{user_id}") @redirect_router.get("/b/{beatmap_id}") diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py index b723713..6ca3775 100644 --- a/app/router/v1/beatmap.py +++ b/app/router/v1/beatmap.py @@ -168,10 +168,7 @@ async def get_beatmaps( elif beatmapset_id is not None: beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id) await beatmapset.awaitable_attrs.beatmaps - if len(beatmapset.beatmaps) > limit: - beatmaps = beatmapset.beatmaps[:limit] - else: - beatmaps = beatmapset.beatmaps + beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps elif user is not None: where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user beatmapsets = (await session.exec(select(Beatmapset).where(where))).all() diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index 024a542..152e8f0 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -158,7 +158,10 @@ async def get_beatmap_attributes( if ruleset is None: beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) ruleset = beatmap_db.mode - key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" + key = ( + f"beatmap:{beatmap_id}:{ruleset}:" + f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes" + ) if await redis.exists(key): return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index c4f2561..3cfcdb1 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp): response_model=SearchBeatmapsetsResp, ) async def search_beatmapset( - db: Database, query: Annotated[SearchQueryModel, Query(...)], request: Request, background_tasks: BackgroundTasks, @@ -104,7 +103,7 @@ async def search_beatmapset( if cached_result: sets = SearchBeatmapsetsResp(**cached_result) # 处理资源代理 - processed_sets = await process_response_assets(sets, request) + processed_sets = await process_response_assets(sets) return processed_sets try: @@ -115,7 +114,7 @@ async def search_beatmapset( await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump()) # 处理资源代理 - processed_sets = await process_response_assets(sets, request) + processed_sets = await process_response_assets(sets) return processed_sets except HTTPError as e: raise HTTPException(status_code=500, detail=str(e)) from e @@ -140,7 +139,7 @@ async def lookup_beatmapset( cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id) if cached_resp: # 处理资源代理 - processed_resp = await process_response_assets(cached_resp, request) + processed_resp = await process_response_assets(cached_resp) return processed_resp try: @@ -151,7 +150,7 @@ async def lookup_beatmapset( await cache_service.cache_beatmap_lookup(beatmap_id, resp) # 处理资源代理 - processed_resp = await process_response_assets(resp, request) + processed_resp = await process_response_assets(resp) return processed_resp except HTTPError as exc: raise HTTPException(status_code=404, detail="Beatmap not found") from exc @@ -176,7 +175,7 @@ async def get_beatmapset( cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id) if cached_resp: # 处理资源代理 - processed_resp = await process_response_assets(cached_resp, request) + processed_resp = await process_response_assets(cached_resp) return processed_resp try: @@ -187,7 +186,7 @@ async def get_beatmapset( await cache_service.cache_beatmapset(resp) # 处理资源代理 - processed_resp = await process_response_assets(resp, request) + processed_resp = await process_response_assets(resp) return processed_resp except HTTPError as exc: raise HTTPException(status_code=404, detail="Beatmapset not found") from exc diff --git a/app/router/v2/room.py b/app/router/v2/room.py index decf936..34fb6f7 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -166,7 +166,6 @@ async def get_room( db: Database, room_id: Annotated[int, Path(..., description="房间 ID")], current_user: Annotated[User, Security(get_current_user, scopes=["public"])], - redis: Redis, category: Annotated[ str, Query( diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 2b47c78..12be45d 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -847,10 +847,7 @@ async def reorder_score_pin( detail = "After score not found" if after_score_id else "Before score not found" raise HTTPException(status_code=404, detail=detail) - if after_score_id: - target_order = reference_score.pinned_order + 1 - else: - target_order = reference_score.pinned_order + target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order current_order = score_record.pinned_order diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 424b988..add3d70 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel): message: str -class VerifyFailed(Exception): +class VerifyFailedError(Exception): def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False): super().__init__(message) self.reason = reason @@ -93,10 +93,7 @@ async def verify_session( # 智能选择验证方法(参考osu-web实现) # API版本较老或用户未设置TOTP时强制使用邮件验证 # print(api_version, totp_key) - if api_version < 20240101 or totp_key is None: - verify_method = "mail" - else: - verify_method = "totp" + verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp" await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis) login_method = verify_method @@ -109,7 +106,7 @@ async def verify_session( db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent ) verify_method = "mail" - raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证") + raise VerifyFailedError("用户TOTP已被删除,已切换到邮件验证") # 如果未开启邮箱验证,则直接认为认证通过 # 正常不会进入到这里 @@ -120,16 +117,16 @@ async def verify_session( else: # 记录详细的验证失败原因(参考osu-web的错误处理) if len(verification_key) != 6: - raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length") + raise VerifyFailedError("TOTP验证码长度错误,应为6位数字", reason="incorrect_length") elif not verification_key.isdigit(): - raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format") + raise VerifyFailedError("TOTP验证码格式错误,应为纯数字", reason="incorrect_format") else: # 可能是密钥错误或者重放攻击 - raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key") + raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key") else: success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key) if not success: - raise VerifyFailed(f"邮件验证失败: {message}") + raise VerifyFailedError(f"邮件验证失败: {message}") await LoginLogService.record_login( db=db, @@ -144,7 +141,7 @@ async def verify_session( await db.commit() return Response(status_code=status.HTTP_204_NO_CONTENT) - except VerifyFailed as e: + except VerifyFailedError as e: await LoginLogService.record_failed_login( db=db, request=request, @@ -171,7 +168,9 @@ async def verify_session( ) error_response["reissued"] = True except Exception: - pass # 忽略重发邮件失败的错误 + log("Verification").exception( + f"Failed to resend verification email to user {current_user.id} (token: {token_id})" + ) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response) diff --git a/app/router/v2/tags.py b/app/router/v2/tags.py index 644cd77..810656d 100644 --- a/app/router/v2/tags.py +++ b/app/router/v2/tags.py @@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession .where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode)) ) ).first() - if user_beatmap_score is None: - return False - return True + return user_beatmap_score is not None @router.put( @@ -75,10 +73,9 @@ async def vote_beatmap_tags( .where(BeatmapTagVote.user_id == current_user.id) ) ).first() - if previous_votes is None: - if check_user_can_vote(current_user, beatmap_id, session): - new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id) - session.add(new_vote) + if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session): + new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id) + session.add(new_vote) await session.commit() except ValueError: raise HTTPException(400, "Tag is not found") diff --git a/app/router/v2/user.py b/app/router/v2/user.py index 98a0b93..e4e41ea 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -91,7 +91,7 @@ async def get_users( # 处理资源代理 response = BatchUserResponse(users=cached_users) - processed_response = await process_response_assets(response, request) + processed_response = await process_response_assets(response) return processed_response else: searched_users = (await session.exec(select(User).limit(50))).all() @@ -109,7 +109,7 @@ async def get_users( # 处理资源代理 response = BatchUserResponse(users=users) - processed_response = await process_response_assets(response, request) + processed_response = await process_response_assets(response) return processed_response @@ -240,7 +240,7 @@ async def get_user_info( cached_user = await cache_service.get_user_from_cache(user_id_int) if cached_user: # 处理资源代理 - processed_user = await process_response_assets(cached_user, request) + processed_user = await process_response_assets(cached_user) return processed_user searched_user = ( @@ -263,7 +263,7 @@ async def get_user_info( background_task.add_task(cache_service.cache_user, user_resp) # 处理资源代理 - processed_user = await process_response_assets(user_resp, request) + processed_user = await process_response_assets(user_resp) return processed_user @@ -381,7 +381,7 @@ async def get_user_scores( user_id, type, include_fails, mode, limit, offset, is_legacy_api ) if cached_scores is not None: - processed_scores = await process_response_assets(cached_scores, request) + processed_scores = await process_response_assets(cached_scores) return processed_scores db_user = await session.get(User, user_id) @@ -438,5 +438,5 @@ async def get_user_scores( ) # 处理资源代理 - processed_scores = await process_response_assets(score_responses, request) + processed_scores = await process_response_assets(score_responses) return processed_scores diff --git a/app/service/asset_proxy_helper.py b/app/service/asset_proxy_helper.py index c654821..c41e77c 100644 --- a/app/service/asset_proxy_helper.py +++ b/app/service/asset_proxy_helper.py @@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service from fastapi import Request -async def process_response_assets(data: Any, request: Request) -> Any: +async def process_response_assets(data: Any) -> Any: """ 根据配置处理响应数据中的资源URL @@ -72,7 +72,7 @@ def asset_proxy_response(func): # 如果有request对象且启用了资源代理,则处理响应 if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path): - result = await process_response_assets(result, request) + result = await process_response_assets(result) return result diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py index 3a81686..1a76195 100644 --- a/app/service/beatmap_cache_service.py +++ b/app/service/beatmap_cache_service.py @@ -113,6 +113,7 @@ class BeatmapCacheService: if size: total_size += size except Exception: + logger.debug(f"Failed to get size for key {key}") continue return { diff --git a/app/service/beatmapset_cache_service.py b/app/service/beatmapset_cache_service.py index df23f20..f255c8c 100644 --- a/app/service/beatmapset_cache_service.py +++ b/app/service/beatmapset_cache_service.py @@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str: def generate_hash(data) -> str: """生成数据的MD5哈希值""" - if isinstance(data, str): - content = data - else: - content = safe_json_dumps(data) - return hashlib.md5(content.encode()).hexdigest() + content = data if isinstance(data, str) else safe_json_dumps(data) + return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest() class BeatmapsetCacheService: diff --git a/app/service/beatmapset_update_service.py b/app/service/beatmapset_update_service.py index 8852146..0cc8704 100644 --- a/app/service/beatmapset_update_service.py +++ b/app/service/beatmapset_update_service.py @@ -110,9 +110,7 @@ class ProcessingBeatmapset: changed_beatmaps = [] for bm in self.beatmapset.beatmaps: saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None) - if not saved: - changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED)) - elif saved["is_deleted"]: + if not saved or saved["is_deleted"]: changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED)) elif saved["md5"] != bm.checksum: changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED)) @@ -285,7 +283,7 @@ class BeatmapsetUpdateService: async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp): async with with_db() as session: db_beatmapset = await session.get(Beatmapset, beatmapset.id) - new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset) + new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) if db_beatmapset: await session.merge(new_beatmapset) await session.commit() @@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic def get_beatmapset_update_service() -> BeatmapsetUpdateService: + if service is None: + raise ValueError("BeatmapsetUpdateService is not initialized") assert service is not None, "BeatmapsetUpdateService is not initialized" return service diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py index 6fa2f1a..0570493 100644 --- a/app/service/login_log_service.py +++ b/app/service/login_log_service.py @@ -128,7 +128,11 @@ class LoginLogService: login_success=False, login_method=login_method, user_agent=user_agent, - notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt", + notes=( + f"Failed login attempt on user {attempted_username}: {notes}" + if attempted_username + else "Failed login attempt" + ), ) diff --git a/app/service/password_reset_service.py b/app/service/password_reset_service.py index 5d831d6..429840f 100644 --- a/app/service/password_reset_service.py +++ b/app/service/password_reset_service.py @@ -120,7 +120,7 @@ class PasswordResetService: await redis.delete(reset_code_key) await redis.delete(rate_limit_key) except Exception: - pass + logger.warning("Failed to clean up Redis data after error") logger.exception("Redis operation failed") return False, "服务暂时不可用,请稍后重试" diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index 8b4f3cf..0d9dd37 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -593,10 +593,7 @@ class RankingCacheService: async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None: """使地区排行榜缓存失效""" try: - if ruleset: - pattern = f"country_ranking:{ruleset}:*" - else: - pattern = "country_ranking:*" + pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*" keys = await self.redis.keys(pattern) if keys: @@ -608,10 +605,7 @@ class RankingCacheService: async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None: """使战队排行榜缓存失效""" try: - if ruleset: - pattern = f"team_ranking:{ruleset}:*" - else: - pattern = "team_ranking:*" + pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*" keys = await self.redis.keys(pattern) if keys: @@ -637,6 +631,7 @@ class RankingCacheService: if size: total_size += size except Exception: + logger.warning(f"Failed to get memory usage for key {key}") continue return { diff --git a/app/service/subscribers/__init__.py b/app/service/subscribers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/subscribers/chat.py b/app/service/subscribers/chat.py index 9512f9c..9241a7f 100644 --- a/app/service/subscribers/chat.py +++ b/app/service/subscribers/chat.py @@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber): self.add_handler(ON_NOTIFICATION, self.on_notification) self.start() - async def on_join_room(self, c: str, s: str): + async def on_join_room(self, c: str, s: str): # noqa: ARG002 channel_id, user_id = s.split(":") if self.chat_server is None: return await self.chat_server.join_room_channel(int(channel_id), int(user_id)) - async def on_leave_room(self, c: str, s: str): + async def on_leave_room(self, c: str, s: str): # noqa: ARG002 channel_id, user_id = s.split(":") if self.chat_server is None: return await self.chat_server.leave_room_channel(int(channel_id), int(user_id)) - async def on_notification(self, c: str, s: str): + async def on_notification(self, c: str, s: str): # noqa: ARG002 try: detail = TypeAdapter(NotificationDetails).validate_json(s) except ValueError: diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py index ae0022e..4a40229 100644 --- a/app/service/user_cache_service.py +++ b/app/service/user_cache_service.py @@ -357,6 +357,7 @@ class UserCacheService: if size: total_size += size except Exception: + logger.warning(f"Failed to get memory usage for key {key}") continue return { diff --git a/app/service/verification_service.py b/app/service/verification_service.py index d1f2ee4..f7cc0ea 100644 --- a/app/service/verification_service.py +++ b/app/service/verification_service.py @@ -288,10 +288,6 @@ This email was sent automatically, please do not reply. redis: Redis, user_id: int, code: str, - ip_address: str | None = None, - user_agent: str | None = None, - client_id: int | None = None, - country_code: str | None = None, ) -> tuple[bool, str]: """验证邮箱验证码""" try: diff --git a/app/tasks/cache.py b/app/tasks/cache.py index 4a684f6..21934c9 100644 --- a/app/tasks/cache.py +++ b/app/tasks/cache.py @@ -41,7 +41,7 @@ async def warmup_cache() -> None: logger.info("Beatmap cache warmup completed successfully") except Exception as e: - logger.error("Beatmap cache warmup failed: %s", e) + logger.error(f"Beatmap cache warmup failed: {e}") async def refresh_ranking_cache() -> None: @@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None: logger.info("Ranking cache refresh completed successfully") except Exception as e: - logger.error("Ranking cache refresh failed: %s", e) + logger.error(f"Ranking cache refresh failed: {e}") async def schedule_user_cache_preload_task() -> None: @@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None: if active_user_ids: user_ids = [row[0] for row in active_user_ids] await cache_service.preload_user_cache(session, user_ids) - logger.info("Preloaded cache for %s active users", len(user_ids)) + logger.info(f"Preloaded cache for {len(user_ids)} active users") else: logger.info("No active users found for cache preload") logger.info("User cache preload task completed successfully") except Exception as e: - logger.error("User cache preload task failed: %s", e) + logger.error(f"User cache preload task failed: {e}") async def schedule_user_cache_warmup_task() -> None: @@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None: if top_users: user_ids = list(top_users) await cache_service.preload_user_cache(session, user_ids) - logger.info("Warmed cache for top 100 users in %s", mode) + logger.info(f"Warmed cache for top 100 users in {mode}") await asyncio.sleep(1) except Exception as e: - logger.error("Failed to warm cache for %s: %s", mode, e) + logger.error(f"Failed to warm cache for {mode}: {e}") continue logger.info("User cache warmup task completed successfully") except Exception as e: - logger.error("User cache warmup task failed: %s", e) + logger.error(f"User cache warmup task failed: {e}") async def schedule_user_cache_cleanup_task() -> None: @@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None: cache_service = get_user_cache_service(redis) stats = await cache_service.get_cache_stats() - logger.info("User cache stats: %s", stats) + logger.info(f"User cache stats: {stats}") logger.info("User cache cleanup task completed successfully") except Exception as e: - logger.error("User cache cleanup task failed: %s", e) + logger.error(f"User cache cleanup task failed: {e}") async def warmup_user_cache() -> None: @@ -167,7 +167,7 @@ async def warmup_user_cache() -> None: try: await schedule_user_cache_warmup_task() except Exception as e: - logger.error("User cache warmup failed: %s", e) + logger.error(f"User cache warmup failed: {e}") async def preload_user_cache() -> None: @@ -175,7 +175,7 @@ async def preload_user_cache() -> None: try: await schedule_user_cache_preload_task() except Exception as e: - logger.error("User cache preload failed: %s", e) + logger.error(f"User cache preload failed: {e}") async def cleanup_user_cache() -> None: @@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None: try: await schedule_user_cache_cleanup_task() except Exception as e: - logger.error("User cache cleanup failed: %s", e) + logger.error(f"User cache cleanup failed: {e}") def register_cache_jobs() -> None: diff --git a/app/tasks/geoip.py b/app/tasks/geoip.py index 0d22ed8..2868346 100644 --- a/app/tasks/geoip.py +++ b/app/tasks/geoip.py @@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database from __future__ import annotations -import asyncio - from app.config import settings from app.dependencies.geoip import get_geoip_helper from app.dependencies.scheduler import get_scheduler @@ -28,14 +26,10 @@ async def update_geoip_database(): try: logger.info("Starting scheduled GeoIP database update...") geoip = get_geoip_helper() - - # Run the synchronous update method in a background thread - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: geoip.update(force=False)) - + await geoip.update(force=False) logger.info("Scheduled GeoIP database update completed successfully") - except Exception as e: - logger.error(f"Scheduled GeoIP database update failed: {e}") + except Exception as exc: + logger.error(f"Scheduled GeoIP database update failed: {exc}") async def init_geoip(): @@ -45,13 +39,8 @@ async def init_geoip(): try: geoip = get_geoip_helper() logger.info("Initializing GeoIP database...") - - # Run the synchronous update method in a background thread - # force=False means only download if files don't exist or are expired - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: geoip.update(force=False)) - + await geoip.update(force=False) logger.info("GeoIP database initialization completed") - except Exception as e: - logger.error(f"GeoIP database initialization failed: {e}") + except Exception as exc: + logger.error(f"GeoIP database initialization failed: {exc}") # Do not raise an exception to avoid blocking application startup diff --git a/app/tasks/osu_rx_statistics.py b/app/tasks/osu_rx_statistics.py index 732d727..9b2a796 100644 --- a/app/tasks/osu_rx_statistics.py +++ b/app/tasks/osu_rx_statistics.py @@ -16,7 +16,7 @@ async def create_rx_statistics(): async with with_db() as session: users = (await session.exec(select(User.id))).all() total_users = len(users) - logger.info("Ensuring RX/AP statistics exist for %s users", total_users) + logger.info(f"Ensuring RX/AP statistics exist for {total_users} users") rx_created = 0 ap_created = 0 for i in users: @@ -57,7 +57,5 @@ async def create_rx_statistics(): await session.commit() if rx_created or ap_created: logger.success( - "Created %s RX statistics rows and %s AP statistics rows during backfill", - rx_created, - ap_created, + f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill" ) diff --git a/app/utils.py b/app/utils.py index 9a610e0..69b3dd7 100644 --- a/app/utils.py +++ b/app/utils.py @@ -258,10 +258,7 @@ class BackgroundTasks: self.tasks = set(tasks) if tasks else set() def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: - if is_async_callable(func): - coro = func(*args, **kwargs) - else: - coro = run_in_threadpool(func, *args, **kwargs) + coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs) task = asyncio.create_task(coro) self.tasks.add(task) task.add_done_callback(self.tasks.discard) diff --git a/main.py b/main.py index 282fd18..27089b1 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import asynccontextmanager +import json from pathlib import Path from app.config import settings @@ -50,7 +51,7 @@ import sentry_sdk @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI): # noqa: ARG001 # on startup init_mods() init_ranked_mods() @@ -223,26 +224,26 @@ async def health_check(): @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): +async def validation_exception_handler(request: Request, exc: RequestValidationError): # noqa: ARG001 return JSONResponse( status_code=422, content={ - "error": exc.errors(), + "error": json.dumps(exc.errors()), }, ) @app.exception_handler(HTTPException) -async def http_exception_handler(requst: Request, exc: HTTPException): +async def http_exception_handler(request: Request, exc: HTTPException): # noqa: ARG001 return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) -if settings.secret_key == "your_jwt_secret_here": +if settings.secret_key == "your_jwt_secret_here": # noqa: S105 system_logger("Security").opt(colors=True).warning( "jwt_secret_key is unset. Your server is unsafe. " "Use this command to generate: openssl rand -hex 32." ) -if settings.osu_web_client_secret == "your_osu_web_client_secret_here": +if settings.osu_web_client_secret == "your_osu_web_client_secret_here": # noqa: S105 system_logger("Security").opt(colors=True).warning( "osu_web_client_secret is unset. Your server is unsafe. " "Use this command to generate: openssl rand -hex 40." diff --git a/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py b/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py index 9e15a12..aeba732 100644 --- a/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py +++ b/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py @@ -1,3 +1,4 @@ +# ruff: noqa """add_password_reset_table Revision ID: d103d442dc24 diff --git a/pyproject.toml b/pyproject.toml index 9862e00..2f99cda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,12 +55,20 @@ select = [ "ASYNC", # flake8-async "C4", # flake8-comprehensions "T10", # flake8-debugger - # "T20", # flake8-print "PYI", # flake8-pyi "PT", # flake8-pytest-style "Q", # flake8-quotes "TID", # flake8-tidy-imports "RUF", # Ruff-specific rules + "FAST", # FastAPI + "YTT", # flake8-2020 + "S", # flake8-bandit + "INP", # flake8-no-pep420 + "SIM", # flake8-simplify + "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "N", # pep8-naming + "FURB" # refurb ] ignore = [ "E402", # module-import-not-at-top-of-file @@ -68,10 +76,17 @@ ignore = [ "RUF001", # ambiguous-unicode-character-string "RUF002", # ambiguous-unicode-character-docstring "RUF003", # ambiguous-unicode-character-comment + "S101", # assert + "S311", # suspicious-non-cryptographic-random-usage ] [tool.ruff.lint.extend-per-file-ignores] "app/database/**/*.py" = ["I002"] +"tools/*.py" = ["PTH", "INP001"] +"migrations/**/*.py" = ["INP001"] +".github/**/*.py" = ["INP001"] +"app/achievements/*.py" = ["INP001", "ARG"] +"app/router/**/*.py" = ["ARG001"] [tool.ruff.lint.isort] force-sort-within-sections = true diff --git a/tools/fix_user_rank_event.py b/tools/fix_user_rank_event.py index 45b098a..67eab51 100644 --- a/tools/fix_user_rank_event.py +++ b/tools/fix_user_rank_event.py @@ -163,13 +163,19 @@ async def main(): # Show specific changes changes = [] - if "scorerank" in original_payload and "scorerank" in fixed_payload: - if original_payload["scorerank"] != fixed_payload["scorerank"]: - changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}") + if ( + "scorerank" in original_payload + and "scorerank" in fixed_payload + and original_payload["scorerank"] != fixed_payload["scorerank"] + ): + changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}") - if "mode" in original_payload and "mode" in fixed_payload: - if original_payload["mode"] != fixed_payload["mode"]: - changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}") + if ( + "mode" in original_payload + and "mode" in fixed_payload + and original_payload["mode"] != fixed_payload["mode"] + ): + changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}") if changes: print(f" Changes: {', '.join(changes)}")