chore(linter): update ruff rules

This commit is contained in:
MingxuanGame
2025-10-03 15:46:53 +00:00
parent b10425ad91
commit d490239f46
59 changed files with 393 additions and 425 deletions

View File

@@ -126,21 +126,22 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
country_code = "CN" # 默认国家代码
country_code = None # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
if geo_info and (country_code := geo_info.get("country_iso")):
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
if country_code is None:
country_code = "CN"
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
# 确保 AUTO_INCREMENT 值从3开始ID=2是BanchoBot
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
@@ -157,7 +158,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country_code=country_code, # 根据 IP 地理位置设置国家
country_code=country_code,
join_date=utcnow(),
last_visit=utcnow(),
is_supporter=settings.enable_supporter_for_all_users,
@@ -386,7 +387,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=scope,
@@ -439,7 +440,7 @@ async def oauth_token(
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token,
scope=scope,
@@ -509,7 +510,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
@@ -554,7 +555,7 @@ async def oauth_token(
return TokenResponse(
access_token=access_token,
token_type="Bearer",
token_type="Bearer", # noqa: S106
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),

View File

@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
"allowed_mods": item_data.get("allowed_mods", []),
"expired": bool(item_data.get("expired", False)),
"playlist_order": item_data.get("playlist_order", default_order),
"played_at": item_data.get("played_at", None),
"played_at": item_data.get("played_at"),
"freestyle": bool(item_data.get("freestyle", True)),
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
"star_rating": item_data.get("star_rating", 0.0),

View File

@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
@bot.command("roll")
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
r = random.randint(1, 100)
r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
return f"{user.username} rolls {r} point(s)"
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
if gamemode is None:
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
gamemode = target_user.playmode
gamemode = last_score.gamemode if last_score is not None else target_user.playmode
statistics = (
await session.exec(

View File

@@ -313,10 +313,7 @@ async def chat_websocket(
# 优先使用查询参数中的token支持token或access_token参数名
auth_token = token or access_token
if not auth_token and authorization:
if authorization.startswith("Bearer "):
auth_token = authorization[7:]
else:
auth_token = authorization
auth_token = authorization.removeprefix("Bearer ")
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")

View File

@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
redirect_router = APIRouter(include_in_schema=False)
@redirect_router.get("/users/{path:path}")
@redirect_router.get("/users/{path:path}") # noqa: FAST003
@redirect_router.get("/teams/{team_id}")
@redirect_router.get("/u/{user_id}")
@redirect_router.get("/b/{beatmap_id}")

View File

@@ -168,10 +168,7 @@ async def get_beatmaps(
elif beatmapset_id is not None:
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
await beatmapset.awaitable_attrs.beatmaps
if len(beatmapset.beatmaps) > limit:
beatmaps = beatmapset.beatmaps[:limit]
else:
beatmaps = beatmapset.beatmaps
beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
elif user is not None:
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()

View File

@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
)
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:

View File

@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
@@ -104,7 +103,7 @@ async def search_beatmapset(
if cached_result:
sets = SearchBeatmapsetsResp(**cached_result)
# 处理资源代理
processed_sets = await process_response_assets(sets, request)
processed_sets = await process_response_assets(sets)
return processed_sets
try:
@@ -115,7 +114,7 @@ async def search_beatmapset(
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
# 处理资源代理
processed_sets = await process_response_assets(sets, request)
processed_sets = await process_response_assets(sets)
return processed_sets
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp, request)
processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
# 处理资源代理
processed_resp = await process_response_assets(resp, request)
processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
@@ -176,7 +175,7 @@ async def get_beatmapset(
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp, request)
processed_resp = await process_response_assets(cached_resp)
return processed_resp
try:
@@ -187,7 +186,7 @@ async def get_beatmapset(
await cache_service.cache_beatmapset(resp)
# 处理资源代理
processed_resp = await process_response_assets(resp, request)
processed_resp = await process_response_assets(resp)
return processed_resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc

View File

@@ -166,7 +166,6 @@ async def get_room(
db: Database,
room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
redis: Redis,
category: Annotated[
str,
Query(

View File

@@ -847,10 +847,7 @@ async def reorder_score_pin(
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
if after_score_id:
target_order = reference_score.pinned_order + 1
else:
target_order = reference_score.pinned_order
target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
current_order = score_record.pinned_order

View File

@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
message: str
class VerifyFailed(Exception):
class VerifyFailedError(Exception):
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
super().__init__(message)
self.reason = reason
@@ -93,10 +93,7 @@ async def verify_session(
# 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证
# print(api_version, totp_key)
if api_version < 20240101 or totp_key is None:
verify_method = "mail"
else:
verify_method = "totp"
verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method
@@ -109,7 +106,7 @@ async def verify_session(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
verify_method = "mail"
raise VerifyFailed("用户TOTP已被删除已切换到邮件验证")
raise VerifyFailedError("用户TOTP已被删除已切换到邮件验证")
# 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里
@@ -120,16 +117,16 @@ async def verify_session(
else:
# 记录详细的验证失败原因参考osu-web的错误处理
if len(verification_key) != 6:
raise VerifyFailed("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
raise VerifyFailedError("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
elif not verification_key.isdigit():
raise VerifyFailed("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
raise VerifyFailedError("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
else:
# 可能是密钥错误或者重放攻击
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
else:
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
if not success:
raise VerifyFailed(f"邮件验证失败: {message}")
raise VerifyFailedError(f"邮件验证失败: {message}")
await LoginLogService.record_login(
db=db,
@@ -144,7 +141,7 @@ async def verify_session(
await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
except VerifyFailed as e:
except VerifyFailedError as e:
await LoginLogService.record_failed_login(
db=db,
request=request,
@@ -171,7 +168,9 @@ async def verify_session(
)
error_response["reissued"] = True
except Exception:
pass # 忽略重发邮件失败的错误
log("Verification").exception(
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)

View File

@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
)
).first()
if user_beatmap_score is None:
return False
return True
return user_beatmap_score is not None
@router.put(
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
.where(BeatmapTagVote.user_id == current_user.id)
)
).first()
if previous_votes is None:
if check_user_can_vote(current_user, beatmap_id, session):
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
session.add(new_vote)
if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
session.add(new_vote)
await session.commit()
except ValueError:
raise HTTPException(400, "Tag is not found")

View File

@@ -91,7 +91,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=cached_users)
processed_response = await process_response_assets(response, request)
processed_response = await process_response_assets(response)
return processed_response
else:
searched_users = (await session.exec(select(User).limit(50))).all()
@@ -109,7 +109,7 @@ async def get_users(
# 处理资源代理
response = BatchUserResponse(users=users)
processed_response = await process_response_assets(response, request)
processed_response = await process_response_assets(response)
return processed_response
@@ -240,7 +240,7 @@ async def get_user_info(
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
# 处理资源代理
processed_user = await process_response_assets(cached_user, request)
processed_user = await process_response_assets(cached_user)
return processed_user
searched_user = (
@@ -263,7 +263,7 @@ async def get_user_info(
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
processed_user = await process_response_assets(user_resp, request)
processed_user = await process_response_assets(user_resp)
return processed_user
@@ -381,7 +381,7 @@ async def get_user_scores(
user_id, type, include_fails, mode, limit, offset, is_legacy_api
)
if cached_scores is not None:
processed_scores = await process_response_assets(cached_scores, request)
processed_scores = await process_response_assets(cached_scores)
return processed_scores
db_user = await session.get(User, user_id)
@@ -438,5 +438,5 @@ async def get_user_scores(
)
# 处理资源代理
processed_scores = await process_response_assets(score_responses, request)
processed_scores = await process_response_assets(score_responses)
return processed_scores