mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-05-01 04:01:14 +00:00
Merge pull request #315 from bernhardmgruber/plot_diff_script
Extend `nvbench_compare.py` with `--plot`, axis/benchmark filtering, and dark mode
This commit is contained in:
@@ -79,6 +79,87 @@ def format_axis_value(axis_name, axis_value, axes):
|
|||||||
return format_string_axis_value(axis_name, axis_value, axes)
|
return format_string_axis_value(axis_name, axis_value, axes)
|
||||||
|
|
||||||
|
|
||||||
|
def make_display(name: str, display_values: [list[str]]) -> str:
|
||||||
|
open_bracket, close_bracket = ("[", "]") if len(display_values) > 1 else ("", "")
|
||||||
|
display_values = ",".join(display_values)
|
||||||
|
return f"{name}={open_bracket}{display_values}{close_bracket}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_axis_filters(axis_args):
|
||||||
|
filters = []
|
||||||
|
for axis_arg in axis_args:
|
||||||
|
if "=" not in axis_arg:
|
||||||
|
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||||
|
name, value = axis_arg.split("=", 1)
|
||||||
|
name = name.strip()
|
||||||
|
value = value.strip()
|
||||||
|
if not name or not value:
|
||||||
|
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||||
|
|
||||||
|
values = []
|
||||||
|
if value.startswith("[") and value.endswith("]"):
|
||||||
|
inner = value[1:-1].strip()
|
||||||
|
values = [
|
||||||
|
stripped for item in inner.split(",") if (stripped := item.strip())
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
values = [value]
|
||||||
|
display_values = list(values)
|
||||||
|
|
||||||
|
if name.endswith("[pow2]"):
|
||||||
|
name = name[: -len("[pow2]")].strip()
|
||||||
|
if not name:
|
||||||
|
raise ValueError(
|
||||||
|
"Axis filter missing name before [pow2]: {}".format(axis_arg)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
exponents = [int(v) for v in values]
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
|
||||||
|
) from exc
|
||||||
|
values = [str(2**exponent) for exponent in exponents]
|
||||||
|
display_values = ["2^{}".format(exponent) for exponent in exponents]
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
raise ValueError(
|
||||||
|
"Axis filter must specify at least one value: {}".format(axis_arg)
|
||||||
|
)
|
||||||
|
|
||||||
|
display = make_display(name, display_values)
|
||||||
|
filters.append(
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"values": values,
|
||||||
|
"display": display,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def matches_axis_filters(state, axis_filters):
|
||||||
|
if not axis_filters:
|
||||||
|
return True
|
||||||
|
|
||||||
|
axis_values = state.get("axis_values") or []
|
||||||
|
for axis_filter in axis_filters:
|
||||||
|
filter_name = axis_filter["name"]
|
||||||
|
filter_values = axis_filter["values"]
|
||||||
|
matched = False
|
||||||
|
for axis_value in axis_values:
|
||||||
|
if axis_value.get("name") != filter_name:
|
||||||
|
continue
|
||||||
|
value = axis_value.get("value")
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if str(value) in filter_values:
|
||||||
|
matched = True
|
||||||
|
break
|
||||||
|
if not matched:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def format_duration(seconds):
|
def format_duration(seconds):
|
||||||
if seconds >= 1:
|
if seconds >= 1:
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
@@ -104,17 +185,121 @@ def format_percentage(percentage):
|
|||||||
return "%0.2f%%" % (percentage * 100.0)
|
return "%0.2f%%" % (percentage * 100.0)
|
||||||
|
|
||||||
|
|
||||||
def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
def format_axis_values(axis_values, axes, axis_filters=None):
|
||||||
if plot:
|
if not axis_values:
|
||||||
|
return ""
|
||||||
|
filtered_names = set()
|
||||||
|
if axis_filters:
|
||||||
|
filtered_names = {
|
||||||
|
axis_filter["name"]
|
||||||
|
for axis_filter in axis_filters
|
||||||
|
if len(axis_filter["values"]) == 1
|
||||||
|
}
|
||||||
|
parts = []
|
||||||
|
for axis_value in axis_values:
|
||||||
|
axis_name = axis_value["name"]
|
||||||
|
if axis_name in filtered_names:
|
||||||
|
continue
|
||||||
|
formatted = format_axis_value(axis_name, axis_value, axes)
|
||||||
|
parts.append(f"{axis_name}={formatted}")
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_comparison_entries(entries, title=None, dark=False):
|
||||||
|
if not entries:
|
||||||
|
print("No comparison data to plot.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not os.environ.get("DISPLAY"):
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import PercentFormatter
|
||||||
|
|
||||||
|
labels, values, statuses, bench_names = map(list, zip(*entries))
|
||||||
|
|
||||||
|
status_colors = {
|
||||||
|
"SLOW": "red",
|
||||||
|
"FAST": "green",
|
||||||
|
"SAME": "blue",
|
||||||
|
}
|
||||||
|
colors = [status_colors.get(status, "gray") for status in statuses]
|
||||||
|
|
||||||
|
fig_height = max(4.0, 0.3 * len(entries) + 1.5)
|
||||||
|
fig, ax = plt.subplots(figsize=(10, fig_height))
|
||||||
|
if dark:
|
||||||
|
fig.patch.set_facecolor("black")
|
||||||
|
ax.set_facecolor("black")
|
||||||
|
ax.tick_params(colors="white")
|
||||||
|
ax.xaxis.label.set_color("white")
|
||||||
|
ax.yaxis.label.set_color("white")
|
||||||
|
ax.title.set_color("white")
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color("white")
|
||||||
|
|
||||||
|
y_pos = range(len(labels))
|
||||||
|
ax.barh(y_pos, values, color=colors)
|
||||||
|
ax.set_yticks(y_pos)
|
||||||
|
ax.set_yticklabels(labels)
|
||||||
|
ax.invert_yaxis()
|
||||||
|
ax.set_ylim(len(labels) - 0.5, -0.5)
|
||||||
|
|
||||||
|
separator_color = "white" if dark else "gray"
|
||||||
|
ax.axvline(0, color=separator_color, linewidth=1, alpha=0.6)
|
||||||
|
for index in range(1, len(bench_names)):
|
||||||
|
if bench_names[index] != bench_names[index - 1]:
|
||||||
|
ax.axhline(index - 0.5, color=separator_color, linewidth=0.6, alpha=0.4)
|
||||||
|
ax.xaxis.set_major_formatter(PercentFormatter(1.0))
|
||||||
|
|
||||||
|
if title:
|
||||||
|
ax.set_title(title)
|
||||||
|
|
||||||
|
min_val = min(values)
|
||||||
|
max_val = max(values)
|
||||||
|
if min_val == max_val:
|
||||||
|
pad = 0.05 if min_val == 0 else abs(min_val) * 0.1
|
||||||
|
ax.set_xlim(min_val - pad, max_val + pad)
|
||||||
|
else:
|
||||||
|
pad = (max_val - min_val) * 0.1
|
||||||
|
ax.set_xlim(min_val - pad, max_val + pad)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
if not os.environ.get("DISPLAY"):
|
||||||
|
output = "nvbench_compare.png"
|
||||||
|
fig.savefig(output, dpi=150)
|
||||||
|
print("Saved comparison plot to {}".format(output))
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def compare_benches(
|
||||||
|
ref_benches,
|
||||||
|
cmp_benches,
|
||||||
|
threshold,
|
||||||
|
plot_along,
|
||||||
|
plot,
|
||||||
|
dark,
|
||||||
|
axis_filters,
|
||||||
|
benchmark_filters,
|
||||||
|
):
|
||||||
|
if plot_along:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
sns.set()
|
sns.set()
|
||||||
|
|
||||||
|
comparison_entries = []
|
||||||
|
comparison_device_names = set()
|
||||||
for cmp_bench in cmp_benches:
|
for cmp_bench in cmp_benches:
|
||||||
ref_bench = find_matching_bench(cmp_bench, ref_benches)
|
ref_bench = find_matching_bench(cmp_bench, ref_benches)
|
||||||
if not ref_bench:
|
if not ref_bench:
|
||||||
continue
|
continue
|
||||||
|
if benchmark_filters and cmp_bench["name"] not in benchmark_filters:
|
||||||
|
continue
|
||||||
|
|
||||||
print("# %s\n" % (cmp_bench["name"]))
|
print("# %s\n" % (cmp_bench["name"]))
|
||||||
|
|
||||||
@@ -154,6 +339,8 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
)
|
)
|
||||||
if not ref_state:
|
if not ref_state:
|
||||||
continue
|
continue
|
||||||
|
if not matches_axis_filters(cmp_state, axis_filters):
|
||||||
|
continue
|
||||||
|
|
||||||
axis_values = cmp_state["axis_values"]
|
axis_values = cmp_state["axis_values"]
|
||||||
if not axis_values:
|
if not axis_values:
|
||||||
@@ -231,11 +418,11 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
else:
|
else:
|
||||||
min_noise = None # Noise is inf
|
min_noise = None # Noise is inf
|
||||||
|
|
||||||
if plot:
|
if plot_along:
|
||||||
axis_name = []
|
axis_name = []
|
||||||
axis_value = "--"
|
axis_value = "--"
|
||||||
for aid in range(len(axis_values)):
|
for aid in range(len(axis_values)):
|
||||||
if axis_values[aid]["name"] != plot:
|
if axis_values[aid]["name"] != plot_along:
|
||||||
axis_name.append(
|
axis_name.append(
|
||||||
"{} = {}".format(
|
"{} = {}".format(
|
||||||
axis_values[aid]["name"], axis_values[aid]["value"]
|
axis_values[aid]["name"], axis_values[aid]["value"]
|
||||||
@@ -264,16 +451,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
config_count += 1
|
config_count += 1
|
||||||
if not min_noise:
|
if not min_noise:
|
||||||
unknown_count += 1
|
unknown_count += 1
|
||||||
status = Fore.YELLOW + "????" + Fore.RESET
|
status_label = "????"
|
||||||
|
status = Fore.YELLOW + status_label + Fore.RESET
|
||||||
elif abs(frac_diff) <= min_noise:
|
elif abs(frac_diff) <= min_noise:
|
||||||
pass_count += 1
|
pass_count += 1
|
||||||
status = Fore.BLUE + "SAME" + Fore.RESET
|
status_label = "SAME"
|
||||||
|
status = Fore.BLUE + status_label + Fore.RESET
|
||||||
elif diff < 0:
|
elif diff < 0:
|
||||||
failure_count += 1
|
failure_count += 1
|
||||||
status = Fore.GREEN + "FAST" + Fore.RESET
|
status_label = "FAST"
|
||||||
|
status = Fore.GREEN + status_label + Fore.RESET
|
||||||
else:
|
else:
|
||||||
failure_count += 1
|
failure_count += 1
|
||||||
status = Fore.RED + "SLOW" + Fore.RESET
|
status_label = "SLOW"
|
||||||
|
status = Fore.RED + status_label + Fore.RESET
|
||||||
|
|
||||||
if abs(frac_diff) >= threshold:
|
if abs(frac_diff) >= threshold:
|
||||||
row.append(format_duration(ref_time))
|
row.append(format_duration(ref_time))
|
||||||
@@ -285,6 +476,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
row.append(status)
|
row.append(status)
|
||||||
|
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
if plot:
|
||||||
|
axis_label = format_axis_values(axis_values, axes, axis_filters)
|
||||||
|
if axis_label:
|
||||||
|
label = "{} | {}".format(cmp_bench["name"], axis_label)
|
||||||
|
else:
|
||||||
|
label = cmp_bench["name"]
|
||||||
|
cmp_device = find_device_by_id(
|
||||||
|
cmp_state["device"], all_cmp_devices
|
||||||
|
)
|
||||||
|
if cmp_device:
|
||||||
|
comparison_device_names.add(cmp_device["name"])
|
||||||
|
comparison_entries.append(
|
||||||
|
(label, frac_diff, status_label, cmp_bench["name"])
|
||||||
|
)
|
||||||
|
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
continue
|
continue
|
||||||
@@ -316,10 +521,10 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
|
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
if plot:
|
if plot_along:
|
||||||
plt.xscale("log")
|
plt.xscale("log")
|
||||||
plt.yscale("log")
|
plt.yscale("log")
|
||||||
plt.xlabel(plot)
|
plt.xlabel(plot_along)
|
||||||
plt.ylabel("time [s]")
|
plt.ylabel("time [s]")
|
||||||
plt.title(cmp_device["name"])
|
plt.title(cmp_device["name"])
|
||||||
|
|
||||||
@@ -342,6 +547,20 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
|||||||
plt.legend()
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
title = "%SOL Bandwidth change"
|
||||||
|
if len(comparison_device_names) == 1:
|
||||||
|
title = "{} - {}".format(title, next(iter(comparison_device_names)))
|
||||||
|
if axis_filters:
|
||||||
|
axis_label = ", ".join(
|
||||||
|
axis_filter["display"]
|
||||||
|
for axis_filter in axis_filters
|
||||||
|
if len(axis_filter["values"]) == 1
|
||||||
|
)
|
||||||
|
if axis_label:
|
||||||
|
title = "{} ({})".format(title, axis_label)
|
||||||
|
plot_comparison_entries(comparison_entries, title=title, dark=dark)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
|
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
|
||||||
@@ -361,11 +580,42 @@ def main():
|
|||||||
help="only show benchmarks where percentage diff is >= THRESHOLD",
|
help="only show benchmarks where percentage diff is >= THRESHOLD",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--plot-along", type=str, dest="plot", default=None, help="plot results"
|
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot",
|
||||||
|
dest="plot",
|
||||||
|
default=False,
|
||||||
|
help="plot comparison summary",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dark",
|
||||||
|
action="store_true",
|
||||||
|
help="Use dark theme (black background, white text)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-a",
|
||||||
|
"--axis",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--benchmark",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Filter by benchmark name (can repeat)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args, files_or_dirs = parser.parse_known_args()
|
args, files_or_dirs = parser.parse_known_args()
|
||||||
print(files_or_dirs)
|
print(files_or_dirs)
|
||||||
|
try:
|
||||||
|
axis_filters = parse_axis_filters(args.axis)
|
||||||
|
except ValueError as exc:
|
||||||
|
print(str(exc))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
if len(files_or_dirs) != 2:
|
if len(files_or_dirs) != 2:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
@@ -414,7 +664,14 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
compare_benches(
|
compare_benches(
|
||||||
ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot
|
ref_root["benchmarks"],
|
||||||
|
cmp_root["benchmarks"],
|
||||||
|
args.threshold,
|
||||||
|
args.plot_along,
|
||||||
|
args.plot,
|
||||||
|
args.dark,
|
||||||
|
axis_filters,
|
||||||
|
args.benchmark,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("# Summary\n")
|
print("# Summary\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user