diff --git a/python/scripts/nvbench_plot.py b/python/scripts/nvbench_plot.py index 4cdf4bd..8ef6086 100644 --- a/python/scripts/nvbench_plot.py +++ b/python/scripts/nvbench_plot.py @@ -239,35 +239,28 @@ def plot_entries(entries, title=None, output=None, dark=False): colors = [bench_colors[bench] for bench in bench_names] fig_height = max(4.0, 0.3 * len(entries) + 1.5) - fig, ax = plt.subplots(figsize=(10, fig_height)) - if dark: - fig.patch.set_facecolor("black") - ax.set_facecolor("black") - ax.tick_params(colors="white") - ax.xaxis.label.set_color("white") - ax.yaxis.label.set_color("white") - ax.title.set_color("white") - for spine in ax.spines.values(): - spine.set_color("white") + style = "dark_background" if dark else None + with plt.style.context(style) if style else plt.style.context("default"): + fig, ax = plt.subplots(figsize=(10, fig_height)) - y_pos = range(len(labels)) - ax.barh(y_pos, values, color=colors) - ax.set_yticks(y_pos) - ax.set_yticklabels(labels) - ax.invert_yaxis() - ax.set_ylim(len(labels) - 0.5, -0.5) + y_pos = range(len(labels)) + ax.barh(y_pos, values, color=colors) + ax.set_yticks(y_pos) + ax.set_yticklabels(labels) + ax.invert_yaxis() + ax.set_ylim(len(labels) - 0.5, -0.5) - ax.xaxis.set_major_formatter(PercentFormatter(1.0)) + ax.xaxis.set_major_formatter(PercentFormatter(1.0)) - if title: - ax.set_title(title) + if title: + ax.set_title(title) - fig.tight_layout() + fig.tight_layout() - if output: - fig.savefig(output, dpi=150) - else: - plt.show() + if output: + fig.savefig(output, dpi=150) + else: + plt.show() return 0