diff --git a/python/scripts/nvbench_plot.py b/python/scripts/nvbench_plot.py new file mode 100644 index 0000000..249abfd --- /dev/null +++ b/python/scripts/nvbench_plot.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python + +import argparse +import os +import sys + +import matplotlib.pyplot as plt +from matplotlib.ticker import PercentFormatter + +try: + from nvbench_json import reader +except ImportError: + from scripts.nvbench_json import reader + + +UTILIZATION_TAG = "nv/cold/bw/global/utilization" + + +def parse_files(): + help_text = "%(prog)s [nvbench.out.json | dir/] ..." + parser = argparse.ArgumentParser(prog="nvbench_plot", usage=help_text) + parser.add_argument( + "-o", + "--output", + default=None, + help="Save plot to this file instead of showing it", + ) + parser.add_argument( + "--title", + default=None, + help="Optional plot title", + ) + parser.add_argument( + "-a", + "--axis", + action="append", + default=[], + help="Filter on axis value, e.g. -a T{ct}=I8 (can repeat)", + ) + parser.add_argument( + "-b", + "--benchmark", + action="append", + default=[], + help="Filter by benchmark name (can repeat)", + ) + args, files_or_dirs = parser.parse_known_args() + + filenames = [] + for file_or_dir in files_or_dirs: + if os.path.isdir(file_or_dir): + for f in os.listdir(file_or_dir): + if os.path.splitext(f)[1] != ".json": + continue + filename = os.path.join(file_or_dir, f) + if os.path.isfile(filename) and os.path.getsize(filename) > 0: + filenames.append(filename) + else: + filenames.append(file_or_dir) + + filenames.sort() + + if not filenames: + parser.print_help() + sys.exit(0) + + return args, filenames + + +def extract_utilization(state): + summaries = state.get("summaries") or [] + summary = next( + filter(lambda s: s["tag"] == UTILIZATION_TAG, summaries), + None, + ) + if not summary: + return None + + value_data = next( + filter(lambda v: v["name"] == "value", summary["data"]), + None, + ) + if not value_data: + return None + + return float(value_data["value"]) + + +def parse_axis_filters(axis_args): + filters = [] + for axis_arg in axis_args: + if "=" not in axis_arg: + raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + name, value = axis_arg.split("=", 1) + name = name.strip() + value = value.strip() + if not name or not value: + raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + if name.endswith("[pow2]"): + name = name[: -len("[pow2]")].strip() + if not name: + raise ValueError( + "Axis filter missing name before [pow2]: {}".format(axis_arg) + ) + try: + exponent = int(value) + except ValueError as exc: + raise ValueError( + "Axis filter [pow2] value must be integer: {}".format(axis_arg) + ) from exc + value = str(2**exponent) + filters.append((name, value)) + return filters + + +def matches_axis_filters(state, axis_filters): + if not axis_filters: + return True + + axis_values = state.get("axis_values") or [] + for filter_name, filter_value in axis_filters: + matched = False + for axis_value in axis_values: + if axis_value.get("name") != filter_name: + continue + value = axis_value.get("value") + if value is None: + continue + if str(value) == filter_value: + matched = True + break + if not matched: + return False + return True + + +def collect_entries(filename, axis_filters, benchmark_filters): + json_root = reader.read_file(filename) + entries = [] + device_names = set() + devices = {device["id"]: device["name"] for device in json_root.get("devices", [])} + for bench in json_root["benchmarks"]: + bench_name = bench["name"] + if benchmark_filters and bench_name not in benchmark_filters: + continue + for state in bench["states"]: + if not matches_axis_filters(state, axis_filters): + continue + utilization = extract_utilization(state) + if utilization is None: + continue + + state_name = state["name"] + if state_name.startswith("Device="): + parts = state_name.split(" ", 1) + if len(parts) == 2: + state_name = parts[1] + label = "{} | {}".format(bench_name, state_name) + device_name = devices.get(state.get("device")) + if device_name: + device_names.add(device_name) + entries.append((label, utilization, bench_name)) + + return entries, device_names + + +def plot_entries(entries, title=None, output=None): + if not entries: + print("No utilization data found.") + return 1 + + labels = [entry[0] for entry in entries] + values = [entry[1] for entry in entries] + bench_names = [entry[2] for entry in entries] + + unique_benches = [] + for bench in bench_names: + if bench not in unique_benches: + unique_benches.append(bench) + + cmap = plt.get_cmap("tab20", max(len(unique_benches), 1)) + bench_colors = {bench: cmap(index) for index, bench in enumerate(unique_benches)} + colors = [bench_colors[bench] for bench in bench_names] + + fig_height = max(4.0, 0.3 * len(entries) + 1.5) + fig, ax = plt.subplots(figsize=(10, fig_height)) + + y_pos = range(len(labels)) + ax.barh(y_pos, values, color=colors) + ax.set_yticks(y_pos) + ax.set_yticklabels(labels) + ax.invert_yaxis() + ax.set_ylim(len(labels) - 0.5, -0.5) + + ax.set_xlabel("Global BW Utilization") + ax.xaxis.set_major_formatter(PercentFormatter(1.0)) + + if title: + ax.set_title(title) + + fig.tight_layout() + + if output: + fig.savefig(output, dpi=150) + else: + plt.show() + + return 0 + + +def main(): + args, filenames = parse_files() + try: + axis_filters = parse_axis_filters(args.axis) + except ValueError as exc: + print(str(exc)) + return 1 + entries = [] + device_names = set() + for filename in filenames: + file_entries, file_device_names = collect_entries( + filename, + axis_filters, + args.benchmark, + ) + entries.extend(file_entries) + device_names.update(file_device_names) + + title = args.title + if title is None: + title = "%SOL Bandwidth" + if len(device_names) == 1: + title = "{} - {}".format(title, next(iter(device_names))) + + return plot_entries(entries, title=title, output=args.output) + + +if __name__ == "__main__": + sys.exit(main())