Add a script to plot benchmark results

This commit is contained in:
Bernhard Manfred Gruber
2026-02-05 10:36:52 +01:00
parent dc59f98ecd
commit 0be190b407

View File

@@ -0,0 +1,239 @@
#!/usr/bin/env python
import argparse
import os
import sys
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
try:
from nvbench_json import reader
except ImportError:
from scripts.nvbench_json import reader
UTILIZATION_TAG = "nv/cold/bw/global/utilization"
def parse_files():
help_text = "%(prog)s [nvbench.out.json | dir/] ..."
parser = argparse.ArgumentParser(prog="nvbench_plot", usage=help_text)
parser.add_argument(
"-o",
"--output",
default=None,
help="Save plot to this file instead of showing it",
)
parser.add_argument(
"--title",
default=None,
help="Optional plot title",
)
parser.add_argument(
"-a",
"--axis",
action="append",
default=[],
help="Filter on axis value, e.g. -a T{ct}=I8 (can repeat)",
)
parser.add_argument(
"-b",
"--benchmark",
action="append",
default=[],
help="Filter by benchmark name (can repeat)",
)
args, files_or_dirs = parser.parse_known_args()
filenames = []
for file_or_dir in files_or_dirs:
if os.path.isdir(file_or_dir):
for f in os.listdir(file_or_dir):
if os.path.splitext(f)[1] != ".json":
continue
filename = os.path.join(file_or_dir, f)
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
filenames.append(filename)
else:
filenames.append(file_or_dir)
filenames.sort()
if not filenames:
parser.print_help()
sys.exit(0)
return args, filenames
def extract_utilization(state):
summaries = state.get("summaries") or []
summary = next(
filter(lambda s: s["tag"] == UTILIZATION_TAG, summaries),
None,
)
if not summary:
return None
value_data = next(
filter(lambda v: v["name"] == "value", summary["data"]),
None,
)
if not value_data:
return None
return float(value_data["value"])
def parse_axis_filters(axis_args):
filters = []
for axis_arg in axis_args:
if "=" not in axis_arg:
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
name, value = axis_arg.split("=", 1)
name = name.strip()
value = value.strip()
if not name or not value:
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
if name.endswith("[pow2]"):
name = name[: -len("[pow2]")].strip()
if not name:
raise ValueError(
"Axis filter missing name before [pow2]: {}".format(axis_arg)
)
try:
exponent = int(value)
except ValueError as exc:
raise ValueError(
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
) from exc
value = str(2**exponent)
filters.append((name, value))
return filters
def matches_axis_filters(state, axis_filters):
if not axis_filters:
return True
axis_values = state.get("axis_values") or []
for filter_name, filter_value in axis_filters:
matched = False
for axis_value in axis_values:
if axis_value.get("name") != filter_name:
continue
value = axis_value.get("value")
if value is None:
continue
if str(value) == filter_value:
matched = True
break
if not matched:
return False
return True
def collect_entries(filename, axis_filters, benchmark_filters):
json_root = reader.read_file(filename)
entries = []
device_names = set()
devices = {device["id"]: device["name"] for device in json_root.get("devices", [])}
for bench in json_root["benchmarks"]:
bench_name = bench["name"]
if benchmark_filters and bench_name not in benchmark_filters:
continue
for state in bench["states"]:
if not matches_axis_filters(state, axis_filters):
continue
utilization = extract_utilization(state)
if utilization is None:
continue
state_name = state["name"]
if state_name.startswith("Device="):
parts = state_name.split(" ", 1)
if len(parts) == 2:
state_name = parts[1]
label = "{} | {}".format(bench_name, state_name)
device_name = devices.get(state.get("device"))
if device_name:
device_names.add(device_name)
entries.append((label, utilization, bench_name))
return entries, device_names
def plot_entries(entries, title=None, output=None):
if not entries:
print("No utilization data found.")
return 1
labels = [entry[0] for entry in entries]
values = [entry[1] for entry in entries]
bench_names = [entry[2] for entry in entries]
unique_benches = []
for bench in bench_names:
if bench not in unique_benches:
unique_benches.append(bench)
cmap = plt.get_cmap("tab20", max(len(unique_benches), 1))
bench_colors = {bench: cmap(index) for index, bench in enumerate(unique_benches)}
colors = [bench_colors[bench] for bench in bench_names]
fig_height = max(4.0, 0.3 * len(entries) + 1.5)
fig, ax = plt.subplots(figsize=(10, fig_height))
y_pos = range(len(labels))
ax.barh(y_pos, values, color=colors)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.invert_yaxis()
ax.set_ylim(len(labels) - 0.5, -0.5)
ax.set_xlabel("Global BW Utilization")
ax.xaxis.set_major_formatter(PercentFormatter(1.0))
if title:
ax.set_title(title)
fig.tight_layout()
if output:
fig.savefig(output, dpi=150)
else:
plt.show()
return 0
def main():
args, filenames = parse_files()
try:
axis_filters = parse_axis_filters(args.axis)
except ValueError as exc:
print(str(exc))
return 1
entries = []
device_names = set()
for filename in filenames:
file_entries, file_device_names = collect_entries(
filename,
axis_filters,
args.benchmark,
)
entries.extend(file_entries)
device_names.update(file_device_names)
title = args.title
if title is None:
title = "%SOL Bandwidth"
if len(device_names) == 1:
title = "{} - {}".format(title, next(iter(device_names)))
return plot_entries(entries, title=title, output=args.output)
if __name__ == "__main__":
sys.exit(main())