diff --git a/python/scripts/nvbench_plot.py b/python/scripts/nvbench_plot.py index 249abfd..63330a0 100644 --- a/python/scripts/nvbench_plot.py +++ b/python/scripts/nvbench_plot.py @@ -30,6 +30,11 @@ def parse_files(): default=None, help="Optional plot title", ) + parser.add_argument( + "--dark", + action="store_true", + help="Use dark theme (black background, white text)", + ) parser.add_argument( "-a", "--axis", @@ -96,6 +101,7 @@ def parse_axis_filters(axis_args): value = value.strip() if not name or not value: raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + display_value = value if name.endswith("[pow2]"): name = name[: -len("[pow2]")].strip() if not name: @@ -109,7 +115,14 @@ def parse_axis_filters(axis_args): "Axis filter [pow2] value must be integer: {}".format(axis_arg) ) from exc value = str(2**exponent) - filters.append((name, value)) + display_value = "2^{}".format(exponent) + filters.append( + { + "name": name, + "value": value, + "display": "{}={}".format(name, display_value), + } + ) return filters @@ -118,7 +131,9 @@ def matches_axis_filters(state, axis_filters): return True axis_values = state.get("axis_values") or [] - for filter_name, filter_value in axis_filters: + for axis_filter in axis_filters: + filter_name = axis_filter["name"] + filter_value = axis_filter["value"] matched = False for axis_value in axis_values: if axis_value.get("name") != filter_name: @@ -134,6 +149,20 @@ def matches_axis_filters(state, axis_filters): return True +def strip_axis_filters_from_state_name(state_name, axis_filters): + if not axis_filters: + return state_name + + tokens = state_name.split() + tokens_to_remove = set( + axis_filter["display"] + for axis_filter in axis_filters + if " " not in axis_filter["display"] + ) + tokens = [token for token in tokens if token not in tokens_to_remove] + return " ".join(tokens) + + def collect_entries(filename, axis_filters, benchmark_filters): json_root = reader.read_file(filename) entries = [] @@ -155,6 +184,7 @@ def collect_entries(filename, axis_filters, benchmark_filters): parts = state_name.split(" ", 1) if len(parts) == 2: state_name = parts[1] + state_name = strip_axis_filters_from_state_name(state_name, axis_filters) label = "{} | {}".format(bench_name, state_name) device_name = devices.get(state.get("device")) if device_name: @@ -164,7 +194,7 @@ def collect_entries(filename, axis_filters, benchmark_filters): return entries, device_names -def plot_entries(entries, title=None, output=None): +def plot_entries(entries, title=None, output=None, dark=False): if not entries: print("No utilization data found.") return 1 @@ -184,6 +214,15 @@ def plot_entries(entries, title=None, output=None): fig_height = max(4.0, 0.3 * len(entries) + 1.5) fig, ax = plt.subplots(figsize=(10, fig_height)) + if dark: + fig.patch.set_facecolor("black") + ax.set_facecolor("black") + ax.tick_params(colors="white") + ax.xaxis.label.set_color("white") + ax.yaxis.label.set_color("white") + ax.title.set_color("white") + for spine in ax.spines.values(): + spine.set_color("white") y_pos = range(len(labels)) ax.barh(y_pos, values, color=colors) @@ -192,7 +231,6 @@ def plot_entries(entries, title=None, output=None): ax.invert_yaxis() ax.set_ylim(len(labels) - 0.5, -0.5) - ax.set_xlabel("Global BW Utilization") ax.xaxis.set_major_formatter(PercentFormatter(1.0)) if title: @@ -231,8 +269,11 @@ def main(): title = "%SOL Bandwidth" if len(device_names) == 1: title = "{} - {}".format(title, next(iter(device_names))) + if axis_filters: + axis_label = ", ".join(axis_filter["display"] for axis_filter in axis_filters) + title = "{} ({})".format(title, axis_label) - return plot_entries(entries, title=title, output=args.output) + return plot_entries(entries, title=title, output=args.output, dark=args.dark) if __name__ == "__main__":