diff --git a/scripts/nvbench_compare.py b/scripts/nvbench_compare.py index 4d4e4c4..bf07fe6 100755 --- a/scripts/nvbench_compare.py +++ b/scripts/nvbench_compare.py @@ -18,7 +18,8 @@ def version_tuple(v): tabulate_version = version_tuple(tabulate.__version__) -all_devices = [] +all_ref_devices = [] +all_cmp_devices = [] config_count = 0 unknown_count = 0 failure_count = 0 @@ -32,7 +33,7 @@ def find_matching_bench(needle, haystack): return None -def find_device_by_id(device_id): +def find_device_by_id(device_id, all_devices): for device in all_devices: if device["id"] == device_id: return device @@ -113,7 +114,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): print("# %s\n" % (cmp_bench["name"])) - device_ids = cmp_bench["devices"] + cmp_device_ids = cmp_bench["devices"] axes = cmp_bench["axes"] ref_states = ref_bench["states"] cmp_states = cmp_bench["states"] @@ -138,7 +139,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): headers.append("Status") colalign.append("center") - for device_id in device_ids: + for cmp_device_id in cmp_device_ids: rows = [] plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} @@ -284,8 +285,21 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): if len(rows) == 0: continue - device = find_device_by_id(device_id) - print("## [%d] %s\n" % (device["id"], device["name"])) + cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices) + ref_device = find_device_by_id(ref_state["device"], all_ref_devices) + + if cmp_device == ref_device: + print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"])) + else: + print( + "## [%d] %s vs. [%d] %s\n" + % ( + ref_device["id"], + ref_device["name"], + cmp_device["id"], + cmp_device["name"], + ) + ) # colalign and github format require tabulate 0.8.3 if tabulate_version >= (0, 8, 3): print( @@ -303,7 +317,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): plt.yscale("log") plt.xlabel(plot) plt.ylabel("time [s]") - plt.title(device["name"]) + plt.title(cmp_device["name"]) def plot_line(key, shape, label): x = [float(x) for x in plot_data[key][axis].keys()] @@ -328,6 +342,13 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot): def main(): help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]" parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text) + parser.add_argument( + "--ignore-devices", + dest="ignore_devices", + default=False, + help="Ignore differences in the device sections and compare anyway", + action="store_true", + ) parser.add_argument( "--threshold-diff", type=float, @@ -369,17 +390,24 @@ def main(): ref_root = reader.read_file(ref) cmp_root = reader.read_file(comp) - global all_devices - all_devices = cmp_root["devices"] + global all_ref_devices + global all_cmp_devices + all_ref_devices = ref_root["devices"] + all_cmp_devices = cmp_root["devices"] if ref_root["devices"] != cmp_root["devices"]: - print("Device sections do not match.") + print( + (Fore.YELLOW if args.ignore_devices else Fore.RED) + + "Device sections do not match:" + + Fore.RESET + ) print( jsondiff.diff( ref_root["devices"], cmp_root["devices"], syntax="symmetric" ) ) - sys.exit(1) + if not args.ignore_devices: + sys.exit(1) compare_benches( ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot