diff --git a/scripts/nvbench_compare.py b/scripts/nvbench_compare.py index 8a25b69..50f1fc3 100755 --- a/scripts/nvbench_compare.py +++ b/scripts/nvbench_compare.py @@ -1,10 +1,12 @@ #!/usr/bin/env python -from colorama import Fore +import argparse import json import math import sys +from colorama import Fore + import tabulate # Parse version string into tuple, "x.y.z" -> (x, y, z) @@ -13,22 +15,7 @@ def version_tuple(v): tabulate_version = version_tuple(tabulate.__version__) -if len(sys.argv) != 3: - print("Usage: %s reference.json compare.json\n" % sys.argv[0]) - sys.exit(1) - -with open(sys.argv[1], "r") as ref_file: - ref_root = json.load(ref_file) - -with open(sys.argv[2], "r") as cmp_file: - cmp_root = json.load(cmp_file) - -# This is blunt but works for now: -if ref_root["devices"] != cmp_root["devices"]: - print("Device sections do not match.") - sys.exit(1) - -all_devices = cmp_root["devices"] +all_devices = [] config_count = 0 unknown_count = 0 failure_count = 0 @@ -250,12 +237,35 @@ def compare_benches(ref_benches, cmp_benches): print("") -compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"]) +def main(): + if len(sys.argv) != 3: + print("Usage: %s reference.json compare.json\n" % sys.argv[0]) + sys.exit(1) -print("# Summary\n") -print("- Total Matches: %d" % config_count) -print(" - Pass (diff <= min_noise): %d" % pass_count) -print(" - Unknown (infinite noise): %d" % unknown_count) -print(" - Failure (diff > min_noise): %d" % failure_count) + with open(sys.argv[1], "r") as ref_file: + ref_root = json.load(ref_file) -sys.exit(failure_count) + with open(sys.argv[2], "r") as cmp_file: + cmp_root = json.load(cmp_file) + + global all_devices + all_devices = cmp_root["devices"] + + # This is blunt but works for now: + if ref_root["devices"] != cmp_root["devices"]: + print("Device sections do not match.") + sys.exit(1) + + compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"]) + + print("# Summary\n") + print("- Total Matches: %d" % config_count) + print(" - Pass (diff <= min_noise): %d" % pass_count) + print(" - Unknown (infinite noise): %d" % unknown_count) + print(" - Failure (diff > min_noise): %d" % failure_count) + + return failure_count + + +if __name__ == '__main__': + sys.exit(main())