From 75fa3062ceecfea1bc6bfa71bfd583d5243d1e87 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 25 Jun 2026 16:52:32 -0500 Subject: [PATCH] Reject non-numeric --plot-along axes Add explicit validation for plot-axis values so string/type axes fail with a clear CLI error instead of a raw float conversion exception. Add regression coverage for a type axis. --- python/scripts/nvbench_compare.py | 12 +++++++++++- python/test/test_nvbench_compare.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index fd4fdb0..13ed14d 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -1121,6 +1121,16 @@ def is_nonnegative_finite(value): return value is not None and math.isfinite(value) and value >= 0.0 +def parse_plot_axis_value(axis_name, axis_value): + try: + return float(axis_value) + except (TypeError, ValueError) as exc: + raise ValueError( + f"--plot-along requires numeric axis values; " + f"axis {axis_name!r} has value {axis_value!r}" + ) from exc + + def make_timing_interval(lower, upper, center): if ( not is_positive_finite(lower) @@ -2783,7 +2793,7 @@ def compare_benches( if av["name"] != plot_along: axis_name_parts.append(f"""{av["name"]} = {av["value"]}""") else: - axis_value = float(av["value"]) + axis_value = parse_plot_axis_value(av["name"], av["value"]) if axis_value is not None: axis_name = ", ".join(axis_name_parts) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 1d2e2fc..9e1b3ee 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -1594,6 +1594,35 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp assert run_data.stats.unknown_count == 0 +def test_plot_along_rejects_non_numeric_axis_values(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")]) + ] + cmp_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")]) + ] + ref_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}] + cmp_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}] + + with pytest.raises( + ValueError, + match="--plot-along requires numeric axis values; axis 'A' has value 'F32'", + ): + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0]