diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index 7ad67d1..4b81d03 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -79,6 +79,87 @@ def format_axis_value(axis_name, axis_value, axes): return format_string_axis_value(axis_name, axis_value, axes) +def make_display(name: str, display_values: [list[str]]) -> str: + open_bracket, close_bracket = ("[", "]") if len(display_values) > 1 else ("", "") + display_values = ",".join(display_values) + return f"{name}={open_bracket}{display_values}{close_bracket}" + + +def parse_axis_filters(axis_args): + filters = [] + for axis_arg in axis_args: + if "=" not in axis_arg: + raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + name, value = axis_arg.split("=", 1) + name = name.strip() + value = value.strip() + if not name or not value: + raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + + values = [] + if value.startswith("[") and value.endswith("]"): + inner = value[1:-1].strip() + values = [ + stripped for item in inner.split(",") if (stripped := item.strip()) + ] + else: + values = [value] + display_values = list(values) + + if name.endswith("[pow2]"): + name = name[: -len("[pow2]")].strip() + if not name: + raise ValueError( + "Axis filter missing name before [pow2]: {}".format(axis_arg) + ) + try: + exponents = [int(v) for v in values] + except ValueError as exc: + raise ValueError( + "Axis filter [pow2] value must be integer: {}".format(axis_arg) + ) from exc + values = [str(2**exponent) for exponent in exponents] + display_values = ["2^{}".format(exponent) for exponent in exponents] + + if not values: + raise ValueError( + "Axis filter must specify at least one value: {}".format(axis_arg) + ) + + display = make_display(name, display_values) + filters.append( + { + "name": name, + "values": values, + "display": display, + } + ) + return filters + + +def matches_axis_filters(state, axis_filters): + if not axis_filters: + return True + + axis_values = state.get("axis_values") or [] + for axis_filter in axis_filters: + filter_name = axis_filter["name"] + filter_values = axis_filter["values"] + matched = False + for axis_value in axis_values: + if axis_value.get("name") != filter_name: + continue + value = axis_value.get("value") + if value is None: + continue + if str(value) in filter_values: + matched = True + break + if not matched: + return False + return True + + def format_duration(seconds): if seconds >= 1: multiplier = 1.0 @@ -104,17 +185,121 @@ def format_percentage(percentage): return "%0.2f%%" % (percentage * 100.0) -def compare_benches(ref_benches, cmp_benches, threshold, plot): - if plot: +def format_axis_values(axis_values, axes, axis_filters=None): + if not axis_values: + return "" + filtered_names = set() + if axis_filters: + filtered_names = { + axis_filter["name"] + for axis_filter in axis_filters + if len(axis_filter["values"]) == 1 + } + parts = [] + for axis_value in axis_values: + axis_name = axis_value["name"] + if axis_name in filtered_names: + continue + formatted = format_axis_value(axis_name, axis_value, axes) + parts.append(f"{axis_name}={formatted}") + return " ".join(parts) + + +def plot_comparison_entries(entries, title=None, dark=False): + if not entries: + print("No comparison data to plot.") + return 1 + + if not os.environ.get("DISPLAY"): + import matplotlib + + matplotlib.use("Agg") + + import matplotlib.pyplot as plt + from matplotlib.ticker import PercentFormatter + + labels, values, statuses, bench_names = map(list, zip(*entries)) + + status_colors = { + "SLOW": "red", + "FAST": "green", + "SAME": "blue", + } + colors = [status_colors.get(status, "gray") for status in statuses] + + fig_height = max(4.0, 0.3 * len(entries) + 1.5) + fig, ax = plt.subplots(figsize=(10, fig_height)) + if dark: + fig.patch.set_facecolor("black") + ax.set_facecolor("black") + ax.tick_params(colors="white") + ax.xaxis.label.set_color("white") + ax.yaxis.label.set_color("white") + ax.title.set_color("white") + for spine in ax.spines.values(): + spine.set_color("white") + + y_pos = range(len(labels)) + ax.barh(y_pos, values, color=colors) + ax.set_yticks(y_pos) + ax.set_yticklabels(labels) + ax.invert_yaxis() + ax.set_ylim(len(labels) - 0.5, -0.5) + + separator_color = "white" if dark else "gray" + ax.axvline(0, color=separator_color, linewidth=1, alpha=0.6) + for index in range(1, len(bench_names)): + if bench_names[index] != bench_names[index - 1]: + ax.axhline(index - 0.5, color=separator_color, linewidth=0.6, alpha=0.4) + ax.xaxis.set_major_formatter(PercentFormatter(1.0)) + + if title: + ax.set_title(title) + + min_val = min(values) + max_val = max(values) + if min_val == max_val: + pad = 0.05 if min_val == 0 else abs(min_val) * 0.1 + ax.set_xlim(min_val - pad, max_val + pad) + else: + pad = (max_val - min_val) * 0.1 + ax.set_xlim(min_val - pad, max_val + pad) + + fig.tight_layout() + + if not os.environ.get("DISPLAY"): + output = "nvbench_compare.png" + fig.savefig(output, dpi=150) + print("Saved comparison plot to {}".format(output)) + else: + plt.show() + return 0 + + +def compare_benches( + ref_benches, + cmp_benches, + threshold, + plot_along, + plot, + dark, + axis_filters, + benchmark_filters, +): + if plot_along: import matplotlib.pyplot as plt import seaborn as sns sns.set() + comparison_entries = [] + comparison_device_names = set() for cmp_bench in cmp_benches: ref_bench = find_matching_bench(cmp_bench, ref_benches) if not ref_bench: continue + if benchmark_filters and cmp_bench["name"] not in benchmark_filters: + continue print("# %s\n" % (cmp_bench["name"])) @@ -154,6 +339,8 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): ) if not ref_state: continue + if not matches_axis_filters(cmp_state, axis_filters): + continue axis_values = cmp_state["axis_values"] if not axis_values: @@ -231,11 +418,11 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): else: min_noise = None # Noise is inf - if plot: + if plot_along: axis_name = [] axis_value = "--" for aid in range(len(axis_values)): - if axis_values[aid]["name"] != plot: + if axis_values[aid]["name"] != plot_along: axis_name.append( "{} = {}".format( axis_values[aid]["name"], axis_values[aid]["value"] @@ -264,16 +451,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): config_count += 1 if not min_noise: unknown_count += 1 - status = Fore.YELLOW + "????" + Fore.RESET + status_label = "????" + status = Fore.YELLOW + status_label + Fore.RESET elif abs(frac_diff) <= min_noise: pass_count += 1 - status = Fore.BLUE + "SAME" + Fore.RESET + status_label = "SAME" + status = Fore.BLUE + status_label + Fore.RESET elif diff < 0: failure_count += 1 - status = Fore.GREEN + "FAST" + Fore.RESET + status_label = "FAST" + status = Fore.GREEN + status_label + Fore.RESET else: failure_count += 1 - status = Fore.RED + "SLOW" + Fore.RESET + status_label = "SLOW" + status = Fore.RED + status_label + Fore.RESET if abs(frac_diff) >= threshold: row.append(format_duration(ref_time)) @@ -285,6 +476,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): row.append(status) rows.append(row) + if plot: + axis_label = format_axis_values(axis_values, axes, axis_filters) + if axis_label: + label = "{} | {}".format(cmp_bench["name"], axis_label) + else: + label = cmp_bench["name"] + cmp_device = find_device_by_id( + cmp_state["device"], all_cmp_devices + ) + if cmp_device: + comparison_device_names.add(cmp_device["name"]) + comparison_entries.append( + (label, frac_diff, status_label, cmp_bench["name"]) + ) if len(rows) == 0: continue @@ -316,10 +521,10 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): print("") - if plot: + if plot_along: plt.xscale("log") plt.yscale("log") - plt.xlabel(plot) + plt.xlabel(plot_along) plt.ylabel("time [s]") plt.title(cmp_device["name"]) @@ -342,6 +547,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): plt.legend() plt.show() + if plot: + title = "%SOL Bandwidth change" + if len(comparison_device_names) == 1: + title = "{} - {}".format(title, next(iter(comparison_device_names))) + if axis_filters: + axis_label = ", ".join( + axis_filter["display"] + for axis_filter in axis_filters + if len(axis_filter["values"]) == 1 + ) + if axis_label: + title = "{} ({})".format(title, axis_label) + plot_comparison_entries(comparison_entries, title=title, dark=dark) + def main(): help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]" @@ -361,11 +580,42 @@ def main(): help="only show benchmarks where percentage diff is >= THRESHOLD", ) parser.add_argument( - "--plot-along", type=str, dest="plot", default=None, help="plot results" + "--plot-along", type=str, dest="plot_along", default=None, help="plot results" + ) + parser.add_argument( + "--plot", + dest="plot", + default=False, + help="plot comparison summary", + action="store_true", + ) + parser.add_argument( + "--dark", + action="store_true", + help="Use dark theme (black background, white text)", + ) + parser.add_argument( + "-a", + "--axis", + action="append", + default=[], + help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)", + ) + parser.add_argument( + "-b", + "--benchmark", + action="append", + default=[], + help="Filter by benchmark name (can repeat)", ) args, files_or_dirs = parser.parse_known_args() print(files_or_dirs) + try: + axis_filters = parse_axis_filters(args.axis) + except ValueError as exc: + print(str(exc)) + sys.exit(1) if len(files_or_dirs) != 2: parser.print_help() @@ -414,7 +664,14 @@ def main(): sys.exit(1) compare_benches( - ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot + ref_root["benchmarks"], + cmp_root["benchmarks"], + args.threshold, + args.plot_along, + args.plot, + args.dark, + axis_filters, + args.benchmark, ) print("# Summary\n")