diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index 13ed14d..9d54889 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -1123,12 +1123,18 @@ def is_nonnegative_finite(value): def parse_plot_axis_value(axis_name, axis_value): try: - return float(axis_value) + value = float(axis_value) except (TypeError, ValueError) as exc: raise ValueError( f"--plot-along requires numeric axis values; " f"axis {axis_name!r} has value {axis_value!r}" ) from exc + if not is_positive_finite(value): + raise ValueError( + f"--plot-along requires positive finite axis values; " + f"axis {axis_name!r} has value {axis_value!r}" + ) + return value def make_timing_interval(lower, upper, center): @@ -2103,9 +2109,7 @@ def build_benchmark_filter_plan(filter_actions): def benchmark_is_selected(benchmark_name, filter_plan): - return not filter_plan.benchmark_scopes or any( - scope.benchmark_name == benchmark_name for scope in filter_plan.benchmark_scopes - ) + return bool(axis_filter_groups_for_benchmark(benchmark_name, filter_plan)) def axis_filter_groups_for_benchmark(benchmark_name, filter_plan): @@ -2117,10 +2121,15 @@ def axis_filter_groups_for_benchmark(benchmark_name, filter_plan): for scope in filter_plan.benchmark_scopes if scope.benchmark_name == benchmark_name ] - return [ - filter_plan.global_axis_filters + scope.axis_filters - for scope in matching_scopes - ] + + if matching_scopes: + return [ + filter_plan.global_axis_filters + scope.axis_filters + for scope in matching_scopes + ] + if filter_plan.global_axis_filters: + return [filter_plan.global_axis_filters] + return [] def matches_axis_filters(state, axis_filters): diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index ed4b203..4f2bc2c 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -1634,6 +1634,38 @@ def test_plot_along_rejects_non_numeric_axis_values(monkeypatch, nvbench_compare ) +@pytest.mark.parametrize("axis_value", ["0", "-1", "nan", "inf"]) +def test_plot_along_rejects_non_positive_or_non_finite_axis_values( + monkeypatch, nvbench_compare, axis_value +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value=axis_value)]) + ] + cmp_benches = [ + make_benchmark([make_state(nvbench_compare, "state", axis_value=axis_value)]) + ] + ref_benches[0]["axes"] = [{"name": "A", "type": "float64", "flags": ""}] + cmp_benches[0]["axes"] = [{"name": "A", "type": "float64", "flags": ""}] + + with pytest.raises( + ValueError, + match="--plot-along requires positive finite axis values", + ): + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0] @@ -1782,6 +1814,67 @@ def test_axis_filter_applies_to_most_recent_benchmark(monkeypatch, nvbench_compa assert run_data.stats.unknown_count == 0 +def test_global_axis_filter_still_applies_after_benchmark_scope( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan( + nvbench_compare, + [("axis", "A=2"), ("benchmark", "bench1")], + ), + no_color=True, + ) + + assert run_data.stats.config_count == 2 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 2 + assert run_data.stats.unknown_count == 0 + + def test_main_returns_success_exit_code_when_regressions_are_detected( monkeypatch, capsys, nvbench_compare ):