Add scoped filtering and device pairing to nvbench_compare

Teach nvbench_compare to keep the order of --benchmark and --axis arguments so
axis filters can apply either globally or to the most recent benchmark. Build a
filter plan from the ordered CLI arguments and apply the same plan to table
output and plotting labels.

Add explicit --reference-devices and --compare-devices filters. The filters
accept all, a single device id, or a comma-separated list of ids; ordered lists
and duplicates are preserved so selected reference and compare devices can be
paired by position. Device-section mismatches remain fatal for unfiltered
all-vs-all comparisons, but become warnings when the user explicitly selects
devices and the selected device counts match.

Match duplicate benchmark states by occurrence within each filtered device
section instead of matching only by state name across the whole benchmark. This
keeps repeated axis values and filtered duplicate states aligned between the
reference and compare inputs, and reports mismatched occurrence counts instead
of silently dropping extra states.

Add Python tests for duplicate-state matching, axis filtering before matching,
device filter parsing and validation, explicit cross-device pairing, and
benchmark-scoped axis filters.

Original commit messages folded into this change:

Tweaks for nvbench_compare

1. When JSON files contain multiple entries with the same name and axis values,
   make sure that scripts compares corresponding entries.

   Previous logic would extract the first entry from ref data, and would compare
   measurements for each state in cmp against the first entry from ref. The
   change introduces a counter to know which nth entry we process for a
   particular axis value, and retrieve corresponding entry in ref.

Scope occurrence matching by device.

Device pairing in nvbench_compare.py is strictly index-based under
--ignore-devices, reused IDs in a different order no longer pair against the
wrong reference device.

Require devices in ref and cmp to have the same cardinality

Handle mismatch when number of duplicates in ref data is not same as in cmp data

Use pytest monkeypatch fixture to pretend third-party package dependencies are
available during test run for nvbench_compare without introducing test-time
dependency

Added the happy-path test and fixed its direct-call setup by initializing the
device globals that main() normally populates.

Fix to filter-before-matching.

 - compare_benches() now pairs devices by selected position instead of taking a
   device id.
 - For each device pair, compare_benches() now builds:
     - ref_device_states: matching reference device and axis filters
     - cmp_device_states: matching compare device and axis filters
 - State occurrence counts and duplicate occurrence matching now operate only
   on those filtered per-device lists.
 - Removed the later matches_axis_filters() skip inside the compare-state loop
   because filtering now happens before matching.

Added a regression test where ref/cmp have duplicate state names in opposite
order, and --axis keeps only one of them. The test verifies the kept compare
state is matched against the kept reference state, not the first unfiltered
occurrence.

Introduce device filtering in nvbench_compare

 - --reference-devices all|ID|ID,ID,...
 - --compare-devices all|ID|ID,ID,...
 - Integer lists preserve order and duplicates.
 - Requested IDs are validated against the file-level device list.
 - Filtered reference/compare device counts must match before comparison.
 - compare_benches() pairs selected reference and compare devices by position.
 - Each benchmark validates that requested device IDs are present in its own
   devices list.

Implemented benchmark-scoped --axis handling.

  - --axis and --benchmark now share an ordered argparse action, so their
    relative CLI order is preserved.
  - -a before any -b becomes a global axis filter.
  - -a after -b <name> applies to that most recent benchmark only.
  - Repeated -b entries are treated as separate filter scopes and combined as
    alternatives for that benchmark.
  - Device filtering remains global and is applied independently.

Allow non-matching devices for explicit device selection

Now the device-section equality check remains fatal only for unfiltered
all-vs-all comparisons. If either --reference-devices or --compare-devices is
explicit, mismatched selected device metadata is printed as a warning, but
comparison proceeds after the selected device counts have been validated.

Fix for resolve_benchmark_device_ids, add comments

The return value of resolve_benchmark_device_ids now always owns its list.

Use monkeypatch class in set_test_devices helper

Stricted device id validation

Test for device id validation
This commit is contained in:
Oleksandr Pavlyk
2026-05-27 12:21:40 -05:00
parent 865d8ef8d0
commit d3abc541a5
2 changed files with 602 additions and 51 deletions

View File

@@ -7,6 +7,7 @@ import argparse
import math
import os
import sys
from collections import Counter
from dataclasses import dataclass
from enum import Enum
@@ -66,6 +67,121 @@ class TimeEstimate:
relative_dispersion: float | None
@dataclass(frozen=True)
class BenchmarkFilterScope:
benchmark_name: str
axis_filters: list[dict]
@dataclass(frozen=True)
class BenchmarkFilterPlan:
global_axis_filters: list[dict]
benchmark_scopes: list[BenchmarkFilterScope]
class OrderedBenchmarkFilterAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
actions = getattr(namespace, self.dest, None)
actions = [] if actions is None else list(actions)
action_kind = "axis" if option_string in {"-a", "--axis"} else "benchmark"
actions.append((action_kind, values))
setattr(namespace, self.dest, actions)
def state_match_key(state):
device_prefix = f"Device={state['device']}"
state_name = state["name"]
if state_name == device_prefix:
return ""
if state_name.startswith(f"{device_prefix} "):
return state_name[len(device_prefix) + 1 :]
return state_name
def group_states_by_match_key(states):
grouped = {}
for state in states:
grouped.setdefault(state_match_key(state), []).append(state)
return grouped
def state_group_counts(grouped_states):
return Counter(
{state_name: len(states) for state_name, states in grouped_states.items()}
)
def format_device_ids(device_ids):
return ", ".join(str(device_id) for device_id in device_ids)
def parse_device_filter(device_arg, option_name):
device_arg = device_arg.strip()
if device_arg.lower() == "all":
return None
values = [value.strip() for value in device_arg.split(",")]
if not all(values):
raise ValueError(
f"{option_name} must be 'all', a non-negative integer, "
"or comma-separated non-negative integers"
)
try:
device_ids = [int(value) for value in values]
except ValueError as exc:
raise ValueError(
f"{option_name} must be 'all', a non-negative integer, "
"or comma-separated non-negative integers"
) from exc
if any(device_id < 0 for device_id in device_ids):
raise ValueError(
f"{option_name} must be 'all', a non-negative integer, "
"or comma-separated non-negative integers"
)
return device_ids
def select_devices(all_devices, device_filter, option_name):
if device_filter is None:
return list(all_devices)
devices_by_id = {device["id"]: device for device in all_devices}
missing_ids = [
device_id for device_id in device_filter if device_id not in devices_by_id
]
if missing_ids:
raise ValueError(
f"{option_name} requested device id(s) not present in input: "
f"{format_device_ids(missing_ids)}"
)
return [devices_by_id[device_id] for device_id in device_filter]
def resolve_benchmark_device_ids(bench, device_filter, option_name):
if device_filter is None:
return list(bench["devices"])
benchmark_device_ids = set(bench["devices"])
missing_ids = [
device_id
for device_id in device_filter
if device_id not in benchmark_device_ids
]
if missing_ids:
raise ValueError(
f"benchmark {bench['name']!r} does not contain {option_name} "
f"device id(s): {format_device_ids(missing_ids)}"
)
return device_filter
def require_matching_device_sections(reference_device_filter, compare_device_filter):
return reference_device_filter is None and compare_device_filter is None
# TODO(opavlyk): replace with Emoji(StrEnum) after EOL of Python 3.10
class Emoji(str, Enum):
YELLOW = "\U0001f7e1"
@@ -328,6 +444,53 @@ def parse_axis_filters(axis_args):
return filters
def build_benchmark_filter_plan(filter_actions):
global_axis_args = []
benchmark_scopes = []
current_scope = None
for action_kind, action_value in filter_actions or []:
if action_kind == "benchmark":
current_scope = {"benchmark_name": action_value, "axis_args": []}
benchmark_scopes.append(current_scope)
elif current_scope is None:
global_axis_args.append(action_value)
else:
current_scope["axis_args"].append(action_value)
return BenchmarkFilterPlan(
global_axis_filters=parse_axis_filters(global_axis_args),
benchmark_scopes=[
BenchmarkFilterScope(
benchmark_name=scope["benchmark_name"],
axis_filters=parse_axis_filters(scope["axis_args"]),
)
for scope in benchmark_scopes
],
)
def benchmark_is_selected(benchmark_name, filter_plan):
return not filter_plan.benchmark_scopes or any(
scope.benchmark_name == benchmark_name for scope in filter_plan.benchmark_scopes
)
def axis_filter_groups_for_benchmark(benchmark_name, filter_plan):
if not filter_plan.benchmark_scopes:
return [filter_plan.global_axis_filters]
matching_scopes = [
scope
for scope in filter_plan.benchmark_scopes
if scope.benchmark_name == benchmark_name
]
return [
filter_plan.global_axis_filters + scope.axis_filters
for scope in matching_scopes
]
def matches_axis_filters(state, axis_filters):
if not axis_filters:
return True
@@ -351,6 +514,23 @@ def matches_axis_filters(state, axis_filters):
return True
def matches_axis_filter_groups(state, axis_filter_groups):
return any(
matches_axis_filters(state, axis_filters) for axis_filters in axis_filter_groups
)
def matching_axis_filters(state, axis_filter_groups):
return next(
(
axis_filters
for axis_filters in axis_filter_groups
if matches_axis_filters(state, axis_filters)
),
[],
)
def format_duration(seconds):
if seconds >= 1:
multiplier = 1.0
@@ -479,9 +659,10 @@ def compare_benches(
plot_along,
plot,
dark,
axis_filters,
benchmark_filters,
filter_plan,
no_color,
reference_device_filter=None,
compare_device_filter=None,
):
if plot_along:
import matplotlib.pyplot as plt
@@ -495,12 +676,28 @@ def compare_benches(
ref_bench = find_matching_bench(cmp_bench, ref_benches)
if not ref_bench:
continue
if benchmark_filters and cmp_bench["name"] not in benchmark_filters:
if not benchmark_is_selected(cmp_bench["name"], filter_plan):
continue
axis_filter_groups = axis_filter_groups_for_benchmark(
cmp_bench["name"], filter_plan
)
cmp_device_ids = resolve_benchmark_device_ids(
cmp_bench, compare_device_filter, "--compare-devices"
)
ref_device_ids = resolve_benchmark_device_ids(
ref_bench, reference_device_filter, "--reference-devices"
)
if len(cmp_device_ids) != len(ref_device_ids):
raise ValueError(
f"benchmark {cmp_bench['name']!r} has {len(ref_device_ids)} "
f"reference device(s) but {len(cmp_device_ids)} compare device(s); "
"nvbench_compare pairs devices by position, so each compared "
"benchmark must contain the same number of devices"
)
print(f"""# {cmp_bench["name"]}\n""")
cmp_device_ids = cmp_bench["devices"]
axes = cmp_bench["axes"]
ref_states = ref_bench["states"]
cmp_states = cmp_bench["states"]
@@ -525,20 +722,43 @@ def compare_benches(
headers.append("Status")
colalign.append("center")
for cmp_device_id in cmp_device_ids:
for cmp_device_index, cmp_device_id in enumerate(cmp_device_ids):
ref_device_id = ref_device_ids[cmp_device_index]
ref_device_states = [
state
for state in ref_states
if state["device"] == ref_device_id
and matches_axis_filter_groups(state, axis_filter_groups)
]
cmp_device_states = [
state
for state in cmp_states
if state["device"] == cmp_device_id
and matches_axis_filter_groups(state, axis_filter_groups)
]
ref_states_by_name = group_states_by_match_key(ref_device_states)
cmp_states_by_name = group_states_by_match_key(cmp_device_states)
ref_state_counts = state_group_counts(ref_states_by_name)
cmp_state_counts = state_group_counts(cmp_states_by_name)
if ref_state_counts != cmp_state_counts:
raise ValueError(
f"benchmark {cmp_bench['name']!r} device pair "
f"ref={ref_device_id} cmp={cmp_device_id} has mismatched "
f"state occurrences: ref={dict(ref_state_counts)}, "
f"cmp={dict(cmp_state_counts)}"
)
rows = []
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
counters = {}
for cmp_state in cmp_states:
cmp_state_name = cmp_state["name"]
ref_state = next(
filter(lambda st: st["name"] == cmp_state_name, ref_states), None
)
if not ref_state:
continue
if not matches_axis_filters(cmp_state, axis_filters):
continue
for cmp_state in cmp_device_states:
cmp_state_name = state_match_key(cmp_state)
occurrence = counters.get(cmp_state_name, 0)
counters[cmp_state_name] = occurrence + 1
# Duplicate state names are matched by occurrence order within
# the filtered device section.
ref_state = ref_states_by_name[cmp_state_name][occurrence]
axis_values = cmp_state["axis_values"]
if not axis_values:
axis_values = []
@@ -632,6 +852,7 @@ def compare_benches(
status = colorize(status_label, Fore.RED, Emoji.RED, no_color)
if abs(frac_diff) >= threshold:
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
row.append(format_duration(ref_time))
row.append(format_percentage(ref_noise))
row.append(format_duration(cmp_time))
@@ -660,7 +881,12 @@ def compare_benches(
continue
cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices)
ref_device = find_device_by_id(ref_state["device"], all_ref_devices)
ref_device = find_device_by_id(ref_device_id, all_ref_devices)
if ref_device is None or cmp_device is None:
raise ValueError(
f"benchmark {cmp_bench['name']!r} references device pair "
f"ref={ref_device_id} cmp={cmp_device_id}, but device metadata is missing"
)
if cmp_device == ref_device:
print(f"## [{cmp_device['id']}] {cmp_device['name']}\n")
@@ -756,10 +982,10 @@ def compare_benches(
title = "%SOL Bandwidth change"
if len(comparison_device_names) == 1:
title = f"{title} - {next(iter(comparison_device_names))}"
if axis_filters:
if filter_plan.global_axis_filters:
axis_label = ", ".join(
axis_filter["display"]
for axis_filter in axis_filters
for axis_filter in filter_plan.global_axis_filters
if len(axis_filter["values"]) == 1
)
if axis_label:
@@ -812,24 +1038,44 @@ def main() -> int:
action="store_true",
help="Use emoji instead of ANSI color codes (useful for GitHub issues/PRs)",
)
parser.add_argument(
"--reference-devices",
default="all",
help="Reference devices to compare: all, a non-negative integer id, or comma-separated ids",
)
parser.add_argument(
"--compare-devices",
default="all",
help="Compare devices to compare: all, a non-negative integer id, or comma-separated ids",
)
parser.add_argument(
"-a",
"--axis",
action="append",
default=[],
help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)",
dest="filter_actions",
action=OrderedBenchmarkFilterAction,
help=(
"Filter on axis value, e.g. -a Elements{io}=2^20. Applies to the "
"most recent --benchmark, or all benchmarks if specified before any "
"--benchmark arguments."
),
)
parser.add_argument(
"-b",
"--benchmark",
action="append",
default=[],
dest="filter_actions",
action=OrderedBenchmarkFilterAction,
help="Filter by benchmark name (can repeat)",
)
args, files_or_dirs = parser.parse_known_args()
try:
axis_filters = parse_axis_filters(args.axis)
filter_plan = build_benchmark_filter_plan(args.filter_actions)
reference_device_filter = parse_device_filter(
args.reference_devices, "--reference-devices"
)
compare_device_filter = parse_device_filter(
args.compare_devices, "--compare-devices"
)
except ValueError as exc:
print(str(exc))
return 1
@@ -863,21 +1109,34 @@ def main() -> int:
global all_ref_devices
global all_cmp_devices
all_ref_devices = ref_root["devices"]
all_cmp_devices = cmp_root["devices"]
try:
all_ref_devices = select_devices(
ref_root["devices"], reference_device_filter, "--reference-devices"
)
all_cmp_devices = select_devices(
cmp_root["devices"], compare_device_filter, "--compare-devices"
)
except ValueError as exc:
print(str(exc))
return 1
if ref_root["devices"] != cmp_root["devices"]:
if len(all_ref_devices) != len(all_cmp_devices):
print(
f"--reference-devices selected {len(all_ref_devices)} device(s), "
f"but --compare-devices selected {len(all_cmp_devices)} device(s)"
)
return 1
if all_ref_devices != all_cmp_devices:
warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED
msg_text = "Device sections do not match"
print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="")
print(": ", end="")
print(
jsondiff.diff(
ref_root["devices"], cmp_root["devices"], syntax="symmetric"
)
)
if not args.ignore_devices:
print(jsondiff.diff(all_ref_devices, all_cmp_devices, syntax="symmetric"))
if not args.ignore_devices and require_matching_device_sections(
reference_device_filter, compare_device_filter
):
return 1
try:
@@ -888,9 +1147,10 @@ def main() -> int:
args.plot_along,
args.plot,
args.dark,
axis_filters,
args.benchmark,
filter_plan,
args.no_color,
reference_device_filter,
compare_device_filter,
)
except ValueError as exc:
print(str(exc))