Reject non-numeric --plot-along axes

Add explicit validation for plot-axis values so string/type axes fail with a
clear CLI error instead of a raw float conversion exception. Add regression
coverage for a type axis.
This commit is contained in:
Oleksandr Pavlyk
2026-06-25 16:52:32 -05:00
parent a81c1adc00
commit 75fa3062ce
2 changed files with 40 additions and 1 deletions

View File

@@ -1121,6 +1121,16 @@ def is_nonnegative_finite(value):
return value is not None and math.isfinite(value) and value >= 0.0
def parse_plot_axis_value(axis_name, axis_value):
try:
return float(axis_value)
except (TypeError, ValueError) as exc:
raise ValueError(
f"--plot-along requires numeric axis values; "
f"axis {axis_name!r} has value {axis_value!r}"
) from exc
def make_timing_interval(lower, upper, center):
if (
not is_positive_finite(lower)
@@ -2783,7 +2793,7 @@ def compare_benches(
if av["name"] != plot_along:
axis_name_parts.append(f"""{av["name"]} = {av["value"]}""")
else:
axis_value = float(av["value"])
axis_value = parse_plot_axis_value(av["name"], av["value"])
if axis_value is not None:
axis_name = ", ".join(axis_name_parts)

View File

@@ -1594,6 +1594,35 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp
assert run_data.stats.unknown_count == 0
def test_plot_along_rejects_non_numeric_axis_values(monkeypatch, nvbench_compare):
run_data = make_comparison_run_data(nvbench_compare)
ref_benches = [
make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")])
]
cmp_benches = [
make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")])
]
ref_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}]
cmp_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}]
with pytest.raises(
ValueError,
match="--plot-along requires numeric axis values; axis 'A' has value 'F32'",
):
nvbench_compare.compare_benches(
run_data,
ref_benches,
cmp_benches,
threshold=0.0,
plot_along="A",
plot=False,
dark=False,
filter_plan=make_filter_plan(nvbench_compare),
no_color=True,
)
def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare):
assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None
assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0]