chore(linter): update ruff rules
This commit is contained in:
@@ -126,21 +126,22 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
country_code = "CN" # 默认国家代码
|
||||
country_code = None # 默认国家代码
|
||||
|
||||
try:
|
||||
# 查询 IP 地理位置
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
if geo_info and (country_code := geo_info.get("country_iso")):
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||
if country_code is None:
|
||||
country_code = "CN"
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
@@ -157,7 +158,7 @@ async def register_user(
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country_code=country_code, # 根据 IP 地理位置设置国家
|
||||
country_code=country_code,
|
||||
join_date=utcnow(),
|
||||
last_visit=utcnow(),
|
||||
is_supporter=settings.enable_supporter_for_all_users,
|
||||
@@ -386,7 +387,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=scope,
|
||||
@@ -439,7 +440,7 @@ async def oauth_token(
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
@@ -509,7 +510,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
@@ -554,7 +555,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
|
||||
@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
|
||||
"allowed_mods": item_data.get("allowed_mods", []),
|
||||
"expired": bool(item_data.get("expired", False)),
|
||||
"playlist_order": item_data.get("playlist_order", default_order),
|
||||
"played_at": item_data.get("played_at", None),
|
||||
"played_at": item_data.get("played_at"),
|
||||
"freestyle": bool(item_data.get("freestyle", True)),
|
||||
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
|
||||
"star_rating": item_data.get("star_rating", 0.0),
|
||||
|
||||
@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
r = random.randint(1, 100)
|
||||
r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
|
||||
return f"{user.username} rolls {r} point(s)"
|
||||
|
||||
|
||||
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
|
||||
if gamemode is None:
|
||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
gamemode = target_user.playmode
|
||||
gamemode = last_score.gamemode if last_score is not None else target_user.playmode
|
||||
|
||||
statistics = (
|
||||
await session.exec(
|
||||
|
||||
@@ -313,10 +313,7 @@ async def chat_websocket(
|
||||
# 优先使用查询参数中的token,支持token或access_token参数名
|
||||
auth_token = token or access_token
|
||||
if not auth_token and authorization:
|
||||
if authorization.startswith("Bearer "):
|
||||
auth_token = authorization[7:]
|
||||
else:
|
||||
auth_token = authorization
|
||||
auth_token = authorization.removeprefix("Bearer ")
|
||||
|
||||
if not auth_token:
|
||||
await websocket.close(code=1008, reason="Missing authentication token")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
|
||||
redirect_router = APIRouter(include_in_schema=False)
|
||||
|
||||
|
||||
@redirect_router.get("/users/{path:path}")
|
||||
@redirect_router.get("/users/{path:path}") # noqa: FAST003
|
||||
@redirect_router.get("/teams/{team_id}")
|
||||
@redirect_router.get("/u/{user_id}")
|
||||
@redirect_router.get("/b/{beatmap_id}")
|
||||
|
||||
@@ -168,10 +168,7 @@ async def get_beatmaps(
|
||||
elif beatmapset_id is not None:
|
||||
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
|
||||
await beatmapset.awaitable_attrs.beatmaps
|
||||
if len(beatmapset.beatmaps) > limit:
|
||||
beatmaps = beatmapset.beatmaps[:limit]
|
||||
else:
|
||||
beatmaps = beatmapset.beatmaps
|
||||
beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
|
||||
elif user is not None:
|
||||
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
||||
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
||||
|
||||
@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
|
||||
)
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
response_model=SearchBeatmapsetsResp,
|
||||
)
|
||||
async def search_beatmapset(
|
||||
db: Database,
|
||||
query: Annotated[SearchQueryModel, Query(...)],
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
@@ -104,7 +103,7 @@ async def search_beatmapset(
|
||||
if cached_result:
|
||||
sets = SearchBeatmapsetsResp(**cached_result)
|
||||
# 处理资源代理
|
||||
processed_sets = await process_response_assets(sets, request)
|
||||
processed_sets = await process_response_assets(sets)
|
||||
return processed_sets
|
||||
|
||||
try:
|
||||
@@ -115,7 +114,7 @@ async def search_beatmapset(
|
||||
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
||||
|
||||
# 处理资源代理
|
||||
processed_sets = await process_response_assets(sets, request)
|
||||
processed_sets = await process_response_assets(sets)
|
||||
return processed_sets
|
||||
except HTTPError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
|
||||
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
||||
if cached_resp:
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(cached_resp, request)
|
||||
processed_resp = await process_response_assets(cached_resp)
|
||||
return processed_resp
|
||||
|
||||
try:
|
||||
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
|
||||
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(resp, request)
|
||||
processed_resp = await process_response_assets(resp)
|
||||
return processed_resp
|
||||
except HTTPError as exc:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
|
||||
@@ -176,7 +175,7 @@ async def get_beatmapset(
|
||||
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
||||
if cached_resp:
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(cached_resp, request)
|
||||
processed_resp = await process_response_assets(cached_resp)
|
||||
return processed_resp
|
||||
|
||||
try:
|
||||
@@ -187,7 +186,7 @@ async def get_beatmapset(
|
||||
await cache_service.cache_beatmapset(resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(resp, request)
|
||||
processed_resp = await process_response_assets(resp)
|
||||
return processed_resp
|
||||
except HTTPError as exc:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc
|
||||
|
||||
@@ -166,7 +166,6 @@ async def get_room(
|
||||
db: Database,
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
redis: Redis,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(
|
||||
|
||||
@@ -847,10 +847,7 @@ async def reorder_score_pin(
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
if after_score_id:
|
||||
target_order = reference_score.pinned_order + 1
|
||||
else:
|
||||
target_order = reference_score.pinned_order
|
||||
target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
|
||||
|
||||
current_order = score_record.pinned_order
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class VerifyFailed(Exception):
|
||||
class VerifyFailedError(Exception):
|
||||
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
||||
super().__init__(message)
|
||||
self.reason = reason
|
||||
@@ -93,10 +93,7 @@ async def verify_session(
|
||||
# 智能选择验证方法(参考osu-web实现)
|
||||
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||
# print(api_version, totp_key)
|
||||
if api_version < 20240101 or totp_key is None:
|
||||
verify_method = "mail"
|
||||
else:
|
||||
verify_method = "totp"
|
||||
verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
|
||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||
login_method = verify_method
|
||||
|
||||
@@ -109,7 +106,7 @@ async def verify_session(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
)
|
||||
verify_method = "mail"
|
||||
raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
|
||||
raise VerifyFailedError("用户TOTP已被删除,已切换到邮件验证")
|
||||
# 如果未开启邮箱验证,则直接认为认证通过
|
||||
# 正常不会进入到这里
|
||||
|
||||
@@ -120,16 +117,16 @@ async def verify_session(
|
||||
else:
|
||||
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
||||
if len(verification_key) != 6:
|
||||
raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||
raise VerifyFailedError("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||
elif not verification_key.isdigit():
|
||||
raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||
raise VerifyFailedError("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||
else:
|
||||
# 可能是密钥错误或者重放攻击
|
||||
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||
raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||
else:
|
||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||
if not success:
|
||||
raise VerifyFailed(f"邮件验证失败: {message}")
|
||||
raise VerifyFailedError(f"邮件验证失败: {message}")
|
||||
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -144,7 +141,7 @@ async def verify_session(
|
||||
await db.commit()
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
except VerifyFailed as e:
|
||||
except VerifyFailedError as e:
|
||||
await LoginLogService.record_failed_login(
|
||||
db=db,
|
||||
request=request,
|
||||
@@ -171,7 +168,9 @@ async def verify_session(
|
||||
)
|
||||
error_response["reissued"] = True
|
||||
except Exception:
|
||||
pass # 忽略重发邮件失败的错误
|
||||
log("Verification").exception(
|
||||
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
|
||||
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
|
||||
)
|
||||
).first()
|
||||
if user_beatmap_score is None:
|
||||
return False
|
||||
return True
|
||||
return user_beatmap_score is not None
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
|
||||
.where(BeatmapTagVote.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if previous_votes is None:
|
||||
if check_user_can_vote(current_user, beatmap_id, session):
|
||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||
session.add(new_vote)
|
||||
if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
|
||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||
session.add(new_vote)
|
||||
await session.commit()
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Tag is not found")
|
||||
|
||||
@@ -91,7 +91,7 @@ async def get_users(
|
||||
|
||||
# 处理资源代理
|
||||
response = BatchUserResponse(users=cached_users)
|
||||
processed_response = await process_response_assets(response, request)
|
||||
processed_response = await process_response_assets(response)
|
||||
return processed_response
|
||||
else:
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
@@ -109,7 +109,7 @@ async def get_users(
|
||||
|
||||
# 处理资源代理
|
||||
response = BatchUserResponse(users=users)
|
||||
processed_response = await process_response_assets(response, request)
|
||||
processed_response = await process_response_assets(response)
|
||||
return processed_response
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ async def get_user_info(
|
||||
cached_user = await cache_service.get_user_from_cache(user_id_int)
|
||||
if cached_user:
|
||||
# 处理资源代理
|
||||
processed_user = await process_response_assets(cached_user, request)
|
||||
processed_user = await process_response_assets(cached_user)
|
||||
return processed_user
|
||||
|
||||
searched_user = (
|
||||
@@ -263,7 +263,7 @@ async def get_user_info(
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_user = await process_response_assets(user_resp, request)
|
||||
processed_user = await process_response_assets(user_resp)
|
||||
return processed_user
|
||||
|
||||
|
||||
@@ -381,7 +381,7 @@ async def get_user_scores(
|
||||
user_id, type, include_fails, mode, limit, offset, is_legacy_api
|
||||
)
|
||||
if cached_scores is not None:
|
||||
processed_scores = await process_response_assets(cached_scores, request)
|
||||
processed_scores = await process_response_assets(cached_scores)
|
||||
return processed_scores
|
||||
|
||||
db_user = await session.get(User, user_id)
|
||||
@@ -438,5 +438,5 @@ async def get_user_scores(
|
||||
)
|
||||
|
||||
# 处理资源代理
|
||||
processed_scores = await process_response_assets(score_responses, request)
|
||||
processed_scores = await process_response_assets(score_responses)
|
||||
return processed_scores
|
||||
|
||||
Reference in New Issue
Block a user