mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-05-12 09:15:47 +00:00
Modularize color handling (#336)
* Introduce function colorize to modularize colorization/no-color handling * Use sns.set_theme instead of deprecated sns.set() * Use str.format instead of legacy % syntax * Simplified iteration over list Use f-string (supported since Python 3.6) instead of str.format for better readability and performance
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from enum import StrEnum
|
||||
|
||||
import jsondiff
|
||||
import tabulate
|
||||
@@ -30,6 +31,24 @@ failure_count = 0
|
||||
pass_count = 0
|
||||
|
||||
|
||||
class Emoji(StrEnum):
|
||||
YELLOW = "\U0001f7e1"
|
||||
BLUE = "\U0001f535"
|
||||
GREEN = "\U0001f7e2"
|
||||
RED = "\U0001f534"
|
||||
NONE = ""
|
||||
|
||||
|
||||
def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str:
|
||||
if no_color:
|
||||
prefix = ""
|
||||
if emoji_s := str(emoji):
|
||||
prefix = f"{emoji_s} "
|
||||
return f"{prefix}{msg}"
|
||||
else:
|
||||
return f"{fore}{msg}{Fore.RESET}"
|
||||
|
||||
|
||||
def find_matching_bench(needle, haystack):
|
||||
for hay in haystack:
|
||||
if hay["name"] == needle["name"]:
|
||||
@@ -89,12 +108,12 @@ def parse_axis_filters(axis_args):
|
||||
filters = []
|
||||
for axis_arg in axis_args:
|
||||
if "=" not in axis_arg:
|
||||
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||
raise ValueError(f"Axis filter must be NAME=VALUE: {axis_arg}")
|
||||
name, value = axis_arg.split("=", 1)
|
||||
name = name.strip()
|
||||
value = value.strip()
|
||||
if not name or not value:
|
||||
raise ValueError("Axis filter must be NAME=VALUE: {}".format(axis_arg))
|
||||
raise ValueError(f"Axis filter must be NAME=VALUE: {axis_arg}")
|
||||
|
||||
values = []
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
@@ -109,22 +128,18 @@ def parse_axis_filters(axis_args):
|
||||
if name.endswith("[pow2]"):
|
||||
name = name[: -len("[pow2]")].strip()
|
||||
if not name:
|
||||
raise ValueError(
|
||||
"Axis filter missing name before [pow2]: {}".format(axis_arg)
|
||||
)
|
||||
raise ValueError(f"Axis filter missing name before [pow2]: {axis_arg}")
|
||||
try:
|
||||
exponents = [int(v) for v in values]
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Axis filter [pow2] value must be integer: {}".format(axis_arg)
|
||||
f"Axis filter [pow2] value must be integer: {axis_arg}"
|
||||
) from exc
|
||||
values = [str(2**exponent) for exponent in exponents]
|
||||
display_values = ["2^{}".format(exponent) for exponent in exponents]
|
||||
display_values = [f"2^{exponent}" for exponent in exponents]
|
||||
|
||||
if not values:
|
||||
raise ValueError(
|
||||
"Axis filter must specify at least one value: {}".format(axis_arg)
|
||||
)
|
||||
raise ValueError(f"Axis filter must specify at least one value: {axis_arg}")
|
||||
|
||||
display = make_display(name, display_values)
|
||||
filters.append(
|
||||
@@ -270,7 +285,7 @@ def plot_comparison_entries(entries, title=None, dark=False):
|
||||
if not os.environ.get("DISPLAY"):
|
||||
output = "nvbench_compare.png"
|
||||
fig.savefig(output, dpi=150)
|
||||
print("Saved comparison plot to {}".format(output))
|
||||
print(f"Saved comparison plot to {output}")
|
||||
else:
|
||||
plt.show()
|
||||
return 0
|
||||
@@ -291,7 +306,7 @@ def compare_benches(
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set()
|
||||
sns.set_theme()
|
||||
|
||||
comparison_entries = []
|
||||
comparison_device_names = set()
|
||||
@@ -302,7 +317,7 @@ def compare_benches(
|
||||
if benchmark_filters and cmp_bench["name"] not in benchmark_filters:
|
||||
continue
|
||||
|
||||
print("# %s\n" % (cmp_bench["name"]))
|
||||
print(f"""# {cmp_bench["name"]}\n""")
|
||||
|
||||
cmp_device_ids = cmp_bench["devices"]
|
||||
axes = cmp_bench["axes"]
|
||||
@@ -422,15 +437,11 @@ def compare_benches(
|
||||
if plot_along:
|
||||
axis_name = []
|
||||
axis_value = "--"
|
||||
for aid in range(len(axis_values)):
|
||||
if axis_values[aid]["name"] != plot_along:
|
||||
axis_name.append(
|
||||
"{} = {}".format(
|
||||
axis_values[aid]["name"], axis_values[aid]["value"]
|
||||
)
|
||||
)
|
||||
for av in axis_values:
|
||||
if av["name"] != plot_along:
|
||||
axis_name.append(f"""{av["name"]} = {av["value"]}""")
|
||||
else:
|
||||
axis_value = float(axis_values[aid]["value"])
|
||||
axis_value = float(av["value"])
|
||||
axis_name = ", ".join(axis_name)
|
||||
|
||||
if axis_name not in plot_data["cmp"]:
|
||||
@@ -453,31 +464,19 @@ def compare_benches(
|
||||
if not min_noise:
|
||||
unknown_count += 1
|
||||
status_label = "????"
|
||||
if no_color:
|
||||
status = f"\U0001f7e1 {status_label}"
|
||||
else:
|
||||
status = f"{Fore.YELLOW}{status_label}{Fore.RESET}"
|
||||
status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color)
|
||||
elif abs(frac_diff) <= min_noise:
|
||||
pass_count += 1
|
||||
status_label = "SAME"
|
||||
if no_color:
|
||||
status = f"\U0001f535 {status_label}"
|
||||
else:
|
||||
status = f"{Fore.BLUE}{status_label}{Fore.RESET}"
|
||||
status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color)
|
||||
elif diff < 0:
|
||||
failure_count += 1
|
||||
status_label = "FAST"
|
||||
if no_color:
|
||||
status = f"\U0001f7e2 {status_label}"
|
||||
else:
|
||||
status = f"{Fore.GREEN}{status_label}{Fore.RESET}"
|
||||
status = colorize(status_label, Fore.GREEN, Emoji.GREEN, no_color)
|
||||
else:
|
||||
failure_count += 1
|
||||
status_label = "SLOW"
|
||||
if no_color:
|
||||
status = f"\U0001f534 {status_label}"
|
||||
else:
|
||||
status = f"{Fore.RED}{status_label}{Fore.RESET}"
|
||||
status = colorize(status_label, Fore.RED, Emoji.RED, no_color)
|
||||
|
||||
if abs(frac_diff) >= threshold:
|
||||
row.append(format_duration(ref_time))
|
||||
@@ -492,7 +491,7 @@ def compare_benches(
|
||||
if plot:
|
||||
axis_label = format_axis_values(axis_values, axes, axis_filters)
|
||||
if axis_label:
|
||||
label = "{} | {}".format(cmp_bench["name"], axis_label)
|
||||
label = f"""{cmp_bench["name"]} | {axis_label}"""
|
||||
else:
|
||||
label = cmp_bench["name"]
|
||||
cmp_device = find_device_by_id(
|
||||
@@ -563,7 +562,7 @@ def compare_benches(
|
||||
if plot:
|
||||
title = "%SOL Bandwidth change"
|
||||
if len(comparison_device_names) == 1:
|
||||
title = "{} - {}".format(title, next(iter(comparison_device_names)))
|
||||
title = f"{title} - {next(iter(comparison_device_names))}"
|
||||
if axis_filters:
|
||||
axis_label = ", ".join(
|
||||
axis_filter["display"]
|
||||
@@ -571,7 +570,7 @@ def compare_benches(
|
||||
if len(axis_filter["values"]) == 1
|
||||
)
|
||||
if axis_label:
|
||||
title = "{} ({})".format(title, axis_label)
|
||||
title = f"{title} ({axis_label})"
|
||||
plot_comparison_entries(comparison_entries, title=title, dark=dark)
|
||||
|
||||
|
||||
@@ -669,14 +668,11 @@ def main():
|
||||
all_cmp_devices = cmp_root["devices"]
|
||||
|
||||
if ref_root["devices"] != cmp_root["devices"]:
|
||||
if args.no_color:
|
||||
print("Device sections do not match:")
|
||||
else:
|
||||
print(
|
||||
(Fore.YELLOW if args.ignore_devices else Fore.RED)
|
||||
+ "Device sections do not match:"
|
||||
+ Fore.RESET
|
||||
)
|
||||
warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED
|
||||
msg_text = "Device sections do not match"
|
||||
print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="")
|
||||
print(": ", end="")
|
||||
|
||||
print(
|
||||
jsondiff.diff(
|
||||
ref_root["devices"], cmp_root["devices"], syntax="symmetric"
|
||||
|
||||
Reference in New Issue
Block a user