diff --git a/python/scripts/nvbench_plot.py b/python/scripts/nvbench_plot_bwutil.py similarity index 94% rename from python/scripts/nvbench_plot.py rename to python/scripts/nvbench_plot_bwutil.py index f8a34df..a7e7148 100644 --- a/python/scripts/nvbench_plot.py +++ b/python/scripts/nvbench_plot_bwutil.py @@ -12,7 +12,6 @@ try: except ImportError: from scripts.nvbench_json import reader - UTILIZATION_TAG = "nv/cold/bw/global/utilization" @@ -55,12 +54,11 @@ def parse_files(): for file_or_dir in files_or_dirs: if os.path.isdir(file_or_dir): for f in os.listdir(file_or_dir): - if os.path.splitext(f)[1] != ".json": - continue filename = os.path.join(file_or_dir, f) if os.path.isfile(filename) and os.path.getsize(filename) > 0: filenames.append(filename) else: + assert os.path.isfile(file_or_dir) filenames.append(file_or_dir) filenames.sort() @@ -107,7 +105,9 @@ def parse_axis_filters(axis_args): if value.startswith("[") and value.endswith("]"): inner = value[1:-1].strip() if inner: - values = [item.strip() for item in inner.split(",") if item.strip()] + values = [ + stripped for item in inner.split(",") if (stripped := item.strip()) + ] else: values = [] else: @@ -189,7 +189,9 @@ def strip_axis_filters_from_state_name(state_name, axis_filters): return " ".join(tokens) -def collect_entries(filename, axis_filters, benchmark_filters): +def collect_entries( + filename: str, axis_filters: list[dict], benchmark_filters: list[str] +) -> tuple[list[tuple[str, float, str]], set[str]]: json_root = reader.read_file(filename) entries = [] device_names = set() @@ -225,14 +227,8 @@ def plot_entries(entries, title=None, output=None, dark=False): print("No utilization data found.") return 1 - labels = [entry[0] for entry in entries] - values = [entry[1] for entry in entries] - bench_names = [entry[2] for entry in entries] - - unique_benches = [] - for bench in bench_names: - if bench not in unique_benches: - unique_benches.append(bench) + labels, values, bench_names = map(list, zip(*entries)) + unique_benches = list(set(bench_names)) cmap = plt.get_cmap("tab20", max(len(unique_benches), 1)) bench_colors = {bench: cmap(index) for index, bench in enumerate(unique_benches)}