diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index c637033..209e0d1 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -5,6 +5,7 @@ import math import os import sys from enum import StrEnum +from itertools import islice import jsondiff import tabulate @@ -347,11 +348,18 @@ def compare_benches( for cmp_device_id in cmp_device_ids: rows = [] plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} + counters = {} for cmp_state in cmp_states: cmp_state_name = cmp_state["name"] + counters[cmp_state_name] = counters.get(cmp_state_name, 0) + 1 ref_state = next( - filter(lambda st: st["name"] == cmp_state_name, ref_states), None + islice( + filter(lambda st: st["name"] == cmp_state_name, ref_states), + counters[cmp_state_name] - 1, + None, + ), + None, ) if not ref_state: continue @@ -424,15 +432,15 @@ def compare_benches( if ref_noise and cmp_noise: ref_noise = float(ref_noise) cmp_noise = float(cmp_noise) - min_noise = min(ref_noise, cmp_noise) + max_noise = max(ref_noise, cmp_noise) elif ref_noise: ref_noise = float(ref_noise) - min_noise = ref_noise + max_noise = ref_noise elif cmp_noise: cmp_noise = float(cmp_noise) - min_noise = cmp_noise + max_noise = cmp_noise else: - min_noise = None # Noise is inf + max_noise = None # Noise is inf if plot_along: axis_name = [] @@ -461,11 +469,11 @@ def compare_benches( global failure_count config_count += 1 - if not min_noise: + if max_noise is None: unknown_count += 1 status_label = "????" status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color) - elif abs(frac_diff) <= min_noise: + elif abs(frac_diff) <= max_noise: pass_count += 1 status_label = "SAME" status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color) @@ -695,9 +703,9 @@ def main(): print("# Summary\n") print("- Total Matches: %d" % config_count) - print(" - Pass (diff <= min_noise): %d" % pass_count) + print(" - Pass (diff <= max_noise): %d" % pass_count) print(" - Unknown (infinite noise): %d" % unknown_count) - print(" - Failure (diff > min_noise): %d" % failure_count) + print(" - Failure (diff > max_noise): %d" % failure_count) return failure_count