chore(linter): update ruff rules

This commit is contained in:
MingxuanGame
2025-10-03 15:46:53 +00:00
parent b10425ad91
commit d490239f46
59 changed files with 393 additions and 425 deletions

View File

@@ -32,11 +32,9 @@ async def process_streak(
).first() ).first()
if not stats: if not stats:
return False return False
if streak <= stats.daily_streak_best < next_streak: return bool(
return True streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
elif next_streak == 0 and stats.daily_streak_best >= streak: )
return True
return False
MEDALS = { MEDALS = {

View File

@@ -68,9 +68,7 @@ async def to_the_core(
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist: if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
return False return False
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "DT" not in mods_ or "NC" not in mods_: return not ("DT" not in mods_ or "NC" not in mods_)
return False
return True
async def wysi( async def wysi(
@@ -83,9 +81,7 @@ async def wysi(
return False return False
if str(round(score.accuracy, ndigits=4))[3:] != "727": if str(round(score.accuracy, ndigits=4))[3:] != "727":
return False return False
if "xi" not in beatmap.beatmapset.artist: return "xi" in beatmap.beatmapset.artist
return False
return True
async def prepared( async def prepared(
@@ -97,9 +93,7 @@ async def prepared(
if score.rank != Rank.X and score.rank != Rank.XH: if score.rank != Rank.X and score.rank != Rank.XH:
return False return False
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "NF" not in mods_: return "NF" in mods_
return False
return True
async def reckless_adandon( async def reckless_adandon(
@@ -117,9 +111,7 @@ async def reckless_adandon(
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < 3: return not attribute.star_rating < 3
return False
return True
async def lights_out( async def lights_out(
@@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth(
return False return False
for mod in score.mods: for mod in score.mods:
if mod.get("acronym") == "AC": if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
if "settings" in mod and "minimum_accuracy" in mod["settings"]: target_accuracy = mod["settings"]["minimum_accuracy"]
target_accuracy = mod["settings"]["minimum_accuracy"] if isinstance(target_accuracy, int | float):
if isinstance(target_accuracy, int | float): return abs(score.accuracy - float(target_accuracy)) < 0.0001
return abs(score.accuracy - float(target_accuracy)) < 0.0001
return False return False

View File

@@ -19,9 +19,7 @@ async def process_mod(
return False return False
if not beatmap.beatmap_status.has_leaderboard(): if not beatmap.beatmap_status.has_leaderboard():
return False return False
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod: return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
return False
return True
async def process_category_mod( async def process_category_mod(

View File

@@ -22,11 +22,7 @@ async def process_combo(
return False return False
if next_combo != 0 and combo >= next_combo: if next_combo != 0 and combo >= next_combo:
return False return False
if combo <= score.max_combo < next_combo: return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
return True
elif next_combo == 0 and score.max_combo >= combo:
return True
return False
MEDALS: Medals = { MEDALS: Medals = {

View File

@@ -35,11 +35,7 @@ async def process_playcount(
).first() ).first()
if not stats: if not stats:
return False return False
if pc <= stats.play_count < next_pc: return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
return True
elif next_pc == 0 and stats.play_count >= pc:
return True
return False
MEDALS: Medals = { MEDALS: Medals = {

View File

@@ -47,9 +47,7 @@ async def process_skill(
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < star or attribute.star_rating >= star + 1: if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False return False
if type == "fc" and not score.is_perfect_combo: return not (type == "fc" and not score.is_perfect_combo)
return False
return True
MEDALS: Medals = { MEDALS: Medals = {

View File

@@ -35,11 +35,7 @@ async def process_tth(
).first() ).first()
if not stats: if not stats:
return False return False
if tth <= stats.total_hits < next_tth: return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
return True
elif next_tth == 0 and stats.play_count >= tth:
return True
return False
MEDALS: Medals = { MEDALS: Medals = {

View File

@@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
2. MD5哈希 -> bcrypt验证 2. MD5哈希 -> bcrypt验证
""" """
# 1. 明文密码转 MD5 # 1. 明文密码转 MD5
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
# 2. 检查缓存 # 2. 检查缓存
if bcrypt_hash in bcrypt_cache: if bcrypt_hash in bcrypt_cache:
@@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""生成密码哈希 - 使用 osu! 的方式""" """生成密码哈希 - 使用 osu! 的方式"""
# 1. 明文密码 -> MD5 # 1. 明文密码 -> MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
# 2. MD5 -> bcrypt # 2. MD5 -> bcrypt
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt()) pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
return pw_bcrypt.decode() return pw_bcrypt.decode()
@@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
验证用户身份 - 使用类似 from_login 的逻辑 验证用户身份 - 使用类似 from_login 的逻辑
""" """
# 1. 明文密码转 MD5 # 1. 明文密码转 MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest() pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
# 2. 根据用户名查找用户 # 2. 根据用户名查找用户
user = None user = None
@@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str:
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性 根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
""" """
if settings.totp_use_username_in_label: primary_identifier = user.username if settings.totp_use_username_in_label else user.email
# 使用用户名作为主要标识
primary_identifier = user.username
else:
# 使用邮箱作为标识
primary_identifier = user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分 # 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name: if settings.totp_service_name:

View File

@@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
if len(hit_objects) > i + per_1s: if len(hit_objects) > i + per_1s:
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000: if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
return True return True
elif len(hit_objects) > i + per_10s: elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000: return True
return True
return False return False
@@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
def is_2b(hit_objects: list[HitObject]) -> bool: def is_2b(hit_objects: list[HitObject]) -> bool:
for i in range(0, len(hit_objects) - 1): return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
if hit_objects[i] == hit_objects[i + 1].start_time:
return True
return False
def is_suspicious_beatmap(content: str) -> bool: def is_suspicious_beatmap(content: str) -> bool:

View File

@@ -217,7 +217,7 @@ STORAGE_SETTINGS='{
# 服务器设置 # 服务器设置
host: Annotated[ host: Annotated[
str, str,
Field(default="0.0.0.0", description="服务器监听地址"), Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
"服务器设置", "服务器设置",
] ]
port: Annotated[ port: Annotated[
@@ -609,26 +609,26 @@ STORAGE_SETTINGS='{
] ]
@field_validator("fetcher_scopes", mode="before") @field_validator("fetcher_scopes", mode="before")
@classmethod
def validate_fetcher_scopes(cls, v: Any) -> list[str]: def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str): if isinstance(v, str):
return v.split(",") return v.split(",")
return v return v
@field_validator("storage_settings", mode="after") @field_validator("storage_settings", mode="after")
@classmethod
def validate_storage_settings( def validate_storage_settings(
cls, cls,
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings, v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
info: ValidationInfo, info: ValidationInfo,
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings: ) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2: service = info.data.get("storage_service")
if not isinstance(v, CloudflareR2Settings): if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings") raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
elif info.data.get("storage_service") == StorageServiceType.LOCAL: if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
if not isinstance(v, LocalStorageSettings): raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings") if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
elif info.data.get("storage_service") == StorageServiceType.AWS_S3: raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
if not isinstance(v, AWSS3StorageSettings):
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
return v return v

View File

@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}) failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod @classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump() d = resp.model_dump()
del d["beatmapset"] del d["beatmapset"]
beatmap = Beatmap.model_validate( beatmap = cls.model_validate(
{ {
**d, **d,
"beatmapset_id": resp.beatmapset_id, "beatmapset_id": resp.beatmapset_id,
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap) session.add(beatmap)
await session.commit() await session.commit()
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return beatmap
@classmethod @classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]: async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
redis: Redis, redis: Redis,
fetcher: "Fetcher", fetcher: "Fetcher",
): ):
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)

View File

@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod @classmethod
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset": async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump() d = resp.model_dump()
if resp.nominations: if resp.nominations:
d["nominations_required"] = resp.nominations.required d["nominations_required"] = resp.nominations.required
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
return beatmapset return beatmapset
@classmethod @classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset": async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap from .beatmap import Beatmap
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_) beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first(): if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset) session.add(beatmapset)
await session.commit() await session.commit()

View File

@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
) )
).first() ).first()
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg") last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
if last_msg and last_msg.isdigit(): last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_msg = int(last_msg)
else:
last_msg = None
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
if last_read_id and last_read_id.isdigit(): last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
last_read_id = int(last_read_id)
else:
last_read_id = last_msg
if silence is not None: if silence is not None:
attribute = ChatUserAttributes( attribute = ChatUserAttributes(

View File

@@ -520,12 +520,11 @@ async def _score_where(
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code)) wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
else: else:
return None return None
elif type == LeaderboardType.TEAM: elif type == LeaderboardType.TEAM and user:
if user: team_membership = await user.awaitable_attrs.team_membership
team_membership = await user.awaitable_attrs.team_membership if team_membership:
if team_membership: team_id = team_membership.team_id
team_id = team_membership.team_id wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
if mods: if mods:
if user and user.is_supporter: if user and user.is_supporter:
wheres.append( wheres.append(

View File

@@ -256,8 +256,6 @@ class UserResp(UserBase):
session: AsyncSession, session: AsyncSession,
include: list[str] = [], include: list[str] = [],
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "UserResp": ) -> "UserResp":
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
@@ -310,16 +308,16 @@ class UserResp(UserBase):
).all() ).all()
] ]
if "team" in include: if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
if team_membership := await obj.awaitable_attrs.team_membership: u.team = team_membership.team
u.team = team_membership.team
if "account_history" in include: if "account_history" in include:
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history] u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
if "daily_challenge_user_stats": if "daily_challenge_user_stats" in include and (
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats: daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats) ):
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include: if "statistics" in include:
current_stattistics = None current_stattistics = None
@@ -443,7 +441,7 @@ class MeResp(UserResp):
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id) u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
u.session_verified = ( u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id if token_id

View File

@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
def BodyOrForm[T: BaseModel](model: type[T]): def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
async def dependency( async def dependency(
request: Request, request: Request,
) -> T: ) -> T:

View File

@@ -119,10 +119,7 @@ async def get_client_user(
if verify_method is None: if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP # 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER: verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
verify_method = "totp"
else:
verify_method = "mail"
# 设置选择的验证方法到Redis中避免重复选择 # 设置选择的验证方法到Redis中避免重复选择
if api_version >= 20250913: if api_version >= 20250913:

View File

View File

@@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 序列化为 JSON 并生成 MD5 哈希 # 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":")) cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest() cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}") logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
@@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher):
cached_data = json.loads(cached_result) cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data) return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e: except Exception as e:
logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}") logger.warning(f"Cache data invalid, fetching from API: {e}")
# 缓存未命中,从 API 获取数据 # 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug("Cache miss, fetching from API") logger.debug("Cache miss, fetching from API")
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
@@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher):
try: try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError: except RateLimitError:
logger.opt(colors=True).info("Prefetch skipped due to rate limit") logger.info("Prefetch skipped due to rate limit")
bg_tasks.add_task(delayed_prefetch) bg_tasks.add_task(delayed_prefetch)
@@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页 # 使用当前 cursor 请求下一页
next_query = query.model_copy() next_query = query.model_copy()
logger.opt(colors=True).debug(f"Prefetching page {page + 1}") logger.debug(f"Prefetching page {page + 1}")
# 生成下一页的缓存键 # 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor) next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存 # 检查是否已经缓存
if await redis_client.exists(next_cache_key): if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(f"Page {page + 1} already cached") logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取 # 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key) cached_data = await redis_client.get(next_cache_key)
if cached_data: if cached_data:
@@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher):
cursor = data["cursor"] cursor = data["cursor"]
continue continue
except Exception: except Exception:
pass logger.warning("Failed to parse cached data for cursor")
break break
# 在预取页面之间添加延迟,避免突发请求 # 在预取页面之间添加延迟,避免突发请求
@@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher):
ex=prefetch_ttl, ex=prefetch_ttl,
) )
logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)") logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
except RateLimitError: except RateLimitError:
logger.opt(colors=True).info("Prefetch stopped due to rate limit") logger.info("Prefetch stopped due to rate limit")
except Exception as e: except Exception as e:
logger.opt(colors=True).warning(f"Prefetch failed: {e}") logger.warning(f"Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存""" """预热主页缓存"""
homepage_queries = self._get_homepage_queries() homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)") logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
for i, (query, cursor) in enumerate(homepage_queries): for i, (query, cursor) in enumerate(homepage_queries):
try: try:
@@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存 # 检查是否已经缓存
if await redis_client.exists(cache_key): if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(f"Query {query.sort} already cached") logger.debug(f"Query {query.sort} already cached")
continue continue
# 请求并缓存 # 请求并缓存
@@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher):
ex=cache_ttl, ex=cache_ttl,
) )
logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)") logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
if api_response.get("cursor"): if api_response.get("cursor"):
try: try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2) await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError: except RateLimitError:
logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit") logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
except RateLimitError: except RateLimitError:
logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit") logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
except Exception as e: except Exception as e:
logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}") logger.error(f"Failed to warmup cache for {query.sort}: {e}")

0
app/helpers/__init__.py Normal file
View File

View File

@@ -1,19 +1,39 @@
""" """
GeoLite2 Helper Class GeoLite2 Helper Class (asynchronous)
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
from contextlib import suppress
import os import os
from pathlib import Path from pathlib import Path
import shutil import shutil
import tarfile import tarfile
import tempfile import tempfile
import time import time
from typing import Any, Required, TypedDict
from app.log import logger
import aiofiles
import httpx import httpx
import maxminddb import maxminddb
class GeoIPLookupResult(TypedDict, total=False):
ip: Required[str]
country_iso: str
country_name: str
city_name: str
latitude: str
longitude: str
time_zone: str
postal_code: str
asn: int | None
organization: str
BASE_URL = "https://download.maxmind.com/app/geoip_download" BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = { EDITIONS = {
"City": "GeoLite2-City", "City": "GeoLite2-City",
@@ -25,161 +45,184 @@ EDITIONS = {
class GeoIPHelper: class GeoIPHelper:
def __init__( def __init__(
self, self,
dest_dir="./geoip", dest_dir: str | Path = Path("./geoip"),
license_key=None, license_key: str | None = None,
editions=None, editions: list[str] | None = None,
max_age_days=8, max_age_days: int = 8,
timeout=60.0, timeout: float = 60.0,
): ):
self.dest_dir = dest_dir self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY") self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"] self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days self.max_age_days = max_age_days
self.timeout = timeout self.timeout = timeout
self._readers = {} self._readers: dict[str, maxminddb.Reader] = {}
self._update_lock = asyncio.Lock()
@staticmethod @staticmethod
def _safe_extract(tar: tarfile.TarFile, path: str): def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = Path(path).resolve() base = path.resolve()
for m in tar.getmembers(): for member in tar.getmembers():
target = (base / m.name).resolve() target = (base / member.name).resolve()
if not str(target).startswith(str(base)): if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file") raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter="data") tar.extractall(path=base, filter="data")
def _download_and_extract(self, edition_id: str) -> str: @staticmethod
""" def _as_mapping(value: Any) -> dict[str, Any]:
下载并解压 mmdb 文件到 dest_dir仅保留 .mmdb return value if isinstance(value, dict) else {}
- 跟随 302 重定向
- 流式下载到临时文件 @staticmethod
- 临时目录退出后自动清理 def _as_str(value: Any, default: str = "") -> str:
""" if isinstance(value, str):
return value
if value is None:
return default
return str(value)
@staticmethod
def _as_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
@staticmethod
def _extract_tarball(src: Path, dest: Path) -> None:
with tarfile.open(src, "r:gz") as tar:
GeoIPHelper._safe_extract(tar, dest)
@staticmethod
def _find_mmdb(root: Path) -> Path | None:
for candidate in root.rglob("*.mmdb"):
return candidate
return None
def _latest_file_sync(self, edition_id: str) -> Path | None:
directory = self.dest_dir
if not directory.is_dir():
return None
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
async def _latest_file(self, edition_id: str) -> Path | None:
return await asyncio.to_thread(self._latest_file_sync, edition_id)
async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key: if not self.license_key:
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.") raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz" url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client: try:
with client.stream("GET", url) as resp: tgz_path = tmp_dir / "db.tgz"
async with (
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
client.stream("GET", url) as resp,
):
resp.raise_for_status() resp.raise_for_status()
with tempfile.TemporaryDirectory() as tmpd: async with aiofiles.open(tgz_path, "wb") as download_file:
tgz_path = os.path.join(tmpd, "db.tgz") async for chunk in resp.aiter_bytes():
# 流式写入 if chunk:
with open(tgz_path, "wb") as f: await download_file.write(chunk)
for chunk in resp.iter_bytes():
if chunk:
f.write(chunk)
# 解压并只移动 .mmdb await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
with tarfile.open(tgz_path, "r:gz") as tar: mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
# 先安全检查与解压 if mmdb_path is None:
self._safe_extract(tar, tmpd) raise RuntimeError("未在压缩包中找到 .mmdb 文件")
# 递归找 .mmdb await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
mmdb_path = None dst = self.dest_dir / mmdb_path.name
for root, _, files in os.walk(tmpd): await asyncio.to_thread(shutil.move, mmdb_path, dst)
for fn in files: return dst
if fn.endswith(".mmdb"): finally:
mmdb_path = os.path.join(root, fn) await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
break
if mmdb_path:
break
if not mmdb_path: async def update(self, force: bool = False) -> None:
raise RuntimeError("未在压缩包中找到 .mmdb 文件") async with self._update_lock:
for edition in self.editions:
edition_id = EDITIONS[edition]
path = await self._latest_file(edition_id)
need_download = force or path is None
os.makedirs(self.dest_dir, exist_ok=True) if path:
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path)) mtime = await asyncio.to_thread(path.stat)
shutil.move(mmdb_path, dst) age_days = (time.time() - mtime.st_mtime) / 86400
return dst if age_days >= self.max_age_days:
need_download = True
def _latest_file(self, edition_id: str): logger.info(
if not os.path.isdir(self.dest_dir): f"{edition_id} database is {age_days:.1f} days old "
return None f"(max: {self.max_age_days}), will download new version"
files = [ )
os.path.join(self.dest_dir, f) else:
for f in os.listdir(self.dest_dir) logger.info(
if f.startswith(edition_id) and f.endswith(".mmdb") f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
] )
return max(files, key=os.path.getmtime) if files else None
def update(self, force=False):
from app.log import logger
for ed in self.editions:
eid = EDITIONS[ed]
path = self._latest_file(eid)
need = force or not path
if path:
age_days = (time.time() - os.path.getmtime(path)) / 86400
if age_days >= self.max_age_days:
need = True
logger.info(
f"{eid} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
else: else:
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})") logger.info(f"{edition_id} database not found, will download")
else:
logger.info(f"{eid} database not found, will download")
if need: if need_download:
logger.info(f"Downloading {eid} database...") logger.info(f"Downloading {edition_id} database...")
path = self._download_and_extract(eid) path = await self._download_and_extract(edition_id)
logger.info(f"{eid} database downloaded successfully") logger.info(f"{edition_id} database downloaded successfully")
else: else:
logger.info(f"Using existing {eid} database") logger.info(f"Using existing {edition_id} database")
old = self._readers.get(ed) old_reader = self._readers.get(edition)
if old: if old_reader:
try: with suppress(Exception):
old.close() old_reader.close()
except Exception: if path is not None:
pass self._readers[edition] = maxminddb.open_database(str(path))
if path is not None:
self._readers[ed] = maxminddb.open_database(path)
def lookup(self, ip: str): def lookup(self, ip: str) -> GeoIPLookupResult:
res = {"ip": ip} res: GeoIPLookupResult = {"ip": ip}
# City city_reader = self._readers.get("City")
city_r = self._readers.get("City") if city_reader:
if city_r: data = city_reader.get(ip)
data = city_r.get(ip) if isinstance(data, dict):
if data: country = self._as_mapping(data.get("country"))
country = data.get("country") or {} res["country_iso"] = self._as_str(country.get("iso_code"))
res["country_iso"] = country.get("iso_code") or "" country_names = self._as_mapping(country.get("names"))
res["country_name"] = (country.get("names") or {}).get("en", "") res["country_name"] = self._as_str(country_names.get("en"))
city = data.get("city") or {}
res["city_name"] = (city.get("names") or {}).get("en", "") city = self._as_mapping(data.get("city"))
loc = data.get("location") or {} city_names = self._as_mapping(city.get("names"))
res["latitude"] = str(loc.get("latitude") or "") res["city_name"] = self._as_str(city_names.get("en"))
res["longitude"] = str(loc.get("longitude") or "")
res["time_zone"] = str(loc.get("time_zone") or "") location = self._as_mapping(data.get("location"))
postal = data.get("postal") or {} latitude = location.get("latitude")
if "code" in postal: longitude = location.get("longitude")
res["postal_code"] = postal["code"] res["latitude"] = str(latitude) if latitude is not None else ""
# ASN res["longitude"] = str(longitude) if longitude is not None else ""
asn_r = self._readers.get("ASN") res["time_zone"] = self._as_str(location.get("time_zone"))
if asn_r:
data = asn_r.get(ip) postal = self._as_mapping(data.get("postal"))
if data: postal_code = postal.get("code")
res["asn"] = data.get("autonomous_system_number") if postal_code is not None:
res["organization"] = data.get("autonomous_system_organization") res["postal_code"] = self._as_str(postal_code)
asn_reader = self._readers.get("ASN")
if asn_reader:
data = asn_reader.get(ip)
if isinstance(data, dict):
res["asn"] = self._as_int(data.get("autonomous_system_number"))
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res return res
def close(self): def close(self) -> None:
for r in self._readers.values(): for reader in self._readers.values():
try: with suppress(Exception):
r.close() reader.close()
except Exception:
pass
self._readers = {} self._readers = {}
if __name__ == "__main__": if __name__ == "__main__":
# 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="") async def _demo() -> None:
geo.update() geo = GeoIPHelper(dest_dir="./geoip", license_key="")
print(geo.lookup("8.8.8.8")) await geo.update()
geo.close() print(geo.lookup("8.8.8.8"))
geo.close()
asyncio.run(_demo())

View File

@@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler):
status_color = "green" status_color = "green"
elif 300 <= status < 400: elif 300 <= status < 400:
status_color = "yellow" status_color = "yellow"
elif 400 <= status < 500: elif 400 <= status < 500 or 500 <= status < 600:
status_color = "red"
elif 500 <= status < 600:
status_color = "red" status_color = "red"
return ( return (

View File

@@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
# 启动验证流程 # 启动验证流程
return await self._initiate_verification(request, session_state) return await self._initiate_verification(session_state)
def _should_skip_verification(self, request: Request) -> bool: def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证""" """检查是否应该跳过验证"""
@@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return True return True
# 非API请求跳过 # 非API请求跳过
if not path.startswith("/api/"): return bool(not path.startswith("/api/"))
return True
return False
def _requires_verification(self, request: Request, user: User) -> bool: def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证""" """检查是否需要验证"""
@@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
logger.error(f"Error getting session state: {e}") logger.error(f"Error getting session state: {e}")
return None return None
async def _initiate_verification(self, request: Request, state: SessionState) -> Response: async def _initiate_verification(self, state: SessionState) -> Response:
"""启动验证流程""" """启动验证流程"""
try: try:
method = await state.get_method() method = await state.get_method()

View File

@@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态""" """扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None access_token: str | None = None
token_type: str = "Bearer" token_type: str = "Bearer" # noqa: S105
expires_in: int | None = None expires_in: int | None = None
refresh_token: str | None = None refresh_token: str | None = None
scope: str | None = None scope: str | None = None
@@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel):
requires_second_factor: bool = False requires_second_factor: bool = False
verification_message: str | None = None verification_message: str | None = None
user_id: int | None = None # 用于二次验证的用户ID user_id: int | None = None # 用于二次验证的用户ID
class SessionState(BaseModel):
"""会话状态"""
user_id: int
username: str
email: str
requires_verification: bool
session_token: str | None = None
verification_sent: bool = False

View File

@@ -1,3 +1,4 @@
# ruff: noqa: ARG002
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod

View File

@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str
token_type: str = "Bearer" token_type: str = "Bearer" # noqa: S105
expires_in: int expires_in: int
refresh_token: str refresh_token: str
scope: str = "*" scope: str = "*"
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
class OAuth2ClientCredentialsBearer(OAuth2): class OAuth2ClientCredentialsBearer(OAuth2):
def __init__( def __init__(
self, self,
tokenUrl: Annotated[ tokenUrl: Annotated[ # noqa: N803
str, str,
Doc( Doc(
""" """
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
""" """
), ),
], ],
refreshUrl: Annotated[ refreshUrl: Annotated[ # noqa: N803
str | None, str | None,
Doc( Doc(
""" """

View File

@@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel):
class PlayerEventItem(BaseModel): class PlayerEventItem(BaseModel):
"""玩家事件项目""" """玩家事件项目"""
userId: int userId: int # noqa: N815
name: str name: str
mapId: int | None = None mapId: int | None = None # noqa: N815
setId: int | None = None setId: int | None = None # noqa: N815
artist: str | None = None artist: str | None = None
title: str | None = None title: str | None = None
version: str | None = None version: str | None = None
@@ -88,7 +88,7 @@ class PlayerInfo(BaseModel):
custom_badge_icon: str custom_badge_icon: str
custom_badge_color: str custom_badge_color: str
userpage_content: str userpage_content: str
recentFailed: int recentFailed: int # noqa: N815
social_discord: str | None = None social_discord: str | None = None
social_youtube: str | None = None social_youtube: str | None = None
social_twitter: str | None = None social_twitter: str | None = None

View File

@@ -126,21 +126,22 @@ async def register_user(
try: try:
# 获取客户端 IP 并查询地理位置 # 获取客户端 IP 并查询地理位置
country_code = "CN" # 默认国家代码 country_code = None # 默认国家代码
try: try:
# 查询 IP 地理位置 # 查询 IP 地理位置
geo_info = geoip.lookup(client_ip) geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"): if geo_info and (country_code := geo_info.get("country_iso")):
country_code = geo_info["country_iso"]
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}") logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else: else:
logger.warning(f"Could not determine country for IP {client_ip}") logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e: except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}") logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
if country_code is None:
country_code = "CN"
# 创建新用户 # 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy # 确保 AUTO_INCREMENT 值从3开始ID=2是BanchoBot
result = await db.execute( result = await db.execute(
text( text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES " "SELECT AUTO_INCREMENT FROM information_schema.TABLES "
@@ -157,7 +158,7 @@ async def register_user(
email=user_email, email=user_email,
pw_bcrypt=get_password_hash(user_password), pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限 priv=1, # 普通用户权限
country_code=country_code, # 根据 IP 地理位置设置国家 country_code=country_code,
join_date=utcnow(), join_date=utcnow(),
last_visit=utcnow(), last_visit=utcnow(),
is_supporter=settings.enable_supporter_for_all_users, is_supporter=settings.enable_supporter_for_all_users,
@@ -386,7 +387,7 @@ async def oauth_token(
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str, refresh_token=refresh_token_str,
scope=scope, scope=scope,
@@ -439,7 +440,7 @@ async def oauth_token(
) )
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token, refresh_token=new_refresh_token,
scope=scope, scope=scope,
@@ -509,7 +510,7 @@ async def oauth_token(
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str, refresh_token=refresh_token_str,
scope=" ".join(scopes), scope=" ".join(scopes),
@@ -554,7 +555,7 @@ async def oauth_token(
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str, refresh_token=refresh_token_str,
scope=" ".join(scopes), scope=" ".join(scopes),

View File

@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
"allowed_mods": item_data.get("allowed_mods", []), "allowed_mods": item_data.get("allowed_mods", []),
"expired": bool(item_data.get("expired", False)), "expired": bool(item_data.get("expired", False)),
"playlist_order": item_data.get("playlist_order", default_order), "playlist_order": item_data.get("playlist_order", default_order),
"played_at": item_data.get("played_at", None), "played_at": item_data.get("played_at"),
"freestyle": bool(item_data.get("freestyle", True)), "freestyle": bool(item_data.get("freestyle", True)),
"beatmap_checksum": item_data.get("beatmap_checksum", ""), "beatmap_checksum": item_data.get("beatmap_checksum", ""),
"star_rating": item_data.get("star_rating", 0.0), "star_rating": item_data.get("star_rating", 0.0),

View File

@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
@bot.command("roll") @bot.command("roll")
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str: def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
if len(args) > 0 and args[0].isdigit(): r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
r = random.randint(1, int(args[0]))
else:
r = random.randint(1, 100)
return f"{user.username} rolls {r} point(s)" return f"{user.username} rolls {r} point(s)"
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
if gamemode is None: if gamemode is None:
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery() subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first() last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
if last_score is not None: gamemode = last_score.gamemode if last_score is not None else target_user.playmode
gamemode = last_score.gamemode
else:
gamemode = target_user.playmode
statistics = ( statistics = (
await session.exec( await session.exec(

View File

@@ -313,10 +313,7 @@ async def chat_websocket(
# 优先使用查询参数中的token支持token或access_token参数名 # 优先使用查询参数中的token支持token或access_token参数名
auth_token = token or access_token auth_token = token or access_token
if not auth_token and authorization: if not auth_token and authorization:
if authorization.startswith("Bearer "): auth_token = authorization.removeprefix("Bearer ")
auth_token = authorization[7:]
else:
auth_token = authorization
if not auth_token: if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token") await websocket.close(code=1008, reason="Missing authentication token")

View File

@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
redirect_router = APIRouter(include_in_schema=False) redirect_router = APIRouter(include_in_schema=False)
@redirect_router.get("/users/{path:path}") @redirect_router.get("/users/{path:path}") # noqa: FAST003
@redirect_router.get("/teams/{team_id}") @redirect_router.get("/teams/{team_id}")
@redirect_router.get("/u/{user_id}") @redirect_router.get("/u/{user_id}")
@redirect_router.get("/b/{beatmap_id}") @redirect_router.get("/b/{beatmap_id}")

View File

@@ -168,10 +168,7 @@ async def get_beatmaps(
elif beatmapset_id is not None: elif beatmapset_id is not None:
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id) beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
await beatmapset.awaitable_attrs.beatmaps await beatmapset.awaitable_attrs.beatmaps
if len(beatmapset.beatmaps) > limit: beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
beatmaps = beatmapset.beatmaps[:limit]
else:
beatmaps = beatmapset.beatmaps
elif user is not None: elif user is not None:
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all() beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()

View File

@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
if ruleset is None: if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode ruleset = beatmap_db.mode
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
)
if await redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try: try:

View File

@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp, response_model=SearchBeatmapsetsResp,
) )
async def search_beatmapset( async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)], query: Annotated[SearchQueryModel, Query(...)],
request: Request, request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@@ -104,7 +103,7 @@ async def search_beatmapset(
if cached_result: if cached_result:
sets = SearchBeatmapsetsResp(**cached_result) sets = SearchBeatmapsetsResp(**cached_result)
# 处理资源代理 # 处理资源代理
processed_sets = await process_response_assets(sets, request) processed_sets = await process_response_assets(sets)
return processed_sets return processed_sets
try: try:
@@ -115,7 +114,7 @@ async def search_beatmapset(
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump()) await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
# 处理资源代理 # 处理资源代理
processed_sets = await process_response_assets(sets, request) processed_sets = await process_response_assets(sets)
return processed_sets return processed_sets
except HTTPError as e: except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e)) from e raise HTTPException(status_code=500, detail=str(e)) from e
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id) cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
if cached_resp: if cached_resp:
# 处理资源代理 # 处理资源代理
processed_resp = await process_response_assets(cached_resp, request) processed_resp = await process_response_assets(cached_resp)
return processed_resp return processed_resp
try: try:
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
await cache_service.cache_beatmap_lookup(beatmap_id, resp) await cache_service.cache_beatmap_lookup(beatmap_id, resp)
# 处理资源代理 # 处理资源代理
processed_resp = await process_response_assets(resp, request) processed_resp = await process_response_assets(resp)
return processed_resp return processed_resp
except HTTPError as exc: except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmap not found") from exc raise HTTPException(status_code=404, detail="Beatmap not found") from exc
@@ -176,7 +175,7 @@ async def get_beatmapset(
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id) cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
if cached_resp: if cached_resp:
# 处理资源代理 # 处理资源代理
processed_resp = await process_response_assets(cached_resp, request) processed_resp = await process_response_assets(cached_resp)
return processed_resp return processed_resp
try: try:
@@ -187,7 +186,7 @@ async def get_beatmapset(
await cache_service.cache_beatmapset(resp) await cache_service.cache_beatmapset(resp)
# 处理资源代理 # 处理资源代理
processed_resp = await process_response_assets(resp, request) processed_resp = await process_response_assets(resp)
return processed_resp return processed_resp
except HTTPError as exc: except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc raise HTTPException(status_code=404, detail="Beatmapset not found") from exc

View File

@@ -166,7 +166,6 @@ async def get_room(
db: Database, db: Database,
room_id: Annotated[int, Path(..., description="房间 ID")], room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])], current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
redis: Redis,
category: Annotated[ category: Annotated[
str, str,
Query( Query(

View File

@@ -847,10 +847,7 @@ async def reorder_score_pin(
detail = "After score not found" if after_score_id else "Before score not found" detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail) raise HTTPException(status_code=404, detail=detail)
if after_score_id: target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
target_order = reference_score.pinned_order + 1
else:
target_order = reference_score.pinned_order
current_order = score_record.pinned_order current_order = score_record.pinned_order

View File

@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
message: str message: str
class VerifyFailed(Exception): class VerifyFailedError(Exception):
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False): def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
super().__init__(message) super().__init__(message)
self.reason = reason self.reason = reason
@@ -93,10 +93,7 @@ async def verify_session(
# 智能选择验证方法参考osu-web实现 # 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证 # API版本较老或用户未设置TOTP时强制使用邮件验证
# print(api_version, totp_key) # print(api_version, totp_key)
if api_version < 20240101 or totp_key is None: verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
verify_method = "mail"
else:
verify_method = "totp"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis) await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method login_method = verify_method
@@ -109,7 +106,7 @@ async def verify_session(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
) )
verify_method = "mail" verify_method = "mail"
raise VerifyFailed("用户TOTP已被删除已切换到邮件验证") raise VerifyFailedError("用户TOTP已被删除已切换到邮件验证")
# 如果未开启邮箱验证,则直接认为认证通过 # 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里 # 正常不会进入到这里
@@ -120,16 +117,16 @@ async def verify_session(
else: else:
# 记录详细的验证失败原因参考osu-web的错误处理 # 记录详细的验证失败原因参考osu-web的错误处理
if len(verification_key) != 6: if len(verification_key) != 6:
raise VerifyFailed("TOTP验证码长度错误应为6位数字", reason="incorrect_length") raise VerifyFailedError("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
elif not verification_key.isdigit(): elif not verification_key.isdigit():
raise VerifyFailed("TOTP验证码格式错误应为纯数字", reason="incorrect_format") raise VerifyFailedError("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
else: else:
# 可能是密钥错误或者重放攻击 # 可能是密钥错误或者重放攻击
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key") raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
else: else:
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key) success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
if not success: if not success:
raise VerifyFailed(f"邮件验证失败: {message}") raise VerifyFailedError(f"邮件验证失败: {message}")
await LoginLogService.record_login( await LoginLogService.record_login(
db=db, db=db,
@@ -144,7 +141,7 @@ async def verify_session(
await db.commit() await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)
except VerifyFailed as e: except VerifyFailedError as e:
await LoginLogService.record_failed_login( await LoginLogService.record_failed_login(
db=db, db=db,
request=request, request=request,
@@ -171,7 +168,9 @@ async def verify_session(
) )
error_response["reissued"] = True error_response["reissued"] = True
except Exception: except Exception:
pass # 忽略重发邮件失败的错误 log("Verification").exception(
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response) return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)

View File

@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode)) .where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
) )
).first() ).first()
if user_beatmap_score is None: return user_beatmap_score is not None
return False
return True
@router.put( @router.put(
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
.where(BeatmapTagVote.user_id == current_user.id) .where(BeatmapTagVote.user_id == current_user.id)
) )
).first() ).first()
if previous_votes is None: if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
if check_user_can_vote(current_user, beatmap_id, session): new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id) session.add(new_vote)
session.add(new_vote)
await session.commit() await session.commit()
except ValueError: except ValueError:
raise HTTPException(400, "Tag is not found") raise HTTPException(400, "Tag is not found")

View File

@@ -91,7 +91,7 @@ async def get_users(
# 处理资源代理 # 处理资源代理
response = BatchUserResponse(users=cached_users) response = BatchUserResponse(users=cached_users)
processed_response = await process_response_assets(response, request) processed_response = await process_response_assets(response)
return processed_response return processed_response
else: else:
searched_users = (await session.exec(select(User).limit(50))).all() searched_users = (await session.exec(select(User).limit(50))).all()
@@ -109,7 +109,7 @@ async def get_users(
# 处理资源代理 # 处理资源代理
response = BatchUserResponse(users=users) response = BatchUserResponse(users=users)
processed_response = await process_response_assets(response, request) processed_response = await process_response_assets(response)
return processed_response return processed_response
@@ -240,7 +240,7 @@ async def get_user_info(
cached_user = await cache_service.get_user_from_cache(user_id_int) cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user: if cached_user:
# 处理资源代理 # 处理资源代理
processed_user = await process_response_assets(cached_user, request) processed_user = await process_response_assets(cached_user)
return processed_user return processed_user
searched_user = ( searched_user = (
@@ -263,7 +263,7 @@ async def get_user_info(
background_task.add_task(cache_service.cache_user, user_resp) background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理 # 处理资源代理
processed_user = await process_response_assets(user_resp, request) processed_user = await process_response_assets(user_resp)
return processed_user return processed_user
@@ -381,7 +381,7 @@ async def get_user_scores(
user_id, type, include_fails, mode, limit, offset, is_legacy_api user_id, type, include_fails, mode, limit, offset, is_legacy_api
) )
if cached_scores is not None: if cached_scores is not None:
processed_scores = await process_response_assets(cached_scores, request) processed_scores = await process_response_assets(cached_scores)
return processed_scores return processed_scores
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
@@ -438,5 +438,5 @@ async def get_user_scores(
) )
# 处理资源代理 # 处理资源代理
processed_scores = await process_response_assets(score_responses, request) processed_scores = await process_response_assets(score_responses)
return processed_scores return processed_scores

View File

@@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request from fastapi import Request
async def process_response_assets(data: Any, request: Request) -> Any: async def process_response_assets(data: Any) -> Any:
""" """
根据配置处理响应数据中的资源URL 根据配置处理响应数据中的资源URL
@@ -72,7 +72,7 @@ def asset_proxy_response(func):
# 如果有request对象且启用了资源代理则处理响应 # 如果有request对象且启用了资源代理则处理响应
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path): if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
result = await process_response_assets(result, request) result = await process_response_assets(result)
return result return result

View File

@@ -113,6 +113,7 @@ class BeatmapCacheService:
if size: if size:
total_size += size total_size += size
except Exception: except Exception:
logger.debug(f"Failed to get size for key {key}")
continue continue
return { return {

View File

@@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str:
def generate_hash(data) -> str: def generate_hash(data) -> str:
"""生成数据的MD5哈希值""" """生成数据的MD5哈希值"""
if isinstance(data, str): content = data if isinstance(data, str) else safe_json_dumps(data)
content = data return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
else:
content = safe_json_dumps(data)
return hashlib.md5(content.encode()).hexdigest()
class BeatmapsetCacheService: class BeatmapsetCacheService:

View File

@@ -110,9 +110,7 @@ class ProcessingBeatmapset:
changed_beatmaps = [] changed_beatmaps = []
for bm in self.beatmapset.beatmaps: for bm in self.beatmapset.beatmaps:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None) saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
if not saved: if not saved or saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED)) changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm.checksum: elif saved["md5"] != bm.checksum:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED)) changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
@@ -285,7 +283,7 @@ class BeatmapsetUpdateService:
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp): async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
async with with_db() as session: async with with_db() as session:
db_beatmapset = await session.get(Beatmapset, beatmapset.id) db_beatmapset = await session.get(Beatmapset, beatmapset.id)
new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset) new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
if db_beatmapset: if db_beatmapset:
await session.merge(new_beatmapset) await session.merge(new_beatmapset)
await session.commit() await session.commit()
@@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic
def get_beatmapset_update_service() -> BeatmapsetUpdateService: def get_beatmapset_update_service() -> BeatmapsetUpdateService:
if service is None:
raise ValueError("BeatmapsetUpdateService is not initialized")
assert service is not None, "BeatmapsetUpdateService is not initialized" assert service is not None, "BeatmapsetUpdateService is not initialized"
return service return service

View File

@@ -128,7 +128,11 @@ class LoginLogService:
login_success=False, login_success=False,
login_method=login_method, login_method=login_method,
user_agent=user_agent, user_agent=user_agent,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt", notes=(
f"Failed login attempt on user {attempted_username}: {notes}"
if attempted_username
else "Failed login attempt"
),
) )

View File

@@ -120,7 +120,7 @@ class PasswordResetService:
await redis.delete(reset_code_key) await redis.delete(reset_code_key)
await redis.delete(rate_limit_key) await redis.delete(rate_limit_key)
except Exception: except Exception:
pass logger.warning("Failed to clean up Redis data after error")
logger.exception("Redis operation failed") logger.exception("Redis operation failed")
return False, "服务暂时不可用,请稍后重试" return False, "服务暂时不可用,请稍后重试"

View File

@@ -593,10 +593,7 @@ class RankingCacheService:
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None: async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
"""使地区排行榜缓存失效""" """使地区排行榜缓存失效"""
try: try:
if ruleset: pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*"
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
@@ -608,10 +605,7 @@ class RankingCacheService:
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None: async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
"""使战队排行榜缓存失效""" """使战队排行榜缓存失效"""
try: try:
if ruleset: pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*"
pattern = f"team_ranking:{ruleset}:*"
else:
pattern = "team_ranking:*"
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
@@ -637,6 +631,7 @@ class RankingCacheService:
if size: if size:
total_size += size total_size += size
except Exception: except Exception:
logger.warning(f"Failed to get memory usage for key {key}")
continue continue
return { return {

View File

View File

@@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber):
self.add_handler(ON_NOTIFICATION, self.on_notification) self.add_handler(ON_NOTIFICATION, self.on_notification)
self.start() self.start()
async def on_join_room(self, c: str, s: str): async def on_join_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":") channel_id, user_id = s.split(":")
if self.chat_server is None: if self.chat_server is None:
return return
await self.chat_server.join_room_channel(int(channel_id), int(user_id)) await self.chat_server.join_room_channel(int(channel_id), int(user_id))
async def on_leave_room(self, c: str, s: str): async def on_leave_room(self, c: str, s: str): # noqa: ARG002
channel_id, user_id = s.split(":") channel_id, user_id = s.split(":")
if self.chat_server is None: if self.chat_server is None:
return return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id)) await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
async def on_notification(self, c: str, s: str): async def on_notification(self, c: str, s: str): # noqa: ARG002
try: try:
detail = TypeAdapter(NotificationDetails).validate_json(s) detail = TypeAdapter(NotificationDetails).validate_json(s)
except ValueError: except ValueError:

View File

@@ -357,6 +357,7 @@ class UserCacheService:
if size: if size:
total_size += size total_size += size
except Exception: except Exception:
logger.warning(f"Failed to get memory usage for key {key}")
continue continue
return { return {

View File

@@ -288,10 +288,6 @@ This email was sent automatically, please do not reply.
redis: Redis, redis: Redis,
user_id: int, user_id: int,
code: str, code: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""验证邮箱验证码""" """验证邮箱验证码"""
try: try:

View File

@@ -41,7 +41,7 @@ async def warmup_cache() -> None:
logger.info("Beatmap cache warmup completed successfully") logger.info("Beatmap cache warmup completed successfully")
except Exception as e: except Exception as e:
logger.error("Beatmap cache warmup failed: %s", e) logger.error(f"Beatmap cache warmup failed: {e}")
async def refresh_ranking_cache() -> None: async def refresh_ranking_cache() -> None:
@@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None:
logger.info("Ranking cache refresh completed successfully") logger.info("Ranking cache refresh completed successfully")
except Exception as e: except Exception as e:
logger.error("Ranking cache refresh failed: %s", e) logger.error(f"Ranking cache refresh failed: {e}")
async def schedule_user_cache_preload_task() -> None: async def schedule_user_cache_preload_task() -> None:
@@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None:
if active_user_ids: if active_user_ids:
user_ids = [row[0] for row in active_user_ids] user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids) await cache_service.preload_user_cache(session, user_ids)
logger.info("Preloaded cache for %s active users", len(user_ids)) logger.info(f"Preloaded cache for {len(user_ids)} active users")
else: else:
logger.info("No active users found for cache preload") logger.info("No active users found for cache preload")
logger.info("User cache preload task completed successfully") logger.info("User cache preload task completed successfully")
except Exception as e: except Exception as e:
logger.error("User cache preload task failed: %s", e) logger.error(f"User cache preload task failed: {e}")
async def schedule_user_cache_warmup_task() -> None: async def schedule_user_cache_warmup_task() -> None:
@@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None:
if top_users: if top_users:
user_ids = list(top_users) user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids) await cache_service.preload_user_cache(session, user_ids)
logger.info("Warmed cache for top 100 users in %s", mode) logger.info(f"Warmed cache for top 100 users in {mode}")
await asyncio.sleep(1) await asyncio.sleep(1)
except Exception as e: except Exception as e:
logger.error("Failed to warm cache for %s: %s", mode, e) logger.error(f"Failed to warm cache for {mode}: {e}")
continue continue
logger.info("User cache warmup task completed successfully") logger.info("User cache warmup task completed successfully")
except Exception as e: except Exception as e:
logger.error("User cache warmup task failed: %s", e) logger.error(f"User cache warmup task failed: {e}")
async def schedule_user_cache_cleanup_task() -> None: async def schedule_user_cache_cleanup_task() -> None:
@@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None:
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats() stats = await cache_service.get_cache_stats()
logger.info("User cache stats: %s", stats) logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully") logger.info("User cache cleanup task completed successfully")
except Exception as e: except Exception as e:
logger.error("User cache cleanup task failed: %s", e) logger.error(f"User cache cleanup task failed: {e}")
async def warmup_user_cache() -> None: async def warmup_user_cache() -> None:
@@ -167,7 +167,7 @@ async def warmup_user_cache() -> None:
try: try:
await schedule_user_cache_warmup_task() await schedule_user_cache_warmup_task()
except Exception as e: except Exception as e:
logger.error("User cache warmup failed: %s", e) logger.error(f"User cache warmup failed: {e}")
async def preload_user_cache() -> None: async def preload_user_cache() -> None:
@@ -175,7 +175,7 @@ async def preload_user_cache() -> None:
try: try:
await schedule_user_cache_preload_task() await schedule_user_cache_preload_task()
except Exception as e: except Exception as e:
logger.error("User cache preload failed: %s", e) logger.error(f"User cache preload failed: {e}")
async def cleanup_user_cache() -> None: async def cleanup_user_cache() -> None:
@@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None:
try: try:
await schedule_user_cache_cleanup_task() await schedule_user_cache_cleanup_task()
except Exception as e: except Exception as e:
logger.error("User cache cleanup failed: %s", e) logger.error(f"User cache cleanup failed: {e}")
def register_cache_jobs() -> None: def register_cache_jobs() -> None:

View File

@@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database
from __future__ import annotations from __future__ import annotations
import asyncio
from app.config import settings from app.config import settings
from app.dependencies.geoip import get_geoip_helper from app.dependencies.geoip import get_geoip_helper
from app.dependencies.scheduler import get_scheduler from app.dependencies.scheduler import get_scheduler
@@ -28,14 +26,10 @@ async def update_geoip_database():
try: try:
logger.info("Starting scheduled GeoIP database update...") logger.info("Starting scheduled GeoIP database update...")
geoip = get_geoip_helper() geoip = get_geoip_helper()
await geoip.update(force=False)
# Run the synchronous update method in a background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False))
logger.info("Scheduled GeoIP database update completed successfully") logger.info("Scheduled GeoIP database update completed successfully")
except Exception as e: except Exception as exc:
logger.error(f"Scheduled GeoIP database update failed: {e}") logger.error(f"Scheduled GeoIP database update failed: {exc}")
async def init_geoip(): async def init_geoip():
@@ -45,13 +39,8 @@ async def init_geoip():
try: try:
geoip = get_geoip_helper() geoip = get_geoip_helper()
logger.info("Initializing GeoIP database...") logger.info("Initializing GeoIP database...")
await geoip.update(force=False)
# Run the synchronous update method in a background thread
# force=False means only download if files don't exist or are expired
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False))
logger.info("GeoIP database initialization completed") logger.info("GeoIP database initialization completed")
except Exception as e: except Exception as exc:
logger.error(f"GeoIP database initialization failed: {e}") logger.error(f"GeoIP database initialization failed: {exc}")
# Do not raise an exception to avoid blocking application startup # Do not raise an exception to avoid blocking application startup

View File

@@ -16,7 +16,7 @@ async def create_rx_statistics():
async with with_db() as session: async with with_db() as session:
users = (await session.exec(select(User.id))).all() users = (await session.exec(select(User.id))).all()
total_users = len(users) total_users = len(users)
logger.info("Ensuring RX/AP statistics exist for %s users", total_users) logger.info(f"Ensuring RX/AP statistics exist for {total_users} users")
rx_created = 0 rx_created = 0
ap_created = 0 ap_created = 0
for i in users: for i in users:
@@ -57,7 +57,5 @@ async def create_rx_statistics():
await session.commit() await session.commit()
if rx_created or ap_created: if rx_created or ap_created:
logger.success( logger.success(
"Created %s RX statistics rows and %s AP statistics rows during backfill", f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill"
rx_created,
ap_created,
) )

View File

@@ -258,10 +258,7 @@ class BackgroundTasks:
self.tasks = set(tasks) if tasks else set() self.tasks = set(tasks) if tasks else set()
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
if is_async_callable(func): coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs)
coro = func(*args, **kwargs)
else:
coro = run_in_threadpool(func, *args, **kwargs)
task = asyncio.create_task(coro) task = asyncio.create_task(coro)
self.tasks.add(task) self.tasks.add(task)
task.add_done_callback(self.tasks.discard) task.add_done_callback(self.tasks.discard)

13
main.py
View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import json
from pathlib import Path from pathlib import Path
from app.config import settings from app.config import settings
@@ -50,7 +51,7 @@ import sentry_sdk
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI): # noqa: ARG001
# on startup # on startup
init_mods() init_mods()
init_ranked_mods() init_ranked_mods()
@@ -223,26 +224,26 @@ async def health_check():
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError): # noqa: ARG001
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content={ content={
"error": exc.errors(), "error": json.dumps(exc.errors()),
}, },
) )
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(requst: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException): # noqa: ARG001
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
if settings.secret_key == "your_jwt_secret_here": if settings.secret_key == "your_jwt_secret_here": # noqa: S105
system_logger("Security").opt(colors=True).warning( system_logger("Security").opt(colors=True).warning(
"<y>jwt_secret_key</y> is unset. Your server is unsafe. " "<y>jwt_secret_key</y> is unset. Your server is unsafe. "
"Use this command to generate: <blue>openssl rand -hex 32</blue>." "Use this command to generate: <blue>openssl rand -hex 32</blue>."
) )
if settings.osu_web_client_secret == "your_osu_web_client_secret_here": if settings.osu_web_client_secret == "your_osu_web_client_secret_here": # noqa: S105
system_logger("Security").opt(colors=True).warning( system_logger("Security").opt(colors=True).warning(
"<y>osu_web_client_secret</y> is unset. Your server is unsafe. " "<y>osu_web_client_secret</y> is unset. Your server is unsafe. "
"Use this command to generate: <blue>openssl rand -hex 40</blue>." "Use this command to generate: <blue>openssl rand -hex 40</blue>."

View File

@@ -1,3 +1,4 @@
# ruff: noqa
"""add_password_reset_table """add_password_reset_table
Revision ID: d103d442dc24 Revision ID: d103d442dc24

View File

@@ -55,12 +55,20 @@ select = [
"ASYNC", # flake8-async "ASYNC", # flake8-async
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"T10", # flake8-debugger "T10", # flake8-debugger
# "T20", # flake8-print
"PYI", # flake8-pyi "PYI", # flake8-pyi
"PT", # flake8-pytest-style "PT", # flake8-pytest-style
"Q", # flake8-quotes "Q", # flake8-quotes
"TID", # flake8-tidy-imports "TID", # flake8-tidy-imports
"RUF", # Ruff-specific rules "RUF", # Ruff-specific rules
"FAST", # FastAPI
"YTT", # flake8-2020
"S", # flake8-bandit
"INP", # flake8-no-pep420
"SIM", # flake8-simplify
"ARG", # flake8-unused-arguments
"PTH", # flake8-use-pathlib
"N", # pep8-naming
"FURB" # refurb
] ]
ignore = [ ignore = [
"E402", # module-import-not-at-top-of-file "E402", # module-import-not-at-top-of-file
@@ -68,10 +76,17 @@ ignore = [
"RUF001", # ambiguous-unicode-character-string "RUF001", # ambiguous-unicode-character-string
"RUF002", # ambiguous-unicode-character-docstring "RUF002", # ambiguous-unicode-character-docstring
"RUF003", # ambiguous-unicode-character-comment "RUF003", # ambiguous-unicode-character-comment
"S101", # assert
"S311", # suspicious-non-cryptographic-random-usage
] ]
[tool.ruff.lint.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"app/database/**/*.py" = ["I002"] "app/database/**/*.py" = ["I002"]
"tools/*.py" = ["PTH", "INP001"]
"migrations/**/*.py" = ["INP001"]
".github/**/*.py" = ["INP001"]
"app/achievements/*.py" = ["INP001", "ARG"]
"app/router/**/*.py" = ["ARG001"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
force-sort-within-sections = true force-sort-within-sections = true

View File

@@ -163,13 +163,19 @@ async def main():
# Show specific changes # Show specific changes
changes = [] changes = []
if "scorerank" in original_payload and "scorerank" in fixed_payload: if (
if original_payload["scorerank"] != fixed_payload["scorerank"]: "scorerank" in original_payload
changes.append(f"scorerank: {original_payload['scorerank']}{fixed_payload['scorerank']}") and "scorerank" in fixed_payload
and original_payload["scorerank"] != fixed_payload["scorerank"]
):
changes.append(f"scorerank: {original_payload['scorerank']}{fixed_payload['scorerank']}")
if "mode" in original_payload and "mode" in fixed_payload: if (
if original_payload["mode"] != fixed_payload["mode"]: "mode" in original_payload
changes.append(f"mode: {original_payload['mode']}{fixed_payload['mode']}") and "mode" in fixed_payload
and original_payload["mode"] != fixed_payload["mode"]
):
changes.append(f"mode: {original_payload['mode']}{fixed_payload['mode']}")
if changes: if changes:
print(f" Changes: {', '.join(changes)}") print(f" Changes: {', '.join(changes)}")