diff --git a/app/achievements/daily_challenge.py b/app/achievements/daily_challenge.py
index b2b3f4c..fcca594 100644
--- a/app/achievements/daily_challenge.py
+++ b/app/achievements/daily_challenge.py
@@ -32,11 +32,9 @@ async def process_streak(
).first()
if not stats:
return False
- if streak <= stats.daily_streak_best < next_streak:
- return True
- elif next_streak == 0 and stats.daily_streak_best >= streak:
- return True
- return False
+ return bool(
+ streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
+ )
MEDALS = {
diff --git a/app/achievements/hush_hush.py b/app/achievements/hush_hush.py
index a887693..bc9c53f 100644
--- a/app/achievements/hush_hush.py
+++ b/app/achievements/hush_hush.py
@@ -68,9 +68,7 @@ async def to_the_core(
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
return False
mods_ = mod_to_save(score.mods)
- if "DT" not in mods_ or "NC" not in mods_:
- return False
- return True
+ return not ("DT" not in mods_ or "NC" not in mods_)
async def wysi(
@@ -83,9 +81,7 @@ async def wysi(
return False
if str(round(score.accuracy, ndigits=4))[3:] != "727":
return False
- if "xi" not in beatmap.beatmapset.artist:
- return False
- return True
+ return "xi" in beatmap.beatmapset.artist
async def prepared(
@@ -97,9 +93,7 @@ async def prepared(
if score.rank != Rank.X and score.rank != Rank.XH:
return False
mods_ = mod_to_save(score.mods)
- if "NF" not in mods_:
- return False
- return True
+ return "NF" in mods_
async def reckless_adandon(
@@ -117,9 +111,7 @@ async def reckless_adandon(
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
- if attribute.star_rating < 3:
- return False
- return True
+ return not attribute.star_rating < 3
async def lights_out(
@@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth(
return False
for mod in score.mods:
- if mod.get("acronym") == "AC":
- if "settings" in mod and "minimum_accuracy" in mod["settings"]:
- target_accuracy = mod["settings"]["minimum_accuracy"]
- if isinstance(target_accuracy, int | float):
- return abs(score.accuracy - float(target_accuracy)) < 0.0001
+ if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
+ target_accuracy = mod["settings"]["minimum_accuracy"]
+ if isinstance(target_accuracy, int | float):
+ return abs(score.accuracy - float(target_accuracy)) < 0.0001
return False
diff --git a/app/achievements/mods.py b/app/achievements/mods.py
index d157a15..0f3c728 100644
--- a/app/achievements/mods.py
+++ b/app/achievements/mods.py
@@ -19,9 +19,7 @@ async def process_mod(
return False
if not beatmap.beatmap_status.has_leaderboard():
return False
- if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
- return False
- return True
+ return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
async def process_category_mod(
diff --git a/app/achievements/osu_combo.py b/app/achievements/osu_combo.py
index 25abbbd..7039bae 100644
--- a/app/achievements/osu_combo.py
+++ b/app/achievements/osu_combo.py
@@ -22,11 +22,7 @@ async def process_combo(
return False
if next_combo != 0 and combo >= next_combo:
return False
- if combo <= score.max_combo < next_combo:
- return True
- elif next_combo == 0 and score.max_combo >= combo:
- return True
- return False
+ return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
MEDALS: Medals = {
diff --git a/app/achievements/osu_playcount.py b/app/achievements/osu_playcount.py
index 934e1c0..b5e1e9b 100644
--- a/app/achievements/osu_playcount.py
+++ b/app/achievements/osu_playcount.py
@@ -35,11 +35,7 @@ async def process_playcount(
).first()
if not stats:
return False
- if pc <= stats.play_count < next_pc:
- return True
- elif next_pc == 0 and stats.play_count >= pc:
- return True
- return False
+ return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
MEDALS: Medals = {
diff --git a/app/achievements/skill.py b/app/achievements/skill.py
index 66123cc..43993d8 100644
--- a/app/achievements/skill.py
+++ b/app/achievements/skill.py
@@ -47,9 +47,7 @@ async def process_skill(
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False
- if type == "fc" and not score.is_perfect_combo:
- return False
- return True
+ return not (type == "fc" and not score.is_perfect_combo)
MEDALS: Medals = {
diff --git a/app/achievements/total_hits.py b/app/achievements/total_hits.py
index 5f3d13d..93fb2c5 100644
--- a/app/achievements/total_hits.py
+++ b/app/achievements/total_hits.py
@@ -35,11 +35,7 @@ async def process_tth(
).first()
if not stats:
return False
- if tth <= stats.total_hits < next_tth:
- return True
- elif next_tth == 0 and stats.play_count >= tth:
- return True
- return False
+ return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
MEDALS: Medals = {
diff --git a/app/auth.py b/app/auth.py
index 6e43f8b..8bef0a8 100644
--- a/app/auth.py
+++ b/app/auth.py
@@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
2. MD5哈希 -> bcrypt验证
"""
# 1. 明文密码转 MD5
- pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
+ pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
# 2. 检查缓存
if bcrypt_hash in bcrypt_cache:
@@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def get_password_hash(password: str) -> str:
"""生成密码哈希 - 使用 osu! 的方式"""
# 1. 明文密码 -> MD5
- pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
+ pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
# 2. MD5 -> bcrypt
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
return pw_bcrypt.decode()
@@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
验证用户身份 - 使用类似 from_login 的逻辑
"""
# 1. 明文密码转 MD5
- pw_md5 = hashlib.md5(password.encode()).hexdigest()
+ pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
# 2. 根据用户名查找用户
user = None
@@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str:
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
"""
- if settings.totp_use_username_in_label:
- # 使用用户名作为主要标识
- primary_identifier = user.username
- else:
- # 使用邮箱作为标识
- primary_identifier = user.email
+ primary_identifier = user.username if settings.totp_use_username_in_label else user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name:
diff --git a/app/calculator.py b/app/calculator.py
index a408fcf..7c9eb7c 100644
--- a/app/calculator.py
+++ b/app/calculator.py
@@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
if len(hit_objects) > i + per_1s:
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
return True
- elif len(hit_objects) > i + per_10s:
- if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
- return True
+ elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
+ return True
return False
@@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
def is_2b(hit_objects: list[HitObject]) -> bool:
- for i in range(0, len(hit_objects) - 1):
- if hit_objects[i] == hit_objects[i + 1].start_time:
- return True
- return False
+ return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
def is_suspicious_beatmap(content: str) -> bool:
diff --git a/app/config.py b/app/config.py
index adb7794..b09c233 100644
--- a/app/config.py
+++ b/app/config.py
@@ -217,7 +217,7 @@ STORAGE_SETTINGS='{
# 服务器设置
host: Annotated[
str,
- Field(default="0.0.0.0", description="服务器监听地址"),
+ Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
"服务器设置",
]
port: Annotated[
@@ -609,26 +609,26 @@ STORAGE_SETTINGS='{
]
@field_validator("fetcher_scopes", mode="before")
+ @classmethod
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
@field_validator("storage_settings", mode="after")
+ @classmethod
def validate_storage_settings(
cls,
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
info: ValidationInfo,
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
- if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
- if not isinstance(v, CloudflareR2Settings):
- raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
- elif info.data.get("storage_service") == StorageServiceType.LOCAL:
- if not isinstance(v, LocalStorageSettings):
- raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
- elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
- if not isinstance(v, AWSS3StorageSettings):
- raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
+ service = info.data.get("storage_service")
+ if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
+ raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
+ if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
+ raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
+ if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
+ raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
return v
diff --git a/app/database/beatmap.py b/app/database/beatmap.py
index ff920b1..029915f 100644
--- a/app/database/beatmap.py
+++ b/app/database/beatmap.py
@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
- async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
+ async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
- beatmap = Beatmap.model_validate(
+ beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
- beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
- return beatmap
+ return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
- key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
+ key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py
index d2df25e..7829c8c 100644
--- a/app/database/beatmapset.py
+++ b/app/database/beatmapset.py
@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
- async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
+ async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
return beatmapset
@classmethod
- async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
+ async def from_resp(
+ cls,
+ session: AsyncSession,
+ resp: "BeatmapsetResp",
+ from_: int = 0,
+ ) -> "Beatmapset":
from .beatmap import Beatmap
- beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
+ beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset)
await session.commit()
diff --git a/app/database/chat.py b/app/database/chat.py
index f0d1c47..b05c790 100644
--- a/app/database/chat.py
+++ b/app/database/chat.py
@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
)
).first()
- last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
- if last_msg and last_msg.isdigit():
- last_msg = int(last_msg)
- else:
- last_msg = None
+ last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
+ last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
- last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
- if last_read_id and last_read_id.isdigit():
- last_read_id = int(last_read_id)
- else:
- last_read_id = last_msg
+ last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
+ last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
if silence is not None:
attribute = ChatUserAttributes(
diff --git a/app/database/score.py b/app/database/score.py
index fd9668f..172e8a4 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -520,12 +520,11 @@ async def _score_where(
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
else:
return None
- elif type == LeaderboardType.TEAM:
- if user:
- team_membership = await user.awaitable_attrs.team_membership
- if team_membership:
- team_id = team_membership.team_id
- wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
+ elif type == LeaderboardType.TEAM and user:
+ team_membership = await user.awaitable_attrs.team_membership
+ if team_membership:
+ team_id = team_membership.team_id
+ wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
if mods:
if user and user.is_supporter:
wheres.append(
diff --git a/app/database/user.py b/app/database/user.py
index 053c1b0..fbc9eed 100644
--- a/app/database/user.py
+++ b/app/database/user.py
@@ -256,8 +256,6 @@ class UserResp(UserBase):
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
- *,
- token_id: int | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
@@ -310,16 +308,16 @@ class UserResp(UserBase):
).all()
]
- if "team" in include:
- if team_membership := await obj.awaitable_attrs.team_membership:
- u.team = team_membership.team
+ if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
+ u.team = team_membership.team
if "account_history" in include:
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
- if "daily_challenge_user_stats":
- if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
- u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
+ if "daily_challenge_user_stats" in include and (
+ daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
+ ):
+ u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include:
current_stattistics = None
@@ -443,7 +441,7 @@ class MeResp(UserResp):
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
- u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
+ u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id
diff --git a/app/dependencies/param.py b/app/dependencies/param.py
index 174adde..9e640bd 100644
--- a/app/dependencies/param.py
+++ b/app/dependencies/param.py
@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ValidationError
-def BodyOrForm[T: BaseModel](model: type[T]):
+def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
async def dependency(
request: Request,
) -> T:
diff --git a/app/dependencies/user.py b/app/dependencies/user.py
index ff3ff53..7061bcb 100644
--- a/app/dependencies/user.py
+++ b/app/dependencies/user.py
@@ -119,10 +119,7 @@ async def get_client_user(
if verify_method is None:
# 智能选择验证方式(有TOTP优先TOTP)
totp_key = await user.awaitable_attrs.totp_key
- if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
- verify_method = "totp"
- else:
- verify_method = "mail"
+ verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
# 设置选择的验证方法到Redis中,避免重复选择
if api_version >= 20250913:
diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py
index 8801fef..a67c6f6 100644
--- a/app/fetcher/beatmapset.py
+++ b/app/fetcher/beatmapset.py
@@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
- cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
+ cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
logger.opt(colors=True).debug(f"[CacheKey] Query: {cache_data}, Hash: {cache_hash}")
@@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher):
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e:
- logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}")
+ logger.warning(f"Cache data invalid, fetching from API: {e}")
# 缓存未命中,从 API 获取数据
- logger.opt(colors=True).debug("Cache miss, fetching from API")
+ logger.debug("Cache miss, fetching from API")
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
@@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
- logger.opt(colors=True).info("Prefetch skipped due to rate limit")
+ logger.info("Prefetch skipped due to rate limit")
bg_tasks.add_task(delayed_prefetch)
@@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
- logger.opt(colors=True).debug(f"Prefetching page {page + 1}")
+ logger.debug(f"Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
- logger.opt(colors=True).debug(f"Page {page + 1} already cached")
+ logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
@@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher):
cursor = data["cursor"]
continue
except Exception:
- pass
+ logger.warning("Failed to parse cached data for cursor")
break
# 在预取页面之间添加延迟,避免突发请求
@@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher):
ex=prefetch_ttl,
)
- logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
+ logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
except RateLimitError:
- logger.opt(colors=True).info("Prefetch stopped due to rate limit")
+ logger.info("Prefetch stopped due to rate limit")
except Exception as e:
- logger.opt(colors=True).warning(f"Prefetch failed: {e}")
+ logger.warning(f"Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
- logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
+ logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
for i, (query, cursor) in enumerate(homepage_queries):
try:
@@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存
if await redis_client.exists(cache_key):
- logger.opt(colors=True).debug(f"Query {query.sort} already cached")
+ logger.debug(f"Query {query.sort} already cached")
continue
# 请求并缓存
@@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher):
ex=cache_ttl,
)
- logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
+ logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
if api_response.get("cursor"):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
- logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
+ logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
except RateLimitError:
- logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit")
+ logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
except Exception as e:
- logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}")
+ logger.error(f"Failed to warmup cache for {query.sort}: {e}")
diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/helpers/geoip_helper.py b/app/helpers/geoip_helper.py
index 3d65371..12c7c3c 100644
--- a/app/helpers/geoip_helper.py
+++ b/app/helpers/geoip_helper.py
@@ -1,19 +1,39 @@
"""
-GeoLite2 Helper Class
+GeoLite2 Helper Class (asynchronous)
"""
from __future__ import annotations
+import asyncio
+from contextlib import suppress
import os
from pathlib import Path
import shutil
import tarfile
import tempfile
import time
+from typing import Any, Required, TypedDict
+from app.log import logger
+
+import aiofiles
import httpx
import maxminddb
+
+class GeoIPLookupResult(TypedDict, total=False):
+ ip: Required[str]
+ country_iso: str
+ country_name: str
+ city_name: str
+ latitude: str
+ longitude: str
+ time_zone: str
+ postal_code: str
+ asn: int | None
+ organization: str
+
+
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
@@ -25,161 +45,184 @@ EDITIONS = {
class GeoIPHelper:
def __init__(
self,
- dest_dir="./geoip",
- license_key=None,
- editions=None,
- max_age_days=8,
- timeout=60.0,
+ dest_dir: str | Path = Path("./geoip"),
+ license_key: str | None = None,
+ editions: list[str] | None = None,
+ max_age_days: int = 8,
+ timeout: float = 60.0,
):
- self.dest_dir = dest_dir
+ self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
- self.editions = editions or ["City", "ASN"]
+ self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days
self.timeout = timeout
- self._readers = {}
+ self._readers: dict[str, maxminddb.Reader] = {}
+ self._update_lock = asyncio.Lock()
@staticmethod
- def _safe_extract(tar: tarfile.TarFile, path: str):
- base = Path(path).resolve()
- for m in tar.getmembers():
- target = (base / m.name).resolve()
- if not str(target).startswith(str(base)):
+ def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
+ base = path.resolve()
+ for member in tar.getmembers():
+ target = (base / member.name).resolve()
+ if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file")
- tar.extractall(path=path, filter="data")
+ tar.extractall(path=base, filter="data")
- def _download_and_extract(self, edition_id: str) -> str:
- """
- 下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb
- - 跟随 302 重定向
- - 流式下载到临时文件
- - 临时目录退出后自动清理
- """
+ @staticmethod
+ def _as_mapping(value: Any) -> dict[str, Any]:
+ return value if isinstance(value, dict) else {}
+
+ @staticmethod
+ def _as_str(value: Any, default: str = "") -> str:
+ if isinstance(value, str):
+ return value
+ if value is None:
+ return default
+ return str(value)
+
+ @staticmethod
+ def _as_int(value: Any) -> int | None:
+ return value if isinstance(value, int) else None
+
+ @staticmethod
+ def _extract_tarball(src: Path, dest: Path) -> None:
+ with tarfile.open(src, "r:gz") as tar:
+ GeoIPHelper._safe_extract(tar, dest)
+
+ @staticmethod
+ def _find_mmdb(root: Path) -> Path | None:
+ for candidate in root.rglob("*.mmdb"):
+ return candidate
+ return None
+
+ def _latest_file_sync(self, edition_id: str) -> Path | None:
+ directory = self.dest_dir
+ if not directory.is_dir():
+ return None
+ candidates = list(directory.glob(f"{edition_id}*.mmdb"))
+ if not candidates:
+ return None
+ return max(candidates, key=lambda p: p.stat().st_mtime)
+
+ async def _latest_file(self, edition_id: str) -> Path | None:
+ return await asyncio.to_thread(self._latest_file_sync, edition_id)
+
+ async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key:
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
+ tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
- with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
- with client.stream("GET", url) as resp:
+ try:
+ tgz_path = tmp_dir / "db.tgz"
+ async with (
+ httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
+ client.stream("GET", url) as resp,
+ ):
resp.raise_for_status()
- with tempfile.TemporaryDirectory() as tmpd:
- tgz_path = os.path.join(tmpd, "db.tgz")
- # 流式写入
- with open(tgz_path, "wb") as f:
- for chunk in resp.iter_bytes():
- if chunk:
- f.write(chunk)
+ async with aiofiles.open(tgz_path, "wb") as download_file:
+ async for chunk in resp.aiter_bytes():
+ if chunk:
+ await download_file.write(chunk)
- # 解压并只移动 .mmdb
- with tarfile.open(tgz_path, "r:gz") as tar:
- # 先安全检查与解压
- self._safe_extract(tar, tmpd)
+ await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
+ mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
+ if mmdb_path is None:
+ raise RuntimeError("未在压缩包中找到 .mmdb 文件")
- # 递归找 .mmdb
- mmdb_path = None
- for root, _, files in os.walk(tmpd):
- for fn in files:
- if fn.endswith(".mmdb"):
- mmdb_path = os.path.join(root, fn)
- break
- if mmdb_path:
- break
+ await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
+ dst = self.dest_dir / mmdb_path.name
+ await asyncio.to_thread(shutil.move, mmdb_path, dst)
+ return dst
+ finally:
+ await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
- if not mmdb_path:
- raise RuntimeError("未在压缩包中找到 .mmdb 文件")
+ async def update(self, force: bool = False) -> None:
+ async with self._update_lock:
+ for edition in self.editions:
+ edition_id = EDITIONS[edition]
+ path = await self._latest_file(edition_id)
+ need_download = force or path is None
- os.makedirs(self.dest_dir, exist_ok=True)
- dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
- shutil.move(mmdb_path, dst)
- return dst
-
- def _latest_file(self, edition_id: str):
- if not os.path.isdir(self.dest_dir):
- return None
- files = [
- os.path.join(self.dest_dir, f)
- for f in os.listdir(self.dest_dir)
- if f.startswith(edition_id) and f.endswith(".mmdb")
- ]
- return max(files, key=os.path.getmtime) if files else None
-
- def update(self, force=False):
- from app.log import logger
-
- for ed in self.editions:
- eid = EDITIONS[ed]
- path = self._latest_file(eid)
- need = force or not path
-
- if path:
- age_days = (time.time() - os.path.getmtime(path)) / 86400
- if age_days >= self.max_age_days:
- need = True
- logger.info(
- f"{eid} database is {age_days:.1f} days old "
- f"(max: {self.max_age_days}), will download new version"
- )
+ if path:
+ mtime = await asyncio.to_thread(path.stat)
+ age_days = (time.time() - mtime.st_mtime) / 86400
+ if age_days >= self.max_age_days:
+ need_download = True
+ logger.info(
+ f"{edition_id} database is {age_days:.1f} days old "
+ f"(max: {self.max_age_days}), will download new version"
+ )
+ else:
+ logger.info(
+ f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
+ )
else:
- logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
- else:
- logger.info(f"{eid} database not found, will download")
+ logger.info(f"{edition_id} database not found, will download")
- if need:
- logger.info(f"Downloading {eid} database...")
- path = self._download_and_extract(eid)
- logger.info(f"{eid} database downloaded successfully")
- else:
- logger.info(f"Using existing {eid} database")
+ if need_download:
+ logger.info(f"Downloading {edition_id} database...")
+ path = await self._download_and_extract(edition_id)
+ logger.info(f"{edition_id} database downloaded successfully")
+ else:
+ logger.info(f"Using existing {edition_id} database")
- old = self._readers.get(ed)
- if old:
- try:
- old.close()
- except Exception:
- pass
- if path is not None:
- self._readers[ed] = maxminddb.open_database(path)
+ old_reader = self._readers.get(edition)
+ if old_reader:
+ with suppress(Exception):
+ old_reader.close()
+ if path is not None:
+ self._readers[edition] = maxminddb.open_database(str(path))
- def lookup(self, ip: str):
- res = {"ip": ip}
- # City
- city_r = self._readers.get("City")
- if city_r:
- data = city_r.get(ip)
- if data:
- country = data.get("country") or {}
- res["country_iso"] = country.get("iso_code") or ""
- res["country_name"] = (country.get("names") or {}).get("en", "")
- city = data.get("city") or {}
- res["city_name"] = (city.get("names") or {}).get("en", "")
- loc = data.get("location") or {}
- res["latitude"] = str(loc.get("latitude") or "")
- res["longitude"] = str(loc.get("longitude") or "")
- res["time_zone"] = str(loc.get("time_zone") or "")
- postal = data.get("postal") or {}
- if "code" in postal:
- res["postal_code"] = postal["code"]
- # ASN
- asn_r = self._readers.get("ASN")
- if asn_r:
- data = asn_r.get(ip)
- if data:
- res["asn"] = data.get("autonomous_system_number")
- res["organization"] = data.get("autonomous_system_organization")
+ def lookup(self, ip: str) -> GeoIPLookupResult:
+ res: GeoIPLookupResult = {"ip": ip}
+ city_reader = self._readers.get("City")
+ if city_reader:
+ data = city_reader.get(ip)
+ if isinstance(data, dict):
+ country = self._as_mapping(data.get("country"))
+ res["country_iso"] = self._as_str(country.get("iso_code"))
+ country_names = self._as_mapping(country.get("names"))
+ res["country_name"] = self._as_str(country_names.get("en"))
+
+ city = self._as_mapping(data.get("city"))
+ city_names = self._as_mapping(city.get("names"))
+ res["city_name"] = self._as_str(city_names.get("en"))
+
+ location = self._as_mapping(data.get("location"))
+ latitude = location.get("latitude")
+ longitude = location.get("longitude")
+ res["latitude"] = str(latitude) if latitude is not None else ""
+ res["longitude"] = str(longitude) if longitude is not None else ""
+ res["time_zone"] = self._as_str(location.get("time_zone"))
+
+ postal = self._as_mapping(data.get("postal"))
+ postal_code = postal.get("code")
+ if postal_code is not None:
+ res["postal_code"] = self._as_str(postal_code)
+
+ asn_reader = self._readers.get("ASN")
+ if asn_reader:
+ data = asn_reader.get(ip)
+ if isinstance(data, dict):
+ res["asn"] = self._as_int(data.get("autonomous_system_number"))
+ res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res
- def close(self):
- for r in self._readers.values():
- try:
- r.close()
- except Exception:
- pass
+ def close(self) -> None:
+ for reader in self._readers.values():
+ with suppress(Exception):
+ reader.close()
self._readers = {}
if __name__ == "__main__":
- # 示例用法
- geo = GeoIPHelper(dest_dir="./geoip", license_key="")
- geo.update()
- print(geo.lookup("8.8.8.8"))
- geo.close()
+
+ async def _demo() -> None:
+ geo = GeoIPHelper(dest_dir="./geoip", license_key="")
+ await geo.update()
+ print(geo.lookup("8.8.8.8"))
+ geo.close()
+
+ asyncio.run(_demo())
diff --git a/app/log.py b/app/log.py
index ba0ef95..d2c5060 100644
--- a/app/log.py
+++ b/app/log.py
@@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler):
status_color = "green"
elif 300 <= status < 400:
status_color = "yellow"
- elif 400 <= status < 500:
- status_color = "red"
- elif 500 <= status < 600:
+ elif 400 <= status < 500 or 500 <= status < 600:
status_color = "red"
return (
diff --git a/app/middleware/verify_session.py b/app/middleware/verify_session.py
index 58ed754..dee6332 100644
--- a/app/middleware/verify_session.py
+++ b/app/middleware/verify_session.py
@@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# 启动验证流程
- return await self._initiate_verification(request, session_state)
+ return await self._initiate_verification(session_state)
def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证"""
@@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return True
# 非API请求跳过
- if not path.startswith("/api/"):
- return True
-
- return False
+ return bool(not path.startswith("/api/"))
def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证"""
@@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
logger.error(f"Error getting session state: {e}")
return None
- async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
+ async def _initiate_verification(self, state: SessionState) -> Response:
"""启动验证流程"""
try:
method = await state.get_method()
diff --git a/app/models/extended_auth.py b/app/models/extended_auth.py
index b3fc831..35a3752 100644
--- a/app/models/extended_auth.py
+++ b/app/models/extended_auth.py
@@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None
- token_type: str = "Bearer"
+ token_type: str = "Bearer" # noqa: S105
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
@@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel):
requires_second_factor: bool = False
verification_message: str | None = None
user_id: int | None = None # 用于二次验证的用户ID
-
-
-class SessionState(BaseModel):
- """会话状态"""
-
- user_id: int
- username: str
- email: str
- requires_verification: bool
- session_token: str | None = None
- verification_sent: bool = False
diff --git a/app/models/notification.py b/app/models/notification.py
index cc95e4c..ceef3b0 100644
--- a/app/models/notification.py
+++ b/app/models/notification.py
@@ -1,3 +1,4 @@
+# ruff: noqa: ARG002
from __future__ import annotations
from abc import abstractmethod
diff --git a/app/models/oauth.py b/app/models/oauth.py
index f3db41f..ce4cabf 100644
--- a/app/models/oauth.py
+++ b/app/models/oauth.py
@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
class TokenResponse(BaseModel):
access_token: str
- token_type: str = "Bearer"
+ token_type: str = "Bearer" # noqa: S105
expires_in: int
refresh_token: str
scope: str = "*"
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
class OAuth2ClientCredentialsBearer(OAuth2):
def __init__(
self,
- tokenUrl: Annotated[
+ tokenUrl: Annotated[ # noqa: N803
str,
Doc(
"""
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
"""
),
],
- refreshUrl: Annotated[
+ refreshUrl: Annotated[ # noqa: N803
str | None,
Doc(
"""
diff --git a/app/models/v1_user.py b/app/models/v1_user.py
index 9868260..ea7d177 100644
--- a/app/models/v1_user.py
+++ b/app/models/v1_user.py
@@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel):
class PlayerEventItem(BaseModel):
"""玩家事件项目"""
- userId: int
+ userId: int # noqa: N815
name: str
- mapId: int | None = None
- setId: int | None = None
+ mapId: int | None = None # noqa: N815
+ setId: int | None = None # noqa: N815
artist: str | None = None
title: str | None = None
version: str | None = None
@@ -88,7 +88,7 @@ class PlayerInfo(BaseModel):
custom_badge_icon: str
custom_badge_color: str
userpage_content: str
- recentFailed: int
+ recentFailed: int # noqa: N815
social_discord: str | None = None
social_youtube: str | None = None
social_twitter: str | None = None
diff --git a/app/router/auth.py b/app/router/auth.py
index ff163b3..1512642 100644
--- a/app/router/auth.py
+++ b/app/router/auth.py
@@ -126,21 +126,22 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
- country_code = "CN" # 默认国家代码
+ country_code = None # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
- if geo_info and geo_info.get("country_iso"):
- country_code = geo_info["country_iso"]
+ if geo_info and (country_code := geo_info.get("country_iso")):
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
+ if country_code is None:
+ country_code = "CN"
# 创建新用户
- # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
+ # 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot)
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
@@ -157,7 +158,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
- country_code=country_code, # 根据 IP 地理位置设置国家
+ country_code=country_code,
join_date=utcnow(),
last_visit=utcnow(),
is_supporter=settings.enable_supporter_for_all_users,
@@ -386,7 +387,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
- token_type="Bearer",
+ token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=scope,
@@ -439,7 +440,7 @@ async def oauth_token(
)
return TokenResponse(
access_token=access_token,
- token_type="Bearer",
+ token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token,
scope=scope,
@@ -509,7 +510,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
- token_type="Bearer",
+ token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
@@ -554,7 +555,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
- token_type="Bearer",
+ token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
diff --git a/app/router/lio.py b/app/router/lio.py
index 969b214..93e0088 100644
--- a/app/router/lio.py
+++ b/app/router/lio.py
@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
"allowed_mods": item_data.get("allowed_mods", []),
"expired": bool(item_data.get("expired", False)),
"playlist_order": item_data.get("playlist_order", default_order),
- "played_at": item_data.get("played_at", None),
+ "played_at": item_data.get("played_at"),
"freestyle": bool(item_data.get("freestyle", True)),
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
"star_rating": item_data.get("star_rating", 0.0),
diff --git a/app/router/notification/banchobot.py b/app/router/notification/banchobot.py
index 7b01347..a491b7d 100644
--- a/app/router/notification/banchobot.py
+++ b/app/router/notification/banchobot.py
@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
@bot.command("roll")
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
- if len(args) > 0 and args[0].isdigit():
- r = random.randint(1, int(args[0]))
- else:
- r = random.randint(1, 100)
+ r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
return f"{user.username} rolls {r} point(s)"
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
if gamemode is None:
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
- if last_score is not None:
- gamemode = last_score.gamemode
- else:
- gamemode = target_user.playmode
+ gamemode = last_score.gamemode if last_score is not None else target_user.playmode
statistics = (
await session.exec(
diff --git a/app/router/notification/server.py b/app/router/notification/server.py
index 9940801..0732124 100644
--- a/app/router/notification/server.py
+++ b/app/router/notification/server.py
@@ -313,10 +313,7 @@ async def chat_websocket(
# 优先使用查询参数中的token,支持token或access_token参数名
auth_token = token or access_token
if not auth_token and authorization:
- if authorization.startswith("Bearer "):
- auth_token = authorization[7:]
- else:
- auth_token = authorization
+ auth_token = authorization.removeprefix("Bearer ")
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")
diff --git a/app/router/redirect.py b/app/router/redirect.py
index 7805fc4..bec9eca 100644
--- a/app/router/redirect.py
+++ b/app/router/redirect.py
@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
redirect_router = APIRouter(include_in_schema=False)
-@redirect_router.get("/users/{path:path}")
+@redirect_router.get("/users/{path:path}") # noqa: FAST003
@redirect_router.get("/teams/{team_id}")
@redirect_router.get("/u/{user_id}")
@redirect_router.get("/b/{beatmap_id}")
diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py
index b723713..6ca3775 100644
--- a/app/router/v1/beatmap.py
+++ b/app/router/v1/beatmap.py
@@ -168,10 +168,7 @@ async def get_beatmaps(
elif beatmapset_id is not None:
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
await beatmapset.awaitable_attrs.beatmaps
- if len(beatmapset.beatmaps) > limit:
- beatmaps = beatmapset.beatmaps[:limit]
- else:
- beatmaps = beatmapset.beatmaps
+ beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
elif user is not None:
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py
index 024a542..152e8f0 100644
--- a/app/router/v2/beatmap.py
+++ b/app/router/v2/beatmap.py
@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
- key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
+ key = (
+ f"beatmap:{beatmap_id}:{ruleset}:"
+ f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
+ )
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py
index c4f2561..3cfcdb1 100644
--- a/app/router/v2/beatmapset.py
+++ b/app/router/v2/beatmapset.py
@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
- db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
@@ -104,7 +103,7 @@ async def search_beatmapset(
if cached_result:
sets = SearchBeatmapsetsResp(**cached_result)
# 处理资源代理
- processed_sets = await process_response_assets(sets, request)
+ processed_sets = await process_response_assets(sets)
return processed_sets
try:
@@ -115,7 +114,7 @@ async def search_beatmapset(
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
# 处理资源代理
- processed_sets = await process_response_assets(sets, request)
+ processed_sets = await process_response_assets(sets)
return processed_sets
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
if cached_resp:
# 处理资源代理
- processed_resp = await process_response_assets(cached_resp, request)
+ processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
# 处理资源代理
- processed_resp = await process_response_assets(resp, request)
+ processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
@@ -176,7 +175,7 @@ async def get_beatmapset(
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
if cached_resp:
# 处理资源代理
- processed_resp = await process_response_assets(cached_resp, request)
+ processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -187,7 +186,7 @@ async def get_beatmapset(
await cache_service.cache_beatmapset(resp)
# 处理资源代理
- processed_resp = await process_response_assets(resp, request)
+ processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc
diff --git a/app/router/v2/room.py b/app/router/v2/room.py
index decf936..34fb6f7 100644
--- a/app/router/v2/room.py
+++ b/app/router/v2/room.py
@@ -166,7 +166,6 @@ async def get_room(
db: Database,
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
- redis: Redis,
category: Annotated[
str,
Query(
diff --git a/app/router/v2/score.py b/app/router/v2/score.py
index 2b47c78..12be45d 100644
--- a/app/router/v2/score.py
+++ b/app/router/v2/score.py
@@ -847,10 +847,7 @@ async def reorder_score_pin(
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
- if after_score_id:
- target_order = reference_score.pinned_order + 1
- else:
- target_order = reference_score.pinned_order
+ target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
current_order = score_record.pinned_order
diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py
index 424b988..add3d70 100644
--- a/app/router/v2/session_verify.py
+++ b/app/router/v2/session_verify.py
@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
message: str
-class VerifyFailed(Exception):
+class VerifyFailedError(Exception):
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
super().__init__(message)
self.reason = reason
@@ -93,10 +93,7 @@ async def verify_session(
# 智能选择验证方法(参考osu-web实现)
# API版本较老或用户未设置TOTP时强制使用邮件验证
# print(api_version, totp_key)
- if api_version < 20240101 or totp_key is None:
- verify_method = "mail"
- else:
- verify_method = "totp"
+ verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method
@@ -109,7 +106,7 @@ async def verify_session(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
verify_method = "mail"
- raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
+ raise VerifyFailedError("用户TOTP已被删除,已切换到邮件验证")
# 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里
@@ -120,16 +117,16 @@ async def verify_session(
else:
# 记录详细的验证失败原因(参考osu-web的错误处理)
if len(verification_key) != 6:
- raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
+ raise VerifyFailedError("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
elif not verification_key.isdigit():
- raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
+ raise VerifyFailedError("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
else:
# 可能是密钥错误或者重放攻击
- raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
+ raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
else:
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
if not success:
- raise VerifyFailed(f"邮件验证失败: {message}")
+ raise VerifyFailedError(f"邮件验证失败: {message}")
await LoginLogService.record_login(
db=db,
@@ -144,7 +141,7 @@ async def verify_session(
await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
- except VerifyFailed as e:
+ except VerifyFailedError as e:
await LoginLogService.record_failed_login(
db=db,
request=request,
@@ -171,7 +168,9 @@ async def verify_session(
)
error_response["reissued"] = True
except Exception:
- pass # 忽略重发邮件失败的错误
+ log("Verification").exception(
+ f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
+ )
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
diff --git a/app/router/v2/tags.py b/app/router/v2/tags.py
index 644cd77..810656d 100644
--- a/app/router/v2/tags.py
+++ b/app/router/v2/tags.py
@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
)
).first()
- if user_beatmap_score is None:
- return False
- return True
+ return user_beatmap_score is not None
@router.put(
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
.where(BeatmapTagVote.user_id == current_user.id)
)
).first()
- if previous_votes is None:
- if check_user_can_vote(current_user, beatmap_id, session):
- new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
- session.add(new_vote)
+ if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
+ new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
+ session.add(new_vote)
await session.commit()
except ValueError:
raise HTTPException(400, "Tag is not found")
diff --git a/app/router/v2/user.py b/app/router/v2/user.py
index 98a0b93..e4e41ea 100644
--- a/app/router/v2/user.py
+++ b/app/router/v2/user.py
@@ -91,7 +91,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=cached_users)
- processed_response = await process_response_assets(response, request)
+ processed_response = await process_response_assets(response)
return processed_response
else:
searched_users = (await session.exec(select(User).limit(50))).all()
@@ -109,7 +109,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=users)
- processed_response = await process_response_assets(response, request)
+ processed_response = await process_response_assets(response)
return processed_response
@@ -240,7 +240,7 @@ async def get_user_info(
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
# 处理资源代理
- processed_user = await process_response_assets(cached_user, request)
+ processed_user = await process_response_assets(cached_user)
return processed_user
searched_user = (
@@ -263,7 +263,7 @@ async def get_user_info(
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
- processed_user = await process_response_assets(user_resp, request)
+ processed_user = await process_response_assets(user_resp)
return processed_user
@@ -381,7 +381,7 @@ async def get_user_scores(
user_id, type, include_fails, mode, limit, offset, is_legacy_api
)
if cached_scores is not None:
- processed_scores = await process_response_assets(cached_scores, request)
+ processed_scores = await process_response_assets(cached_scores)
return processed_scores
db_user = await session.get(User, user_id)
@@ -438,5 +438,5 @@ async def get_user_scores(
)
# 处理资源代理
- processed_scores = await process_response_assets(score_responses, request)
+ processed_scores = await process_response_assets(score_responses)
return processed_scores
diff --git a/app/service/asset_proxy_helper.py b/app/service/asset_proxy_helper.py
index c654821..c41e77c 100644
--- a/app/service/asset_proxy_helper.py
+++ b/app/service/asset_proxy_helper.py
@@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request
-async def process_response_assets(data: Any, request: Request) -> Any:
+async def process_response_assets(data: Any) -> Any:
"""
根据配置处理响应数据中的资源URL
@@ -72,7 +72,7 @@ def asset_proxy_response(func):
# 如果有request对象且启用了资源代理,则处理响应
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
- result = await process_response_assets(result, request)
+ result = await process_response_assets(result)
return result
diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py
index 3a81686..1a76195 100644
--- a/app/service/beatmap_cache_service.py
+++ b/app/service/beatmap_cache_service.py
@@ -113,6 +113,7 @@ class BeatmapCacheService:
if size:
total_size += size
except Exception:
+ logger.debug(f"Failed to get size for key {key}")
continue
return {
diff --git a/app/service/beatmapset_cache_service.py b/app/service/beatmapset_cache_service.py
index df23f20..f255c8c 100644
--- a/app/service/beatmapset_cache_service.py
+++ b/app/service/beatmapset_cache_service.py
@@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str:
def generate_hash(data) -> str:
"""生成数据的MD5哈希值"""
- if isinstance(data, str):
- content = data
- else:
- content = safe_json_dumps(data)
- return hashlib.md5(content.encode()).hexdigest()
+ content = data if isinstance(data, str) else safe_json_dumps(data)
+ return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
class BeatmapsetCacheService:
diff --git a/app/service/beatmapset_update_service.py b/app/service/beatmapset_update_service.py
index 8852146..0cc8704 100644
--- a/app/service/beatmapset_update_service.py
+++ b/app/service/beatmapset_update_service.py
@@ -110,9 +110,7 @@ class ProcessingBeatmapset:
changed_beatmaps = []
for bm in self.beatmapset.beatmaps:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
- if not saved:
- changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
- elif saved["is_deleted"]:
+ if not saved or saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm.checksum:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
@@ -285,7 +283,7 @@ class BeatmapsetUpdateService:
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
async with with_db() as session:
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
- new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset)
+ new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
if db_beatmapset:
await session.merge(new_beatmapset)
await session.commit()
@@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic
def get_beatmapset_update_service() -> BeatmapsetUpdateService:
+ if service is None:
+ raise ValueError("BeatmapsetUpdateService is not initialized")
assert service is not None, "BeatmapsetUpdateService is not initialized"
return service
diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py
index 6fa2f1a..0570493 100644
--- a/app/service/login_log_service.py
+++ b/app/service/login_log_service.py
@@ -128,7 +128,11 @@ class LoginLogService:
login_success=False,
login_method=login_method,
user_agent=user_agent,
- notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
+ notes=(
+ f"Failed login attempt on user {attempted_username}: {notes}"
+ if attempted_username
+ else "Failed login attempt"
+ ),
)
diff --git a/app/service/password_reset_service.py b/app/service/password_reset_service.py
index 5d831d6..429840f 100644
--- a/app/service/password_reset_service.py
+++ b/app/service/password_reset_service.py
@@ -120,7 +120,7 @@ class PasswordResetService:
await redis.delete(reset_code_key)
await redis.delete(rate_limit_key)
except Exception:
- pass
+ logger.warning("Failed to clean up Redis data after error")
logger.exception("Redis operation failed")
return False, "服务暂时不可用,请稍后重试"
diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py
index 8b4f3cf..0d9dd37 100644
--- a/app/service/ranking_cache_service.py
+++ b/app/service/ranking_cache_service.py
@@ -593,10 +593,7 @@ class RankingCacheService:
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
"""使地区排行榜缓存失效"""
try:
- if ruleset:
- pattern = f"country_ranking:{ruleset}:*"
- else:
- pattern = "country_ranking:*"
+ pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
@@ -608,10 +605,7 @@ class RankingCacheService:
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
"""使战队排行榜缓存失效"""
try:
- if ruleset:
- pattern = f"team_ranking:{ruleset}:*"
- else:
- pattern = "team_ranking:*"
+ pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
@@ -637,6 +631,7 @@ class RankingCacheService:
if size:
total_size += size
except Exception:
+ logger.warning(f"Failed to get memory usage for key {key}")
continue
return {
diff --git a/app/service/subscribers/__init__.py b/app/service/subscribers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/app/service/subscribers/chat.py b/app/service/subscribers/chat.py
index 9512f9c..9241a7f 100644
--- a/app/service/subscribers/chat.py
+++ b/app/service/subscribers/chat.py
@@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber):
self.add_handler(ON_NOTIFICATION, self.on_notification)
self.start()
- async def on_join_room(self, c: str, s: str):
+ async def on_join_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
- async def on_leave_room(self, c: str, s: str):
+ async def on_leave_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
- async def on_notification(self, c: str, s: str):
+ async def on_notification(self, c: str, s: str): # noqa: ARG002
try:
detail = TypeAdapter(NotificationDetails).validate_json(s)
except ValueError:
diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py
index ae0022e..4a40229 100644
--- a/app/service/user_cache_service.py
+++ b/app/service/user_cache_service.py
@@ -357,6 +357,7 @@ class UserCacheService:
if size:
total_size += size
except Exception:
+ logger.warning(f"Failed to get memory usage for key {key}")
continue
return {
diff --git a/app/service/verification_service.py b/app/service/verification_service.py
index d1f2ee4..f7cc0ea 100644
--- a/app/service/verification_service.py
+++ b/app/service/verification_service.py
@@ -288,10 +288,6 @@ This email was sent automatically, please do not reply.
redis: Redis,
user_id: int,
code: str,
- ip_address: str | None = None,
- user_agent: str | None = None,
- client_id: int | None = None,
- country_code: str | None = None,
) -> tuple[bool, str]:
"""验证邮箱验证码"""
try:
diff --git a/app/tasks/cache.py b/app/tasks/cache.py
index 4a684f6..21934c9 100644
--- a/app/tasks/cache.py
+++ b/app/tasks/cache.py
@@ -41,7 +41,7 @@ async def warmup_cache() -> None:
logger.info("Beatmap cache warmup completed successfully")
except Exception as e:
- logger.error("Beatmap cache warmup failed: %s", e)
+ logger.error(f"Beatmap cache warmup failed: {e}")
async def refresh_ranking_cache() -> None:
@@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None:
logger.info("Ranking cache refresh completed successfully")
except Exception as e:
- logger.error("Ranking cache refresh failed: %s", e)
+ logger.error(f"Ranking cache refresh failed: {e}")
async def schedule_user_cache_preload_task() -> None:
@@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None:
if active_user_ids:
user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids)
- logger.info("Preloaded cache for %s active users", len(user_ids))
+ logger.info(f"Preloaded cache for {len(user_ids)} active users")
else:
logger.info("No active users found for cache preload")
logger.info("User cache preload task completed successfully")
except Exception as e:
- logger.error("User cache preload task failed: %s", e)
+ logger.error(f"User cache preload task failed: {e}")
async def schedule_user_cache_warmup_task() -> None:
@@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None:
if top_users:
user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids)
- logger.info("Warmed cache for top 100 users in %s", mode)
+ logger.info(f"Warmed cache for top 100 users in {mode}")
await asyncio.sleep(1)
except Exception as e:
- logger.error("Failed to warm cache for %s: %s", mode, e)
+ logger.error(f"Failed to warm cache for {mode}: {e}")
continue
logger.info("User cache warmup task completed successfully")
except Exception as e:
- logger.error("User cache warmup task failed: %s", e)
+ logger.error(f"User cache warmup task failed: {e}")
async def schedule_user_cache_cleanup_task() -> None:
@@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None:
cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats()
- logger.info("User cache stats: %s", stats)
+ logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully")
except Exception as e:
- logger.error("User cache cleanup task failed: %s", e)
+ logger.error(f"User cache cleanup task failed: {e}")
async def warmup_user_cache() -> None:
@@ -167,7 +167,7 @@ async def warmup_user_cache() -> None:
try:
await schedule_user_cache_warmup_task()
except Exception as e:
- logger.error("User cache warmup failed: %s", e)
+ logger.error(f"User cache warmup failed: {e}")
async def preload_user_cache() -> None:
@@ -175,7 +175,7 @@ async def preload_user_cache() -> None:
try:
await schedule_user_cache_preload_task()
except Exception as e:
- logger.error("User cache preload failed: %s", e)
+ logger.error(f"User cache preload failed: {e}")
async def cleanup_user_cache() -> None:
@@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None:
try:
await schedule_user_cache_cleanup_task()
except Exception as e:
- logger.error("User cache cleanup failed: %s", e)
+ logger.error(f"User cache cleanup failed: {e}")
def register_cache_jobs() -> None:
diff --git a/app/tasks/geoip.py b/app/tasks/geoip.py
index 0d22ed8..2868346 100644
--- a/app/tasks/geoip.py
+++ b/app/tasks/geoip.py
@@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database
from __future__ import annotations
-import asyncio
-
from app.config import settings
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.scheduler import get_scheduler
@@ -28,14 +26,10 @@ async def update_geoip_database():
try:
logger.info("Starting scheduled GeoIP database update...")
geoip = get_geoip_helper()
-
- # Run the synchronous update method in a background thread
- loop = asyncio.get_event_loop()
- await loop.run_in_executor(None, lambda: geoip.update(force=False))
-
+ await geoip.update(force=False)
logger.info("Scheduled GeoIP database update completed successfully")
- except Exception as e:
- logger.error(f"Scheduled GeoIP database update failed: {e}")
+ except Exception as exc:
+ logger.error(f"Scheduled GeoIP database update failed: {exc}")
async def init_geoip():
@@ -45,13 +39,8 @@ async def init_geoip():
try:
geoip = get_geoip_helper()
logger.info("Initializing GeoIP database...")
-
- # Run the synchronous update method in a background thread
- # force=False means only download if files don't exist or are expired
- loop = asyncio.get_event_loop()
- await loop.run_in_executor(None, lambda: geoip.update(force=False))
-
+ await geoip.update(force=False)
logger.info("GeoIP database initialization completed")
- except Exception as e:
- logger.error(f"GeoIP database initialization failed: {e}")
+ except Exception as exc:
+ logger.error(f"GeoIP database initialization failed: {exc}")
# Do not raise an exception to avoid blocking application startup
diff --git a/app/tasks/osu_rx_statistics.py b/app/tasks/osu_rx_statistics.py
index 732d727..9b2a796 100644
--- a/app/tasks/osu_rx_statistics.py
+++ b/app/tasks/osu_rx_statistics.py
@@ -16,7 +16,7 @@ async def create_rx_statistics():
async with with_db() as session:
users = (await session.exec(select(User.id))).all()
total_users = len(users)
- logger.info("Ensuring RX/AP statistics exist for %s users", total_users)
+ logger.info(f"Ensuring RX/AP statistics exist for {total_users} users")
rx_created = 0
ap_created = 0
for i in users:
@@ -57,7 +57,5 @@ async def create_rx_statistics():
await session.commit()
if rx_created or ap_created:
logger.success(
- "Created %s RX statistics rows and %s AP statistics rows during backfill",
- rx_created,
- ap_created,
+ f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill"
)
diff --git a/app/utils.py b/app/utils.py
index 9a610e0..69b3dd7 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -258,10 +258,7 @@ class BackgroundTasks:
self.tasks = set(tasks) if tasks else set()
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
- if is_async_callable(func):
- coro = func(*args, **kwargs)
- else:
- coro = run_in_threadpool(func, *args, **kwargs)
+ coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs)
task = asyncio.create_task(coro)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
diff --git a/main.py b/main.py
index 282fd18..27089b1 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from contextlib import asynccontextmanager
+import json
from pathlib import Path
from app.config import settings
@@ -50,7 +51,7 @@ import sentry_sdk
@asynccontextmanager
-async def lifespan(app: FastAPI):
+async def lifespan(app: FastAPI): # noqa: ARG001
# on startup
init_mods()
init_ranked_mods()
@@ -223,26 +224,26 @@ async def health_check():
@app.exception_handler(RequestValidationError)
-async def validation_exception_handler(request: Request, exc: RequestValidationError):
+async def validation_exception_handler(request: Request, exc: RequestValidationError): # noqa: ARG001
return JSONResponse(
status_code=422,
content={
- "error": exc.errors(),
+ "error": json.dumps(exc.errors()),
},
)
@app.exception_handler(HTTPException)
-async def http_exception_handler(requst: Request, exc: HTTPException):
+async def http_exception_handler(request: Request, exc: HTTPException): # noqa: ARG001
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
-if settings.secret_key == "your_jwt_secret_here":
+if settings.secret_key == "your_jwt_secret_here": # noqa: S105
system_logger("Security").opt(colors=True).warning(
"jwt_secret_key is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 32."
)
-if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
+if settings.osu_web_client_secret == "your_osu_web_client_secret_here": # noqa: S105
system_logger("Security").opt(colors=True).warning(
"osu_web_client_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 40."
diff --git a/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py b/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py
index 9e15a12..aeba732 100644
--- a/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py
+++ b/migrations/versions/2025-08-22_d103d442dc24_add_password_reset_table.py
@@ -1,3 +1,4 @@
+# ruff: noqa
"""add_password_reset_table
Revision ID: d103d442dc24
diff --git a/pyproject.toml b/pyproject.toml
index 9862e00..2f99cda 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,12 +55,20 @@ select = [
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"T10", # flake8-debugger
- # "T20", # flake8-print
"PYI", # flake8-pyi
"PT", # flake8-pytest-style
"Q", # flake8-quotes
"TID", # flake8-tidy-imports
"RUF", # Ruff-specific rules
+ "FAST", # FastAPI
+ "YTT", # flake8-2020
+ "S", # flake8-bandit
+ "INP", # flake8-no-pep420
+ "SIM", # flake8-simplify
+ "ARG", # flake8-unused-arguments
+ "PTH", # flake8-use-pathlib
+ "N", # pep8-naming
+ "FURB" # refurb
]
ignore = [
"E402", # module-import-not-at-top-of-file
@@ -68,10 +76,17 @@ ignore = [
"RUF001", # ambiguous-unicode-character-string
"RUF002", # ambiguous-unicode-character-docstring
"RUF003", # ambiguous-unicode-character-comment
+ "S101", # assert
+ "S311", # suspicious-non-cryptographic-random-usage
]
[tool.ruff.lint.extend-per-file-ignores]
"app/database/**/*.py" = ["I002"]
+"tools/*.py" = ["PTH", "INP001"]
+"migrations/**/*.py" = ["INP001"]
+".github/**/*.py" = ["INP001"]
+"app/achievements/*.py" = ["INP001", "ARG"]
+"app/router/**/*.py" = ["ARG001"]
[tool.ruff.lint.isort]
force-sort-within-sections = true
diff --git a/tools/fix_user_rank_event.py b/tools/fix_user_rank_event.py
index 45b098a..67eab51 100644
--- a/tools/fix_user_rank_event.py
+++ b/tools/fix_user_rank_event.py
@@ -163,13 +163,19 @@ async def main():
# Show specific changes
changes = []
- if "scorerank" in original_payload and "scorerank" in fixed_payload:
- if original_payload["scorerank"] != fixed_payload["scorerank"]:
- changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}")
+ if (
+ "scorerank" in original_payload
+ and "scorerank" in fixed_payload
+ and original_payload["scorerank"] != fixed_payload["scorerank"]
+ ):
+ changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}")
- if "mode" in original_payload and "mode" in fixed_payload:
- if original_payload["mode"] != fixed_payload["mode"]:
- changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}")
+ if (
+ "mode" in original_payload
+ and "mode" in fixed_payload
+ and original_payload["mode"] != fixed_payload["mode"]
+ ):
+ changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}")
if changes:
print(f" Changes: {', '.join(changes)}")