diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index cacb419..99d6485 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -7,6 +7,7 @@ import argparse import math import os import sys +from collections import Counter from dataclasses import dataclass from enum import Enum @@ -66,6 +67,121 @@ class TimeEstimate: relative_dispersion: float | None +@dataclass(frozen=True) +class BenchmarkFilterScope: + benchmark_name: str + axis_filters: list[dict] + + +@dataclass(frozen=True) +class BenchmarkFilterPlan: + global_axis_filters: list[dict] + benchmark_scopes: list[BenchmarkFilterScope] + + +class OrderedBenchmarkFilterAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + actions = getattr(namespace, self.dest, None) + actions = [] if actions is None else list(actions) + action_kind = "axis" if option_string in {"-a", "--axis"} else "benchmark" + actions.append((action_kind, values)) + setattr(namespace, self.dest, actions) + + +def state_match_key(state): + device_prefix = f"Device={state['device']}" + state_name = state["name"] + if state_name == device_prefix: + return "" + if state_name.startswith(f"{device_prefix} "): + return state_name[len(device_prefix) + 1 :] + return state_name + + +def group_states_by_match_key(states): + grouped = {} + for state in states: + grouped.setdefault(state_match_key(state), []).append(state) + return grouped + + +def state_group_counts(grouped_states): + return Counter( + {state_name: len(states) for state_name, states in grouped_states.items()} + ) + + +def format_device_ids(device_ids): + return ", ".join(str(device_id) for device_id in device_ids) + + +def parse_device_filter(device_arg, option_name): + device_arg = device_arg.strip() + if device_arg.lower() == "all": + return None + + values = [value.strip() for value in device_arg.split(",")] + if not all(values): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + + try: + device_ids = [int(value) for value in values] + except ValueError as exc: + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) from exc + if any(device_id < 0 for device_id in device_ids): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + return device_ids + + +def select_devices(all_devices, device_filter, option_name): + if device_filter is None: + return list(all_devices) + + devices_by_id = {device["id"]: device for device in all_devices} + missing_ids = [ + device_id for device_id in device_filter if device_id not in devices_by_id + ] + if missing_ids: + raise ValueError( + f"{option_name} requested device id(s) not present in input: " + f"{format_device_ids(missing_ids)}" + ) + + return [devices_by_id[device_id] for device_id in device_filter] + + +def resolve_benchmark_device_ids(bench, device_filter, option_name): + if device_filter is None: + return list(bench["devices"]) + + benchmark_device_ids = set(bench["devices"]) + missing_ids = [ + device_id + for device_id in device_filter + if device_id not in benchmark_device_ids + ] + if missing_ids: + raise ValueError( + f"benchmark {bench['name']!r} does not contain {option_name} " + f"device id(s): {format_device_ids(missing_ids)}" + ) + + return device_filter + + +def require_matching_device_sections(reference_device_filter, compare_device_filter): + return reference_device_filter is None and compare_device_filter is None + + # TODO(opavlyk): replace with Emoji(StrEnum) after EOL of Python 3.10 class Emoji(str, Enum): YELLOW = "\U0001f7e1" @@ -328,6 +444,53 @@ def parse_axis_filters(axis_args): return filters +def build_benchmark_filter_plan(filter_actions): + global_axis_args = [] + benchmark_scopes = [] + current_scope = None + + for action_kind, action_value in filter_actions or []: + if action_kind == "benchmark": + current_scope = {"benchmark_name": action_value, "axis_args": []} + benchmark_scopes.append(current_scope) + elif current_scope is None: + global_axis_args.append(action_value) + else: + current_scope["axis_args"].append(action_value) + + return BenchmarkFilterPlan( + global_axis_filters=parse_axis_filters(global_axis_args), + benchmark_scopes=[ + BenchmarkFilterScope( + benchmark_name=scope["benchmark_name"], + axis_filters=parse_axis_filters(scope["axis_args"]), + ) + for scope in benchmark_scopes + ], + ) + + +def benchmark_is_selected(benchmark_name, filter_plan): + return not filter_plan.benchmark_scopes or any( + scope.benchmark_name == benchmark_name for scope in filter_plan.benchmark_scopes + ) + + +def axis_filter_groups_for_benchmark(benchmark_name, filter_plan): + if not filter_plan.benchmark_scopes: + return [filter_plan.global_axis_filters] + + matching_scopes = [ + scope + for scope in filter_plan.benchmark_scopes + if scope.benchmark_name == benchmark_name + ] + return [ + filter_plan.global_axis_filters + scope.axis_filters + for scope in matching_scopes + ] + + def matches_axis_filters(state, axis_filters): if not axis_filters: return True @@ -351,6 +514,23 @@ def matches_axis_filters(state, axis_filters): return True +def matches_axis_filter_groups(state, axis_filter_groups): + return any( + matches_axis_filters(state, axis_filters) for axis_filters in axis_filter_groups + ) + + +def matching_axis_filters(state, axis_filter_groups): + return next( + ( + axis_filters + for axis_filters in axis_filter_groups + if matches_axis_filters(state, axis_filters) + ), + [], + ) + + def format_duration(seconds): if seconds >= 1: multiplier = 1.0 @@ -479,9 +659,10 @@ def compare_benches( plot_along, plot, dark, - axis_filters, - benchmark_filters, + filter_plan, no_color, + reference_device_filter=None, + compare_device_filter=None, ): if plot_along: import matplotlib.pyplot as plt @@ -495,12 +676,28 @@ def compare_benches( ref_bench = find_matching_bench(cmp_bench, ref_benches) if not ref_bench: continue - if benchmark_filters and cmp_bench["name"] not in benchmark_filters: + if not benchmark_is_selected(cmp_bench["name"], filter_plan): continue + axis_filter_groups = axis_filter_groups_for_benchmark( + cmp_bench["name"], filter_plan + ) + + cmp_device_ids = resolve_benchmark_device_ids( + cmp_bench, compare_device_filter, "--compare-devices" + ) + ref_device_ids = resolve_benchmark_device_ids( + ref_bench, reference_device_filter, "--reference-devices" + ) + if len(cmp_device_ids) != len(ref_device_ids): + raise ValueError( + f"benchmark {cmp_bench['name']!r} has {len(ref_device_ids)} " + f"reference device(s) but {len(cmp_device_ids)} compare device(s); " + "nvbench_compare pairs devices by position, so each compared " + "benchmark must contain the same number of devices" + ) print(f"""# {cmp_bench["name"]}\n""") - cmp_device_ids = cmp_bench["devices"] axes = cmp_bench["axes"] ref_states = ref_bench["states"] cmp_states = cmp_bench["states"] @@ -525,20 +722,43 @@ def compare_benches( headers.append("Status") colalign.append("center") - for cmp_device_id in cmp_device_ids: + for cmp_device_index, cmp_device_id in enumerate(cmp_device_ids): + ref_device_id = ref_device_ids[cmp_device_index] + ref_device_states = [ + state + for state in ref_states + if state["device"] == ref_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + cmp_device_states = [ + state + for state in cmp_states + if state["device"] == cmp_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + ref_states_by_name = group_states_by_match_key(ref_device_states) + cmp_states_by_name = group_states_by_match_key(cmp_device_states) + ref_state_counts = state_group_counts(ref_states_by_name) + cmp_state_counts = state_group_counts(cmp_states_by_name) + if ref_state_counts != cmp_state_counts: + raise ValueError( + f"benchmark {cmp_bench['name']!r} device pair " + f"ref={ref_device_id} cmp={cmp_device_id} has mismatched " + f"state occurrences: ref={dict(ref_state_counts)}, " + f"cmp={dict(cmp_state_counts)}" + ) + rows = [] plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} + counters = {} - for cmp_state in cmp_states: - cmp_state_name = cmp_state["name"] - ref_state = next( - filter(lambda st: st["name"] == cmp_state_name, ref_states), None - ) - if not ref_state: - continue - if not matches_axis_filters(cmp_state, axis_filters): - continue - + for cmp_state in cmp_device_states: + cmp_state_name = state_match_key(cmp_state) + occurrence = counters.get(cmp_state_name, 0) + counters[cmp_state_name] = occurrence + 1 + # Duplicate state names are matched by occurrence order within + # the filtered device section. + ref_state = ref_states_by_name[cmp_state_name][occurrence] axis_values = cmp_state["axis_values"] if not axis_values: axis_values = [] @@ -632,6 +852,7 @@ def compare_benches( status = colorize(status_label, Fore.RED, Emoji.RED, no_color) if abs(frac_diff) >= threshold: + axis_filters = matching_axis_filters(cmp_state, axis_filter_groups) row.append(format_duration(ref_time)) row.append(format_percentage(ref_noise)) row.append(format_duration(cmp_time)) @@ -660,7 +881,12 @@ def compare_benches( continue cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices) - ref_device = find_device_by_id(ref_state["device"], all_ref_devices) + ref_device = find_device_by_id(ref_device_id, all_ref_devices) + if ref_device is None or cmp_device is None: + raise ValueError( + f"benchmark {cmp_bench['name']!r} references device pair " + f"ref={ref_device_id} cmp={cmp_device_id}, but device metadata is missing" + ) if cmp_device == ref_device: print(f"## [{cmp_device['id']}] {cmp_device['name']}\n") @@ -756,10 +982,10 @@ def compare_benches( title = "%SOL Bandwidth change" if len(comparison_device_names) == 1: title = f"{title} - {next(iter(comparison_device_names))}" - if axis_filters: + if filter_plan.global_axis_filters: axis_label = ", ".join( axis_filter["display"] - for axis_filter in axis_filters + for axis_filter in filter_plan.global_axis_filters if len(axis_filter["values"]) == 1 ) if axis_label: @@ -812,24 +1038,44 @@ def main() -> int: action="store_true", help="Use emoji instead of ANSI color codes (useful for GitHub issues/PRs)", ) + parser.add_argument( + "--reference-devices", + default="all", + help="Reference devices to compare: all, a non-negative integer id, or comma-separated ids", + ) + parser.add_argument( + "--compare-devices", + default="all", + help="Compare devices to compare: all, a non-negative integer id, or comma-separated ids", + ) parser.add_argument( "-a", "--axis", - action="append", - default=[], - help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)", + dest="filter_actions", + action=OrderedBenchmarkFilterAction, + help=( + "Filter on axis value, e.g. -a Elements{io}=2^20. Applies to the " + "most recent --benchmark, or all benchmarks if specified before any " + "--benchmark arguments." + ), ) parser.add_argument( "-b", "--benchmark", - action="append", - default=[], + dest="filter_actions", + action=OrderedBenchmarkFilterAction, help="Filter by benchmark name (can repeat)", ) args, files_or_dirs = parser.parse_known_args() try: - axis_filters = parse_axis_filters(args.axis) + filter_plan = build_benchmark_filter_plan(args.filter_actions) + reference_device_filter = parse_device_filter( + args.reference_devices, "--reference-devices" + ) + compare_device_filter = parse_device_filter( + args.compare_devices, "--compare-devices" + ) except ValueError as exc: print(str(exc)) return 1 @@ -863,21 +1109,34 @@ def main() -> int: global all_ref_devices global all_cmp_devices - all_ref_devices = ref_root["devices"] - all_cmp_devices = cmp_root["devices"] + try: + all_ref_devices = select_devices( + ref_root["devices"], reference_device_filter, "--reference-devices" + ) + all_cmp_devices = select_devices( + cmp_root["devices"], compare_device_filter, "--compare-devices" + ) + except ValueError as exc: + print(str(exc)) + return 1 - if ref_root["devices"] != cmp_root["devices"]: + if len(all_ref_devices) != len(all_cmp_devices): + print( + f"--reference-devices selected {len(all_ref_devices)} device(s), " + f"but --compare-devices selected {len(all_cmp_devices)} device(s)" + ) + return 1 + + if all_ref_devices != all_cmp_devices: warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED msg_text = "Device sections do not match" print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="") print(": ", end="") - print( - jsondiff.diff( - ref_root["devices"], cmp_root["devices"], syntax="symmetric" - ) - ) - if not args.ignore_devices: + print(jsondiff.diff(all_ref_devices, all_cmp_devices, syntax="symmetric")) + if not args.ignore_devices and require_matching_device_sections( + reference_device_filter, compare_device_filter + ): return 1 try: @@ -888,9 +1147,10 @@ def main() -> int: args.plot_along, args.plot, args.dark, - axis_filters, - args.benchmark, + filter_plan, args.no_color, + reference_device_filter, + compare_device_filter, ) except ValueError as exc: print(str(exc)) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 8d82acc..c6d8c14 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -115,10 +115,18 @@ def make_benchmark(states, *, name="bench"): } -def set_test_devices(monkeypatch, nvbench_compare): +def set_test_devices(monkeypatch, nvbench_compare, ref_devices=None, cmp_devices=None): devices = [{"id": 0, "name": "Test GPU"}] - monkeypatch.setattr(nvbench_compare, "all_ref_devices", devices) - monkeypatch.setattr(nvbench_compare, "all_cmp_devices", devices) + monkeypatch.setattr( + nvbench_compare, + "all_ref_devices", + devices if ref_devices is None else ref_devices, + ) + monkeypatch.setattr( + nvbench_compare, + "all_cmp_devices", + devices if cmp_devices is None else cmp_devices, + ) monkeypatch.setattr(nvbench_compare, "config_count", 0) monkeypatch.setattr(nvbench_compare, "pass_count", 0) monkeypatch.setattr(nvbench_compare, "improvement_count", 0) @@ -126,19 +134,132 @@ def set_test_devices(monkeypatch, nvbench_compare): monkeypatch.setattr(nvbench_compare, "unknown_count", 0) -def compare_benches(nvbench_compare, ref_benches, cmp_benches, **kwargs): +def make_filter_plan(nvbench_compare, filter_actions=None): + return nvbench_compare.build_benchmark_filter_plan(filter_actions or []) + + +def test_compare_benches_accepts_matching_duplicate_state_counts( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state2", mean="1.005"), + ] + ) + ] + nvbench_compare.compare_benches( ref_benches, cmp_benches, - threshold=kwargs.get("threshold", 0.0), - plot_along=kwargs.get("plot_along"), - plot=kwargs.get("plot", False), + threshold=0.0, + plot_along=None, + plot=False, dark=False, - axis_filters=kwargs.get("axis_filters", []), - benchmark_filters=kwargs.get("benchmark_filters", []), + filter_plan=make_filter_plan(nvbench_compare), no_color=True, ) + assert nvbench_compare.config_count == 3 + assert nvbench_compare.pass_count == 3 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_compare_benches_rejects_swapped_duplicate_state_counts( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + + with pytest.raises(ValueError, match="mismatched state occurrences"): + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + +def test_compare_benches_matches_duplicate_states_after_axis_filter( + monkeypatch, nvbench_compare +): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + ] + ) + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare, [("axis", "A=2")]), + no_color=True, + ) + + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 1 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): set_test_devices(monkeypatch, nvbench_compare) @@ -162,7 +283,16 @@ def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): ) ] - compare_benches(nvbench_compare, ref_benches, cmp_benches) + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) assert nvbench_compare.config_count == 1 assert nvbench_compare.pass_count == 1 @@ -191,8 +321,15 @@ def test_compare_benches_prefers_median_and_iqr_when_available( ] ) - compare_benches( - nvbench_compare, [make_benchmark([ref_state])], [make_benchmark([cmp_state])] + nvbench_compare.compare_benches( + [make_benchmark([ref_state])], + [make_benchmark([cmp_state])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 1 @@ -225,10 +362,15 @@ def test_compare_benches_marks_unavailable_noise_unknown(monkeypatch, nvbench_co make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), ] - compare_benches( - nvbench_compare, + nvbench_compare.compare_benches( [make_benchmark([missing_noise_ref, null_noise_ref])], [make_benchmark([missing_noise_cmp, null_noise_cmp])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 2 @@ -258,11 +400,15 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp ) ] - compare_benches( - nvbench_compare, + nvbench_compare.compare_benches( ref_benches, cmp_benches, + threshold=0.0, plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, ) assert nvbench_compare.config_count == 2 @@ -272,6 +418,151 @@ def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_comp assert nvbench_compare.unknown_count == 0 +def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): + assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None + assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0] + assert nvbench_compare.parse_device_filter("0, 2,0", "--reference-devices") == [ + 0, + 2, + 0, + ] + + +@pytest.mark.parametrize( + "device_arg", + [ + "", + " ", + "gpu", + "-1", + "0,gpu", + "0,-1", + "0,", + ",0", + ], +) +def test_device_filter_parser_rejects_invalid_values(nvbench_compare, device_arg): + with pytest.raises(ValueError, match="must be 'all'"): + nvbench_compare.parse_device_filter(device_arg, "--reference-devices") + + +def test_explicit_device_filters_downgrade_device_mismatch_to_warning(nvbench_compare): + assert nvbench_compare.require_matching_device_sections(None, None) + assert not nvbench_compare.require_matching_device_sections([0], None) + assert not nvbench_compare.require_matching_device_sections(None, [1]) + assert not nvbench_compare.require_matching_device_sections([0], [1]) + + +def test_compare_benches_pairs_filtered_devices_by_position( + monkeypatch, nvbench_compare +): + set_test_devices( + monkeypatch, + nvbench_compare, + ref_devices=[ + {"id": 0, "name": "Reference GPU 0"}, + {"id": 1, "name": "Reference GPU 1"}, + ], + cmp_devices=[ + {"id": 0, "name": "Compare GPU 0"}, + {"id": 1, "name": "Compare GPU 1"}, + ], + ) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="1.0", device=0), + make_state(nvbench_compare, "Device=1", mean="9.0", device=1), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="9.0", device=0), + make_state(nvbench_compare, "Device=1", mean="1.0", device=1), + ] + ) + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + reference_device_filter=[0], + compare_device_filter=[1], + ) + + assert nvbench_compare.config_count == 1 + assert nvbench_compare.pass_count == 1 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + +def test_axis_filter_applies_to_most_recent_benchmark(monkeypatch, nvbench_compare): + set_test_devices(monkeypatch, nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + + nvbench_compare.compare_benches( + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan( + nvbench_compare, + [("benchmark", "bench1"), ("axis", "A=2"), ("benchmark", "bench2")], + ), + no_color=True, + ) + + assert nvbench_compare.config_count == 3 + assert nvbench_compare.pass_count == 3 + assert nvbench_compare.improvement_count == 0 + assert nvbench_compare.regression_count == 0 + assert nvbench_compare.unknown_count == 0 + + def test_main_returns_success_exit_code_when_regressions_are_detected( monkeypatch, capsys, nvbench_compare ):