mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Apply reviewer feedback
This commit is contained in:
@@ -12,7 +12,6 @@ try:
|
||||
except ImportError:
|
||||
from scripts.nvbench_json import reader
|
||||
|
||||
|
||||
UTILIZATION_TAG = "nv/cold/bw/global/utilization"
|
||||
|
||||
|
||||
@@ -55,12 +54,11 @@ def parse_files():
|
||||
for file_or_dir in files_or_dirs:
|
||||
if os.path.isdir(file_or_dir):
|
||||
for f in os.listdir(file_or_dir):
|
||||
if os.path.splitext(f)[1] != ".json":
|
||||
continue
|
||||
filename = os.path.join(file_or_dir, f)
|
||||
if os.path.isfile(filename) and os.path.getsize(filename) > 0:
|
||||
filenames.append(filename)
|
||||
else:
|
||||
assert os.path.isfile(file_or_dir)
|
||||
filenames.append(file_or_dir)
|
||||
|
||||
filenames.sort()
|
||||
@@ -107,7 +105,9 @@ def parse_axis_filters(axis_args):
|
||||
if value.startswith("[") and value.endswith("]"):
|
||||
inner = value[1:-1].strip()
|
||||
if inner:
|
||||
values = [item.strip() for item in inner.split(",") if item.strip()]
|
||||
values = [
|
||||
stripped for item in inner.split(",") if (stripped := item.strip())
|
||||
]
|
||||
else:
|
||||
values = []
|
||||
else:
|
||||
@@ -189,7 +189,9 @@ def strip_axis_filters_from_state_name(state_name, axis_filters):
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def collect_entries(filename, axis_filters, benchmark_filters):
|
||||
def collect_entries(
|
||||
filename: str, axis_filters: list[dict], benchmark_filters: list[str]
|
||||
) -> tuple[list[tuple[str, float, str]], set[str]]:
|
||||
json_root = reader.read_file(filename)
|
||||
entries = []
|
||||
device_names = set()
|
||||
@@ -225,14 +227,8 @@ def plot_entries(entries, title=None, output=None, dark=False):
|
||||
print("No utilization data found.")
|
||||
return 1
|
||||
|
||||
labels = [entry[0] for entry in entries]
|
||||
values = [entry[1] for entry in entries]
|
||||
bench_names = [entry[2] for entry in entries]
|
||||
|
||||
unique_benches = []
|
||||
for bench in bench_names:
|
||||
if bench not in unique_benches:
|
||||
unique_benches.append(bench)
|
||||
labels, values, bench_names = map(list, zip(*entries))
|
||||
unique_benches = list(set(bench_names))
|
||||
|
||||
cmap = plt.get_cmap("tab20", max(len(unique_benches), 1))
|
||||
bench_colors = {bench: cmap(index) for index, bench in enumerate(unique_benches)}
|
||||
Reference in New Issue
Block a user