chore(linter): update ruff rules

This commit is contained in:
MingxuanGame
2025-10-03 15:46:53 +00:00
parent b10425ad91
commit d490239f46
59 changed files with 393 additions and 425 deletions

View File

@@ -32,11 +32,9 @@ async def process_streak(
).first()
if not stats:
return False
if streak <= stats.daily_streak_best < next_streak:
return True
elif next_streak == 0 and stats.daily_streak_best >= streak:
return True
return False
return bool(
streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
)
MEDALS = {

View File

@@ -68,9 +68,7 @@ async def to_the_core(
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
return False
mods_ = mod_to_save(score.mods)
if "DT" not in mods_ or "NC" not in mods_:
return False
return True
return not ("DT" not in mods_ or "NC" not in mods_)
async def wysi(
@@ -83,9 +81,7 @@ async def wysi(
return False
if str(round(score.accuracy, ndigits=4))[3:] != "727":
return False
if "xi" not in beatmap.beatmapset.artist:
return False
return True
return "xi" in beatmap.beatmapset.artist
async def prepared(
@@ -97,9 +93,7 @@ async def prepared(
if score.rank != Rank.X and score.rank != Rank.XH:
return False
mods_ = mod_to_save(score.mods)
if "NF" not in mods_:
return False
return True
return "NF" in mods_
async def reckless_adandon(
@@ -117,9 +111,7 @@ async def reckless_adandon(
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < 3:
return False
return True
return not attribute.star_rating < 3
async def lights_out(
@@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth(
return False
for mod in score.mods:
if mod.get("acronym") == "AC":
if "settings" in mod and "minimum_accuracy" in mod["settings"]:
target_accuracy = mod["settings"]["minimum_accuracy"]
if isinstance(target_accuracy, int | float):
return abs(score.accuracy - float(target_accuracy)) < 0.0001
if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
target_accuracy = mod["settings"]["minimum_accuracy"]
if isinstance(target_accuracy, int | float):
return abs(score.accuracy - float(target_accuracy)) < 0.0001
return False

View File

@@ -19,9 +19,7 @@ async def process_mod(
return False
if not beatmap.beatmap_status.has_leaderboard():
return False
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
return False
return True
return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
async def process_category_mod(

View File

@@ -22,11 +22,7 @@ async def process_combo(
return False
if next_combo != 0 and combo >= next_combo:
return False
if combo <= score.max_combo < next_combo:
return True
elif next_combo == 0 and score.max_combo >= combo:
return True
return False
return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
MEDALS: Medals = {

View File

@@ -35,11 +35,7 @@ async def process_playcount(
).first()
if not stats:
return False
if pc <= stats.play_count < next_pc:
return True
elif next_pc == 0 and stats.play_count >= pc:
return True
return False
return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
MEDALS: Medals = {

View File

@@ -47,9 +47,7 @@ async def process_skill(
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False
if type == "fc" and not score.is_perfect_combo:
return False
return True
return not (type == "fc" and not score.is_perfect_combo)
MEDALS: Medals = {

View File

@@ -35,11 +35,7 @@ async def process_tth(
).first()
if not stats:
return False
if tth <= stats.total_hits < next_tth:
return True
elif next_tth == 0 and stats.play_count >= tth:
return True
return False
return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
MEDALS: Medals = {

View File

@@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
2. MD5哈希 -> bcrypt验证
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
# 2. 检查缓存
if bcrypt_hash in bcrypt_cache:
@@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def get_password_hash(password: str) -> str:
"""生成密码哈希 - 使用 osu! 的方式"""
# 1. 明文密码 -> MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
# 2. MD5 -> bcrypt
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
return pw_bcrypt.decode()
@@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
验证用户身份 - 使用类似 from_login 的逻辑
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest()
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
# 2. 根据用户名查找用户
user = None
@@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str:
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
"""
if settings.totp_use_username_in_label:
# 使用用户名作为主要标识
primary_identifier = user.username
else:
# 使用邮箱作为标识
primary_identifier = user.email
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name:

View File

@@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
if len(hit_objects) > i + per_1s:
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
return True
elif len(hit_objects) > i + per_10s:
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
return True
elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
return True
return False
@@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
def is_2b(hit_objects: list[HitObject]) -> bool:
for i in range(0, len(hit_objects) - 1):
if hit_objects[i] == hit_objects[i + 1].start_time:
return True
return False
return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
def is_suspicious_beatmap(content: str) -> bool:

View File

@@ -217,7 +217,7 @@ STORAGE_SETTINGS='{
# 服务器设置
host: Annotated[
str,
Field(default="0.0.0.0", description="服务器监听地址"),
Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
"服务器设置",
]
port: Annotated[
@@ -609,26 +609,26 @@ STORAGE_SETTINGS='{
]
@field_validator("fetcher_scopes", mode="before")
@classmethod
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
@field_validator("storage_settings", mode="after")
@classmethod
def validate_storage_settings(
cls,
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
info: ValidationInfo,
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
if not isinstance(v, CloudflareR2Settings):
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
if not isinstance(v, LocalStorageSettings):
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
if not isinstance(v, AWSS3StorageSettings):
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
service = info.data.get("storage_service")
if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
return v

View File

@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
beatmap = Beatmap.model_validate(
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return beatmap
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)

View File

@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
return beatmapset
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset)
await session.commit()

View File

@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
)
).first()
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
if last_msg and last_msg.isdigit():
last_msg = int(last_msg)
else:
last_msg = None
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
if last_read_id and last_read_id.isdigit():
last_read_id = int(last_read_id)
else:
last_read_id = last_msg
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
if silence is not None:
attribute = ChatUserAttributes(

View File

@@ -520,12 +520,11 @@ async def _score_where(
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
else:
return None
elif type == LeaderboardType.TEAM:
if user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
elif type == LeaderboardType.TEAM and user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
if mods:
if user and user.is_supporter:
wheres.append(

View File

@@ -256,8 +256,6 @@ class UserResp(UserBase):
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
@@ -310,16 +308,16 @@ class UserResp(UserBase):
).all()
]
if "team" in include:
if team_membership := await obj.awaitable_attrs.team_membership:
u.team = team_membership.team
if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
u.team = team_membership.team
if "account_history" in include:
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
if "daily_challenge_user_stats":
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "daily_challenge_user_stats" in include and (
daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
):
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include:
current_stattistics = None
@@ -443,7 +441,7 @@ class MeResp(UserResp):
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id

View File

@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ValidationError
def BodyOrForm[T: BaseModel](model: type[T]):
def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
async def dependency(
request: Request,
) -> T:

View File

@@ -119,10 +119,7 @@ async def get_client_user(
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = "totp"
else:
verify_method = "mail"
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
# 设置选择的验证方法到Redis中避免重复选择
if api_version >= 20250913:

View File

View File

@@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
@@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher):
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e:
logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}")
logger.warning(f"Cache data invalid, fetching from API: {e}")
# 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug("Cache miss, fetching from API")
logger.debug("Cache miss, fetching from API")
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
@@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
logger.opt(colors=True).info("Prefetch skipped due to rate limit")
logger.info("Prefetch skipped due to rate limit")
bg_tasks.add_task(delayed_prefetch)
@@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.opt(colors=True).debug(f"Prefetching page {page + 1}")
logger.debug(f"Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(f"Page {page + 1} already cached")
logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
@@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher):
cursor = data["cursor"]
continue
except Exception:
pass
logger.warning("Failed to parse cached data for cursor")
break
# 在预取页面之间添加延迟,避免突发请求
@@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher):
ex=prefetch_ttl,
)
logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
except RateLimitError:
logger.opt(colors=True).info("Prefetch stopped due to rate limit")
logger.info("Prefetch stopped due to rate limit")
except Exception as e:
logger.opt(colors=True).warning(f"Prefetch failed: {e}")
logger.warning(f"Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
for i, (query, cursor) in enumerate(homepage_queries):
try:
@@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存
if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(f"Query {query.sort} already cached")
logger.debug(f"Query {query.sort} already cached")
continue
# 请求并缓存
@@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher):
ex=cache_ttl,
)
logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
if api_response.get("cursor"):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
except RateLimitError:
logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit")
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
except Exception as e:
logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}")
logger.error(f"Failed to warmup cache for {query.sort}: {e}")

0
app/helpers/__init__.py Normal file
View File

View File

@@ -1,19 +1,39 @@
"""
GeoLite2 Helper Class
GeoLite2 Helper Class (asynchronous)
"""
from __future__ import annotations
import asyncio
from contextlib import suppress
import os
from pathlib import Path
import shutil
import tarfile
import tempfile
import time
from typing import Any, Required, TypedDict
from app.log import logger
import aiofiles
import httpx
import maxminddb
class GeoIPLookupResult(TypedDict, total=False):
ip: Required[str]
country_iso: str
country_name: str
city_name: str
latitude: str
longitude: str
time_zone: str
postal_code: str
asn: int | None
organization: str
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
@@ -25,161 +45,184 @@ EDITIONS = {
class GeoIPHelper:
def __init__(
self,
dest_dir="./geoip",
license_key=None,
editions=None,
max_age_days=8,
timeout=60.0,
dest_dir: str | Path = Path("./geoip"),
license_key: str | None = None,
editions: list[str] | None = None,
max_age_days: int = 8,
timeout: float = 60.0,
):
self.dest_dir = dest_dir
self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"]
self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days
self.timeout = timeout
self._readers = {}
self._readers: dict[str, maxminddb.Reader] = {}
self._update_lock = asyncio.Lock()
@staticmethod
def _safe_extract(tar: tarfile.TarFile, path: str):
base = Path(path).resolve()
for m in tar.getmembers():
target = (base / m.name).resolve()
if not str(target).startswith(str(base)):
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = path.resolve()
for member in tar.getmembers():
target = (base / member.name).resolve()
if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter="data")
tar.extractall(path=base, filter="data")
def _download_and_extract(self, edition_id: str) -> str:
"""
下载并解压 mmdb 文件到 dest_dir仅保留 .mmdb
- 跟随 302 重定向
- 流式下载到临时文件
- 临时目录退出后自动清理
"""
@staticmethod
def _as_mapping(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
@staticmethod
def _as_str(value: Any, default: str = "") -> str:
if isinstance(value, str):
return value
if value is None:
return default
return str(value)
@staticmethod
def _as_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
@staticmethod
def _extract_tarball(src: Path, dest: Path) -> None:
with tarfile.open(src, "r:gz") as tar:
GeoIPHelper._safe_extract(tar, dest)
@staticmethod
def _find_mmdb(root: Path) -> Path | None:
for candidate in root.rglob("*.mmdb"):
return candidate
return None
def _latest_file_sync(self, edition_id: str) -> Path | None:
directory = self.dest_dir
if not directory.is_dir():
return None
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
async def _latest_file(self, edition_id: str) -> Path | None:
return await asyncio.to_thread(self._latest_file_sync, edition_id)
async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key:
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp:
try:
tgz_path = tmp_dir / "db.tgz"
async with (
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
client.stream("GET", url) as resp,
):
resp.raise_for_status()
with tempfile.TemporaryDirectory() as tmpd:
tgz_path = os.path.join(tmpd, "db.tgz")
# 流式写入
with open(tgz_path, "wb") as f:
for chunk in resp.iter_bytes():
if chunk:
f.write(chunk)
async with aiofiles.open(tgz_path, "wb") as download_file:
async for chunk in resp.aiter_bytes():
if chunk:
await download_file.write(chunk)
# 解压并只移动 .mmdb
with tarfile.open(tgz_path, "r:gz") as tar:
# 先安全检查与解压
self._safe_extract(tar, tmpd)
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
if mmdb_path is None:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
# 递归找 .mmdb
mmdb_path = None
for root, _, files in os.walk(tmpd):
for fn in files:
if fn.endswith(".mmdb"):
mmdb_path = os.path.join(root, fn)
break
if mmdb_path:
break
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
dst = self.dest_dir / mmdb_path.name
await asyncio.to_thread(shutil.move, mmdb_path, dst)
return dst
finally:
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
if not mmdb_path:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
async def update(self, force: bool = False) -> None:
async with self._update_lock:
for edition in self.editions:
edition_id = EDITIONS[edition]
path = await self._latest_file(edition_id)
need_download = force or path is None
os.makedirs(self.dest_dir, exist_ok=True)
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
shutil.move(mmdb_path, dst)
return dst
def _latest_file(self, edition_id: str):
if not os.path.isdir(self.dest_dir):
return None
files = [
os.path.join(self.dest_dir, f)
for f in os.listdir(self.dest_dir)
if f.startswith(edition_id) and f.endswith(".mmdb")
]
return max(files, key=os.path.getmtime) if files else None
def update(self, force=False):
from app.log import logger
for ed in self.editions:
eid = EDITIONS[ed]
path = self._latest_file(eid)
need = force or not path
if path:
age_days = (time.time() - os.path.getmtime(path)) / 86400
if age_days >= self.max_age_days:
need = True
logger.info(
f"{eid} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
if path:
mtime = await asyncio.to_thread(path.stat)
age_days = (time.time() - mtime.st_mtime) / 86400
if age_days >= self.max_age_days:
need_download = True
logger.info(
f"{edition_id} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
else:
logger.info(
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
)
else:
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
else:
logger.info(f"{eid} database not found, will download")
logger.info(f"{edition_id} database not found, will download")
if need:
logger.info(f"Downloading {eid} database...")
path = self._download_and_extract(eid)
logger.info(f"{eid} database downloaded successfully")
else:
logger.info(f"Using existing {eid} database")
if need_download:
logger.info(f"Downloading {edition_id} database...")
path = await self._download_and_extract(edition_id)
logger.info(f"{edition_id} database downloaded successfully")
else:
logger.info(f"Using existing {edition_id} database")
old = self._readers.get(ed)
if old:
try:
old.close()
except Exception:
pass
if path is not None:
self._readers[ed] = maxminddb.open_database(path)
old_reader = self._readers.get(edition)
if old_reader:
with suppress(Exception):
old_reader.close()
if path is not None:
self._readers[edition] = maxminddb.open_database(str(path))
def lookup(self, ip: str):
res = {"ip": ip}
# City
city_r = self._readers.get("City")
if city_r:
data = city_r.get(ip)
if data:
country = data.get("country") or {}
res["country_iso"] = country.get("iso_code") or ""
res["country_name"] = (country.get("names") or {}).get("en", "")
city = data.get("city") or {}
res["city_name"] = (city.get("names") or {}).get("en", "")
loc = data.get("location") or {}
res["latitude"] = str(loc.get("latitude") or "")
res["longitude"] = str(loc.get("longitude") or "")
res["time_zone"] = str(loc.get("time_zone") or "")
postal = data.get("postal") or {}
if "code" in postal:
res["postal_code"] = postal["code"]
# ASN
asn_r = self._readers.get("ASN")
if asn_r:
data = asn_r.get(ip)
if data:
res["asn"] = data.get("autonomous_system_number")
res["organization"] = data.get("autonomous_system_organization")
def lookup(self, ip: str) -> GeoIPLookupResult:
res: GeoIPLookupResult = {"ip": ip}
city_reader = self._readers.get("City")
if city_reader:
data = city_reader.get(ip)
if isinstance(data, dict):
country = self._as_mapping(data.get("country"))
res["country_iso"] = self._as_str(country.get("iso_code"))
country_names = self._as_mapping(country.get("names"))
res["country_name"] = self._as_str(country_names.get("en"))
city = self._as_mapping(data.get("city"))
city_names = self._as_mapping(city.get("names"))
res["city_name"] = self._as_str(city_names.get("en"))
location = self._as_mapping(data.get("location"))
latitude = location.get("latitude")
longitude = location.get("longitude")
res["latitude"] = str(latitude) if latitude is not None else ""
res["longitude"] = str(longitude) if longitude is not None else ""
res["time_zone"] = self._as_str(location.get("time_zone"))
postal = self._as_mapping(data.get("postal"))
postal_code = postal.get("code")
if postal_code is not None:
res["postal_code"] = self._as_str(postal_code)
asn_reader = self._readers.get("ASN")
if asn_reader:
data = asn_reader.get(ip)
if isinstance(data, dict):
res["asn"] = self._as_int(data.get("autonomous_system_number"))
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res
def close(self):
for r in self._readers.values():
try:
r.close()
except Exception:
pass
def close(self) -> None:
for reader in self._readers.values():
with suppress(Exception):
reader.close()
self._readers = {}
if __name__ == "__main__":
# 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
async def _demo() -> None:
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
await geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
asyncio.run(_demo())

View File

@@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler):
status_color = "green"
elif 300 <= status < 400:
status_color = "yellow"
elif 400 <= status < 500:
status_color = "red"
elif 500 <= status < 600:
elif 400 <= status < 500 or 500 <= status < 600:
status_color = "red"
return (

View File

@@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# 启动验证流程
return await self._initiate_verification(request, session_state)
return await self._initiate_verification(session_state)
def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证"""
@@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return True
# 非API请求跳过
if not path.startswith("/api/"):
return True
return False
return bool(not path.startswith("/api/"))
def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证"""
@@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
logger.error(f"Error getting session state: {e}")
return None
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
async def _initiate_verification(self, state: SessionState) -> Response:
"""启动验证流程"""
try:
method = await state.get_method()

View File

@@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None
token_type: str = "Bearer"
token_type: str = "Bearer" # noqa: S105
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
@@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel):
requires_second_factor: bool = False
verification_message: str | None = None
user_id: int | None = None # 用于二次验证的用户ID
class SessionState(BaseModel):
"""会话状态"""
user_id: int
username: str
email: str
requires_verification: bool
session_token: str | None = None
verification_sent: bool = False

View File

@@ -1,3 +1,4 @@
# ruff: noqa: ARG002
from __future__ import annotations
from abc import abstractmethod

View File

@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
class TokenResponse(BaseModel):
access_token: str
token_type: str = "Bearer"
token_type: str = "Bearer" # noqa: S105
expires_in: int
refresh_token: str
scope: str = "*"
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
class OAuth2ClientCredentialsBearer(OAuth2):
def __init__(
self,
tokenUrl: Annotated[
tokenUrl: Annotated[ # noqa: N803
str,
Doc(
"""
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
"""
),
],
refreshUrl: Annotated[
refreshUrl: Annotated[ # noqa: N803
str | None,
Doc(
"""

View File

@@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel):
class PlayerEventItem(BaseModel):
"""玩家事件项目"""
userId: int
userId: int # noqa: N815
name: str
mapId: int | None = None
setId: int | None = None
mapId: int | None = None # noqa: N815
setId: int | None = None # noqa: N815
artist: str | None = None
title: str | None = None
version: str | None = None
@@ -88,7 +88,7 @@ class PlayerInfo(BaseModel):
custom_badge_icon: str
custom_badge_color: str
userpage_content: str
recentFailed: int
recentFailed: int # noqa: N815
social_discord: str | None = None
social_youtube: str | None = None
social_twitter: str | None = None

View File

@@ -126,21 +126,22 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
country_code = "CN" # 默认国家代码
country_code = None # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
if geo_info and (country_code := geo_info.get("country_iso")):
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
if country_code is None:
country_code = "CN"
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
# 确保 AUTO_INCREMENT 值从3开始ID=2是BanchoBot
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
@@ -157,7 +158,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country_code=country_code, # 根据 IP 地理位置设置国家
country_code=country_code,
join_date=utcnow(),
last_visit=utcnow(),
is_supporter=settings.enable_supporter_for_all_users,
@@ -386,7 +387,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=scope,
@@ -439,7 +440,7 @@ async def oauth_token(
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token,
scope=scope,
@@ -509,7 +510,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
@@ -554,7 +555,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),

View File

@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
"allowed_mods": item_data.get("allowed_mods", []),
"expired": bool(item_data.get("expired", False)),
"playlist_order": item_data.get("playlist_order", default_order),
"played_at": item_data.get("played_at", None),
"played_at": item_data.get("played_at"),
"freestyle": bool(item_data.get("freestyle", True)),
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
"star_rating": item_data.get("star_rating", 0.0),

View File

@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
@bot.command("roll")
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
r = random.randint(1, 100)
r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
return f"{user.username} rolls {r} point(s)"
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
if gamemode is None:
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
gamemode = target_user.playmode
gamemode = last_score.gamemode if last_score is not None else target_user.playmode
statistics = (
await session.exec(

View File

@@ -313,10 +313,7 @@ async def chat_websocket(
# 优先使用查询参数中的token支持token或access_token参数名
auth_token = token or access_token
if not auth_token and authorization:
if authorization.startswith("Bearer "):
auth_token = authorization[7:]
else:
auth_token = authorization
auth_token = authorization.removeprefix("Bearer ")
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")

View File

@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
redirect_router = APIRouter(include_in_schema=False)
@redirect_router.get("/users/{path:path}")
@redirect_router.get("/users/{path:path}") # noqa: FAST003
@redirect_router.get("/teams/{team_id}")
@redirect_router.get("/u/{user_id}")
@redirect_router.get("/b/{beatmap_id}")

View File

@@ -168,10 +168,7 @@ async def get_beatmaps(
elif beatmapset_id is not None:
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
await beatmapset.awaitable_attrs.beatmaps
if len(beatmapset.beatmaps) > limit:
beatmaps = beatmapset.beatmaps[:limit]
else:
beatmaps = beatmapset.beatmaps
beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
elif user is not None:
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()

View File

@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
)
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:

View File

@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
@@ -104,7 +103,7 @@ async def search_beatmapset(
if cached_result:
sets = SearchBeatmapsetsResp(**cached_result)
# 处理资源代理
processed_sets = await process_response_assets(sets, request)
processed_sets = await process_response_assets(sets)
return processed_sets
try:
@@ -115,7 +114,7 @@ async def search_beatmapset(
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
# 处理资源代理
processed_sets = await process_response_assets(sets, request)
processed_sets = await process_response_assets(sets)
return processed_sets
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp, request)
processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
# 处理资源代理
processed_resp = await process_response_assets(resp, request)
processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
@@ -176,7 +175,7 @@ async def get_beatmapset(
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp, request)
processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -187,7 +186,7 @@ async def get_beatmapset(
await cache_service.cache_beatmapset(resp)
# 处理资源代理
processed_resp = await process_response_assets(resp, request)
processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc

View File

@@ -166,7 +166,6 @@ async def get_room(
db: Database,
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
redis: Redis,
category: Annotated[
str,
Query(

View File

@@ -847,10 +847,7 @@ async def reorder_score_pin(
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
if after_score_id:
target_order = reference_score.pinned_order + 1
else:
target_order = reference_score.pinned_order
target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
current_order = score_record.pinned_order

View File

@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
message: str
class VerifyFailed(Exception):
class VerifyFailedError(Exception):
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
super().__init__(message)
self.reason = reason
@@ -93,10 +93,7 @@ async def verify_session(
# 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证
# print(api_version, totp_key)
if api_version < 20240101 or totp_key is None:
verify_method = "mail"
else:
verify_method = "totp"
verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method
@@ -109,7 +106,7 @@ async def verify_session(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
verify_method = "mail"
raise VerifyFailed("用户TOTP已被删除已切换到邮件验证")
raise VerifyFailedError("用户TOTP已被删除已切换到邮件验证")
# 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里
@@ -120,16 +117,16 @@ async def verify_session(
else:
# 记录详细的验证失败原因参考osu-web的错误处理
if len(verification_key) != 6:
raise VerifyFailed("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
raise VerifyFailedError("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
elif not verification_key.isdigit():
raise VerifyFailed("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
raise VerifyFailedError("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
else:
# 可能是密钥错误或者重放攻击
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
else:
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
if not success:
raise VerifyFailed(f"邮件验证失败: {message}")
raise VerifyFailedError(f"邮件验证失败: {message}")
await LoginLogService.record_login(
db=db,
@@ -144,7 +141,7 @@ async def verify_session(
await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
except VerifyFailed as e:
except VerifyFailedError as e:
await LoginLogService.record_failed_login(
db=db,
request=request,
@@ -171,7 +168,9 @@ async def verify_session(
)
error_response["reissued"] = True
except Exception:
pass # 忽略重发邮件失败的错误
log("Verification").exception(
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)

View File

@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
)
).first()
if user_beatmap_score is None:
return False
return True
return user_beatmap_score is not None
@router.put(
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
.where(BeatmapTagVote.user_id == current_user.id)
)
).first()
if previous_votes is None:
if check_user_can_vote(current_user, beatmap_id, session):
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
session.add(new_vote)
if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
session.add(new_vote)
await session.commit()
except ValueError:
raise HTTPException(400, "Tag is not found")

View File

@@ -91,7 +91,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=cached_users)
processed_response = await process_response_assets(response, request)
processed_response = await process_response_assets(response)
return processed_response
else:
searched_users = (await session.exec(select(User).limit(50))).all()
@@ -109,7 +109,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=users)
processed_response = await process_response_assets(response, request)
processed_response = await process_response_assets(response)
return processed_response
@@ -240,7 +240,7 @@ async def get_user_info(
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
# 处理资源代理
processed_user = await process_response_assets(cached_user, request)
processed_user = await process_response_assets(cached_user)
return processed_user
searched_user = (
@@ -263,7 +263,7 @@ async def get_user_info(
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
processed_user = await process_response_assets(user_resp, request)
processed_user = await process_response_assets(user_resp)
return processed_user
@@ -381,7 +381,7 @@ async def get_user_scores(
user_id, type, include_fails, mode, limit, offset, is_legacy_api
)
if cached_scores is not None:
processed_scores = await process_response_assets(cached_scores, request)
processed_scores = await process_response_assets(cached_scores)
return processed_scores
db_user = await session.get(User, user_id)
@@ -438,5 +438,5 @@ async def get_user_scores(
)
# 处理资源代理
processed_scores = await process_response_assets(score_responses, request)
processed_scores = await process_response_assets(score_responses)
return processed_scores

View File

@@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request
async def process_response_assets(data: Any, request: Request) -> Any:
async def process_response_assets(data: Any) -> Any:
"""
根据配置处理响应数据中的资源URL
@@ -72,7 +72,7 @@ def asset_proxy_response(func):
# 如果有request对象且启用了资源代理则处理响应
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
result = await process_response_assets(result, request)
result = await process_response_assets(result)
return result

View File

@@ -113,6 +113,7 @@ class BeatmapCacheService:
if size:
total_size += size
except Exception:
logger.debug(f"Failed to get size for key {key}")
continue
return {

View File

@@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str:
def generate_hash(data) -> str:
"""生成数据的MD5哈希值"""
if isinstance(data, str):
content = data
else:
content = safe_json_dumps(data)
return hashlib.md5(content.encode()).hexdigest()
content = data if isinstance(data, str) else safe_json_dumps(data)
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
class BeatmapsetCacheService:

View File

@@ -110,9 +110,7 @@ class ProcessingBeatmapset:
changed_beatmaps = []
for bm in self.beatmapset.beatmaps:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
if not saved:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["is_deleted"]:
if not saved or saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm.checksum:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
@@ -285,7 +283,7 @@ class BeatmapsetUpdateService:
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
async with with_db() as session:
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset)
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
if db_beatmapset:
await session.merge(new_beatmapset)
await session.commit()
@@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic
def get_beatmapset_update_service() -> BeatmapsetUpdateService:
if service is None:
raise ValueError("BeatmapsetUpdateService is not initialized")
assert service is not None, "BeatmapsetUpdateService is not initialized"
return service

View File

@@ -128,7 +128,11 @@ class LoginLogService:
login_success=False,
login_method=login_method,
user_agent=user_agent,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
notes=(
f"Failed login attempt on user {attempted_username}: {notes}"
if attempted_username
else "Failed login attempt"
),
)

View File

@@ -120,7 +120,7 @@ class PasswordResetService:
await redis.delete(reset_code_key)
await redis.delete(rate_limit_key)
except Exception:
pass
logger.warning("Failed to clean up Redis data after error")
logger.exception("Redis operation failed")
return False, "服务暂时不可用,请稍后重试"

View File

@@ -593,10 +593,7 @@ class RankingCacheService:
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
"""使地区排行榜缓存失效"""
try:
if ruleset:
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
@@ -608,10 +605,7 @@ class RankingCacheService:
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
"""使战队排行榜缓存失效"""
try:
if ruleset:
pattern = f"team_ranking:{ruleset}:*"
else:
pattern = "team_ranking:*"
pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
@@ -637,6 +631,7 @@ class RankingCacheService:
if size:
total_size += size
except Exception:
logger.warning(f"Failed to get memory usage for key {key}")
continue
return {

View File

View File

@@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber):
self.add_handler(ON_NOTIFICATION, self.on_notification)
self.start()
async def on_join_room(self, c: str, s: str):
async def on_join_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
async def on_leave_room(self, c: str, s: str):
async def on_leave_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
async def on_notification(self, c: str, s: str):
async def on_notification(self, c: str, s: str): # noqa: ARG002
try:
detail = TypeAdapter(NotificationDetails).validate_json(s)
except ValueError:

View File

@@ -357,6 +357,7 @@ class UserCacheService:
if size:
total_size += size
except Exception:
logger.warning(f"Failed to get memory usage for key {key}")
continue
return {

View File

@@ -288,10 +288,6 @@ This email was sent automatically, please do not reply.
redis: Redis,
user_id: int,
code: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]:
"""验证邮箱验证码"""
try:

View File

@@ -41,7 +41,7 @@ async def warmup_cache() -> None:
logger.info("Beatmap cache warmup completed successfully")
except Exception as e:
logger.error("Beatmap cache warmup failed: %s", e)
logger.error(f"Beatmap cache warmup failed: {e}")
async def refresh_ranking_cache() -> None:
@@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None:
logger.info("Ranking cache refresh completed successfully")
except Exception as e:
logger.error("Ranking cache refresh failed: %s", e)
logger.error(f"Ranking cache refresh failed: {e}")
async def schedule_user_cache_preload_task() -> None:
@@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None:
if active_user_ids:
user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids)
logger.info("Preloaded cache for %s active users", len(user_ids))
logger.info(f"Preloaded cache for {len(user_ids)} active users")
else:
logger.info("No active users found for cache preload")
logger.info("User cache preload task completed successfully")
except Exception as e:
logger.error("User cache preload task failed: %s", e)
logger.error(f"User cache preload task failed: {e}")
async def schedule_user_cache_warmup_task() -> None:
@@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None:
if top_users:
user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids)
logger.info("Warmed cache for top 100 users in %s", mode)
logger.info(f"Warmed cache for top 100 users in {mode}")
await asyncio.sleep(1)
except Exception as e:
logger.error("Failed to warm cache for %s: %s", mode, e)
logger.error(f"Failed to warm cache for {mode}: {e}")
continue
logger.info("User cache warmup task completed successfully")
except Exception as e:
logger.error("User cache warmup task failed: %s", e)
logger.error(f"User cache warmup task failed: {e}")
async def schedule_user_cache_cleanup_task() -> None:
@@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None:
cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats()
logger.info("User cache stats: %s", stats)
logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully")
except Exception as e:
logger.error("User cache cleanup task failed: %s", e)
logger.error(f"User cache cleanup task failed: {e}")
async def warmup_user_cache() -> None:
@@ -167,7 +167,7 @@ async def warmup_user_cache() -> None:
try:
await schedule_user_cache_warmup_task()
except Exception as e:
logger.error("User cache warmup failed: %s", e)
logger.error(f"User cache warmup failed: {e}")
async def preload_user_cache() -> None:
@@ -175,7 +175,7 @@ async def preload_user_cache() -> None:
try:
await schedule_user_cache_preload_task()
except Exception as e:
logger.error("User cache preload failed: %s", e)
logger.error(f"User cache preload failed: {e}")
async def cleanup_user_cache() -> None:
@@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None:
try:
await schedule_user_cache_cleanup_task()
except Exception as e:
logger.error("User cache cleanup failed: %s", e)
logger.error(f"User cache cleanup failed: {e}")
def register_cache_jobs() -> None:

View File

@@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database
from __future__ import annotations
import asyncio
from app.config import settings
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.scheduler import get_scheduler
@@ -28,14 +26,10 @@ async def update_geoip_database():
try:
logger.info("Starting scheduled GeoIP database update...")
geoip = get_geoip_helper()
# Run the synchronous update method in a background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False))
await geoip.update(force=False)
logger.info("Scheduled GeoIP database update completed successfully")
except Exception as e:
logger.error(f"Scheduled GeoIP database update failed: {e}")
except Exception as exc:
logger.error(f"Scheduled GeoIP database update failed: {exc}")
async def init_geoip():
@@ -45,13 +39,8 @@ async def init_geoip():
try:
geoip = get_geoip_helper()
logger.info("Initializing GeoIP database...")
# Run the synchronous update method in a background thread
# force=False means only download if files don't exist or are expired
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False))
await geoip.update(force=False)
logger.info("GeoIP database initialization completed")
except Exception as e:
logger.error(f"GeoIP database initialization failed: {e}")
except Exception as exc:
logger.error(f"GeoIP database initialization failed: {exc}")
# Do not raise an exception to avoid blocking application startup

View File

@@ -16,7 +16,7 @@ async def create_rx_statistics():
async with with_db() as session:
users = (await session.exec(select(User.id))).all()
total_users = len(users)
logger.info("Ensuring RX/AP statistics exist for %s users", total_users)
logger.info(f"Ensuring RX/AP statistics exist for {total_users} users")
rx_created = 0
ap_created = 0
for i in users:
@@ -57,7 +57,5 @@ async def create_rx_statistics():
await session.commit()
if rx_created or ap_created:
logger.success(
"Created %s RX statistics rows and %s AP statistics rows during backfill",
rx_created,
ap_created,
f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill"
)

View File

@@ -258,10 +258,7 @@ class BackgroundTasks:
self.tasks = set(tasks) if tasks else set()
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
if is_async_callable(func):
coro = func(*args, **kwargs)
else:
coro = run_in_threadpool(func, *args, **kwargs)
coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs)
task = asyncio.create_task(coro)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)