Files
nvbench/python/scripts/nvbench_plot_bwutil.py
Oleksandr Pavlyk a7d11959f6 Load script tooling dependencies lazily
Add a shared nvbench_tooling_deps helper for importing packages required
by NVBench console tools. Missing tooling packages now raise a dedicated
error with an install recipe instead of failing with a raw ImportError.

Update script imports to work both as installed package modules and as
direct source-tree scripts by using the __package__ import pattern for
nvbench_json and the new tooling helper.

Defer nvbench-compare dependencies to the points where they are needed:
NumPy/colorama during normal comparison setup, tabulate during table
rendering, jsondiff only for device mismatch reporting, and plotting
packages only for plot modes.

Update tests to initialize compare tooling when calling internals
directly and add coverage for the tooling dependency loader.

Closes #384
2026-06-29 12:18:15 -05:00

333 lines
9.6 KiB
Python

#!/usr/bin/env python
import argparse
import os
import sys
if __package__:
from .nvbench_json import reader
from .nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
else:
from nvbench_json import reader
from nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
plt = None
PercentFormatter = None
def load_nvbench_plot_bwutil_tooling():
global PercentFormatter, plt
if plt is None:
plt = require_tooling_dependency(
ToolingDependency(
"matplotlib.pyplot", "matplotlib", "bandwidth plot rendering"
),
tool_name="nvbench-plot-bwutil",
)
if PercentFormatter is None:
ticker = require_tooling_dependency(
ToolingDependency(
"matplotlib.ticker", "matplotlib", "plot axis formatting"
),
tool_name="nvbench-plot-bwutil",
)
PercentFormatter = ticker.PercentFormatter
UTILIZATION_TAG = "nv/cold/bw/global/utilization"
def parse_files():
help_text = "%(prog)s [nvbench.out.json | dir/] ..."
parser = argparse.ArgumentParser(prog="nvbench_plot", usage=help_text)
parser.add_argument(
"-o",
"--output",
default=None,
help="Save plot to this file instead of showing it",
)
parser.add_argument(
"--title",
default=None,
help="Optional plot title",
)
parser.add_argument(
"--dark",
action="store_true",
help="Use dark theme (black background, white text)",
)
parser.add_argument(
"-a",
"--axis",
action="append",
default=[],
help="Filter on axis value, e.g. -a T{ct}=I8 (can repeat)",
)
parser.add_argument(
"-b",
"--benchmark",
action="append",
default=[],
help="Filter by benchmark name (can repeat)",
)
args, files_or_dirs = parser.parse_known_args()
filenames = []
for file_or_dir in files_or_dirs:
if os.path.isdir(file_or_dir):
for f in os.listdir(file_or_dir):
filename = os.path.join(file_or_dir, f)
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
filenames.append(filename)
else:
assert os.path.isfile(file_or_dir)
filenames.append(file_or_dir)
filenames.sort()
if not filenames:
parser.print_help()
sys.exit(0)
return args, filenames
def extract_utilization(state):
summaries = state.get("summaries") or []
summary = next(
filter(lambda s: s["tag"] == UTILIZATION_TAG, summaries),
None,
)
if not summary:
return None
value_data = next(
filter(lambda v: v["name"] == "value", summary["data"]),
None,
)
if not value_data:
return None
return float(value_data["value"])
def parse_axis_filters(axis_args):
filters = []
for axis_arg in axis_args:
if "=" not in axis_arg:
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
name, value = axis_arg.split("=", 1)
name = name.strip()
value = value.strip()
if not name or not value:
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
values = []
display_values = []
if value.startswith("[") and value.endswith("]"):
inner = value[1:-1].strip()
if inner:
values = [
stripped for item in inner.split(",") if (stripped := item.strip())
]
else:
values = []
else:
values = [value]
display_values = list(values)
if name.endswith("[pow2]"):
name = name[: -len("[pow2]")].strip()
if not name:
raise ValueError(
"Axis filter missing name before [pow2]: {}".format(axis_arg)
)
try:
exponents = [int(v) for v in values]
except ValueError as exc:
raise ValueError(
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
) from exc
values = [str(2**exponent) for exponent in exponents]
display_values = ["2^{}".format(exponent) for exponent in exponents]
if not values:
raise ValueError(
"Axis filter must specify at least one value: {}".format(axis_arg)
)
if len(display_values) == 1:
display = "{}={}".format(name, display_values[0])
else:
display = "{}=[{}]".format(name, ",".join(display_values))
filters.append(
{
"name": name,
"values": values,
"display": display,
}
)
return filters
def matches_axis_filters(state, axis_filters):
if not axis_filters:
return True
axis_values = state.get("axis_values") or []
for axis_filter in axis_filters:
filter_name = axis_filter["name"]
filter_values = axis_filter["values"]
matched = False
for axis_value in axis_values:
if axis_value.get("name") != filter_name:
continue
value = axis_value.get("value")
if value is None:
continue
if str(value) in filter_values:
matched = True
break
if not matched:
return False
return True
def strip_axis_filters_from_state_name(state_name, axis_filters):
if not axis_filters:
return state_name
tokens = state_name.split()
filter_prefixes = set(
"{}=".format(axis_filter["name"])
for axis_filter in axis_filters
if len(axis_filter["values"]) == 1
)
tokens = [
token
for token in tokens
if not any(token.startswith(prefix) for prefix in filter_prefixes)
]
return " ".join(tokens)
def collect_entries(
filename: str, axis_filters: list[dict], benchmark_filters: list[str]
) -> tuple[list[tuple[str, float, str]], set[str]]:
json_root = reader.read_file(filename)
entries = []
device_names = set()
devices = {device["id"]: device["name"] for device in json_root.get("devices", [])}
for bench in json_root["benchmarks"]:
bench_name = bench["name"]
if benchmark_filters and bench_name not in benchmark_filters:
continue
for state in bench["states"]:
if not matches_axis_filters(state, axis_filters):
continue
utilization = extract_utilization(state)
if utilization is None:
continue
state_name = state["name"]
if state_name.startswith("Device="):
parts = state_name.split(" ", 1)
if len(parts) == 2:
state_name = parts[1]
state_name = strip_axis_filters_from_state_name(state_name, axis_filters)
label = "{} | {}".format(bench_name, state_name)
device_name = devices.get(state.get("device"))
if device_name:
device_names.add(device_name)
entries.append((label, utilization, bench_name))
return entries, device_names
def plot_entries(entries, title=None, output=None, dark=False):
if not entries:
print("No utilization data found.")
return 1
labels, values, bench_names = map(list, zip(*entries))
unique_benches = list(set(bench_names))
cmap = plt.get_cmap("tab20", max(len(unique_benches), 1))
bench_colors = {bench: cmap(index) for index, bench in enumerate(unique_benches)}
colors = [bench_colors[bench] for bench in bench_names]
fig_height = max(4.0, 0.3 * len(entries) + 1.5)
style = "dark_background" if dark else "default"
with plt.style.context(style):
fig, ax = plt.subplots(figsize=(10, fig_height))
y_pos = range(len(labels))
ax.barh(y_pos, values, color=colors)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.invert_yaxis()
ax.set_ylim(len(labels) - 0.5, -0.5)
ax.xaxis.set_major_formatter(PercentFormatter(1.0))
if title:
ax.set_title(title)
fig.tight_layout()
if output:
fig.savefig(output, dpi=150)
else:
plt.show()
return 0
def main():
args, filenames = parse_files()
try:
load_nvbench_plot_bwutil_tooling()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
try:
axis_filters = parse_axis_filters(args.axis)
except ValueError as exc:
print(str(exc))
return 1
entries = []
device_names = set()
for filename in filenames:
file_entries, file_device_names = collect_entries(
filename,
axis_filters,
args.benchmark,
)
entries.extend(file_entries)
device_names.update(file_device_names)
title = args.title
if title is None:
title = "%SOL Bandwidth"
if len(device_names) == 1:
title = "{} - {}".format(title, next(iter(device_names)))
if axis_filters:
axis_label = ", ".join(axis_filter["display"] for axis_filter in axis_filters)
title = "{} ({})".format(title, axis_label)
return plot_entries(entries, title=title, output=args.output, dark=args.dark)
if __name__ == "__main__":
sys.exit(main())