diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index fd4fdb0..13ed14d 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -1121,6 +1121,16 @@ def is_nonnegative_finite(value): return value is not None and math.isfinite(value) and value >= 0.0 +def parse_plot_axis_value(axis_name, axis_value): + try: + return float(axis_value) + except (TypeError, ValueError) as exc: + raise ValueError( + f"--plot-along requires numeric axis values; " + f"axis {axis_name!r} has value {axis_value!r}" + ) from exc + + def make_timing_interval(lower, upper, center): if ( not is_positive_finite(lower) @@ -2783,7 +2793,7 @@ def compare_benches( if av["name"] != plot_along: axis_name_parts.append(f"""{av["name"]} = {av["value"]}""") else: - axis_value = float(av["value"]) + axis_value = parse_plot_axis_value(av["name"], av["value"]) if axis_value is not None: axis_name = ", ".join(axis_name_parts) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 1d2e2fc..9e1b3ee 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -1594,6 +1594,35 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp assert run_data.stats.unknown_count == 0 +def test_plot_along_rejects_non_numeric_axis_values(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")]) + ] + cmp_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value="F32")]) + ] + ref_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}] + cmp_benches[0]["axes"] = [{"name": "A", "type": "type", "flags": ""}] + + with pytest.raises( + ValueError, + match="--plot-along requires numeric axis values; axis 'A' has value 'F32'", + ): + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0]