From b0a46f44c21c344e2c338225d724eed60d912fe2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Tue, 14 Apr 2026 08:09:44 -0500 Subject: [PATCH] Modularize color handling (#336) * Introduce function colorize to modularize colorization/no-color handling * Use sns.set_theme instead of deprecated sns.set() * Use str.format instead of legacy % syntax * Simplified iteration over list Use f-string (supported since Python 3.6) instead of str.format for better readability and performance --- python/scripts/nvbench_compare.py | 92 +++++++++++++++---------------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index 585cdde..c637033 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -4,6 +4,7 @@ import argparse import math import os import sys +from enum import StrEnum import jsondiff import tabulate @@ -30,6 +31,24 @@ failure_count = 0 pass_count = 0 +class Emoji(StrEnum): + YELLOW = "\U0001f7e1" + BLUE = "\U0001f535" + GREEN = "\U0001f7e2" + RED = "\U0001f534" + NONE = "" + + +def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str: + if no_color: + prefix = "" + if emoji_s := str(emoji): + prefix = f"{emoji_s} " + return f"{prefix}{msg}" + else: + return f"{fore}{msg}{Fore.RESET}" + + def find_matching_bench(needle, haystack): for hay in haystack: if hay["name"] == needle["name"]: @@ -89,12 +108,12 @@ def parse_axis_filters(axis_args): filters = [] for axis_arg in axis_args: if "=" not in axis_arg: - raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + raise ValueError(f"Axis filter must be NAME=VALUE: {axis_arg}") name, value = axis_arg.split("=", 1) name = name.strip() value = value.strip() if not name or not value: - raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg)) + raise ValueError(f"Axis filter must be NAME=VALUE: {axis_arg}") values = [] if value.startswith("[") and value.endswith("]"): @@ -109,22 +128,18 @@ def parse_axis_filters(axis_args): if name.endswith("[pow2]"): name = name[: -len("[pow2]")].strip() if not name: - raise ValueError( - "Axis filter missing name before [pow2]: {}".format(axis_arg) - ) + raise ValueError(f"Axis filter missing name before [pow2]: {axis_arg}") try: exponents = [int(v) for v in values] except ValueError as exc: raise ValueError( - "Axis filter [pow2] value must be integer: {}".format(axis_arg) + f"Axis filter [pow2] value must be integer: {axis_arg}" ) from exc values = [str(2**exponent) for exponent in exponents] - display_values = ["2^{}".format(exponent) for exponent in exponents] + display_values = [f"2^{exponent}" for exponent in exponents] if not values: - raise ValueError( - "Axis filter must specify at least one value: {}".format(axis_arg) - ) + raise ValueError(f"Axis filter must specify at least one value: {axis_arg}") display = make_display(name, display_values) filters.append( @@ -270,7 +285,7 @@ def plot_comparison_entries(entries, title=None, dark=False): if not os.environ.get("DISPLAY"): output = "nvbench_compare.png" fig.savefig(output, dpi=150) - print("Saved comparison plot to {}".format(output)) + print(f"Saved comparison plot to {output}") else: plt.show() return 0 @@ -291,7 +306,7 @@ def compare_benches( import matplotlib.pyplot as plt import seaborn as sns - sns.set() + sns.set_theme() comparison_entries = [] comparison_device_names = set() @@ -302,7 +317,7 @@ def compare_benches( if benchmark_filters and cmp_bench["name"] not in benchmark_filters: continue - print("# %s\n" % (cmp_bench["name"])) + print(f"""# {cmp_bench["name"]}\n""") cmp_device_ids = cmp_bench["devices"] axes = cmp_bench["axes"] @@ -422,15 +437,11 @@ def compare_benches( if plot_along: axis_name = [] axis_value = "--" - for aid in range(len(axis_values)): - if axis_values[aid]["name"] != plot_along: - axis_name.append( - "{} = {}".format( - axis_values[aid]["name"], axis_values[aid]["value"] - ) - ) + for av in axis_values: + if av["name"] != plot_along: + axis_name.append(f"""{av["name"]} = {av["value"]}""") else: - axis_value = float(axis_values[aid]["value"]) + axis_value = float(av["value"]) axis_name = ", ".join(axis_name) if axis_name not in plot_data["cmp"]: @@ -453,31 +464,19 @@ def compare_benches( if not min_noise: unknown_count += 1 status_label = "????" - if no_color: - status = f"\U0001f7e1 {status_label}" - else: - status = f"{Fore.YELLOW}{status_label}{Fore.RESET}" + status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color) elif abs(frac_diff) <= min_noise: pass_count += 1 status_label = "SAME" - if no_color: - status = f"\U0001f535 {status_label}" - else: - status = f"{Fore.BLUE}{status_label}{Fore.RESET}" + status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color) elif diff < 0: failure_count += 1 status_label = "FAST" - if no_color: - status = f"\U0001f7e2 {status_label}" - else: - status = f"{Fore.GREEN}{status_label}{Fore.RESET}" + status = colorize(status_label, Fore.GREEN, Emoji.GREEN, no_color) else: failure_count += 1 status_label = "SLOW" - if no_color: - status = f"\U0001f534 {status_label}" - else: - status = f"{Fore.RED}{status_label}{Fore.RESET}" + status = colorize(status_label, Fore.RED, Emoji.RED, no_color) if abs(frac_diff) >= threshold: row.append(format_duration(ref_time)) @@ -492,7 +491,7 @@ def compare_benches( if plot: axis_label = format_axis_values(axis_values, axes, axis_filters) if axis_label: - label = "{} | {}".format(cmp_bench["name"], axis_label) + label = f"""{cmp_bench["name"]} | {axis_label}""" else: label = cmp_bench["name"] cmp_device = find_device_by_id( @@ -563,7 +562,7 @@ def compare_benches( if plot: title = "%SOL Bandwidth change" if len(comparison_device_names) == 1: - title = "{} - {}".format(title, next(iter(comparison_device_names))) + title = f"{title} - {next(iter(comparison_device_names))}" if axis_filters: axis_label = ", ".join( axis_filter["display"] @@ -571,7 +570,7 @@ def compare_benches( if len(axis_filter["values"]) == 1 ) if axis_label: - title = "{} ({})".format(title, axis_label) + title = f"{title} ({axis_label})" plot_comparison_entries(comparison_entries, title=title, dark=dark) @@ -669,14 +668,11 @@ def main(): all_cmp_devices = cmp_root["devices"] if ref_root["devices"] != cmp_root["devices"]: - if args.no_color: - print("Device sections do not match:") - else: - print( - (Fore.YELLOW if args.ignore_devices else Fore.RED) - + "Device sections do not match:" - + Fore.RESET - ) + warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED + msg_text = "Device sections do not match" + print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="") + print(": ", end="") + print( jsondiff.diff( ref_root["devices"], cmp_root["devices"], syntax="symmetric"