mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-03-14 20:27:24 +00:00
Merge pull request #298 from bernhardmgruber/ignore_device
Allow to by-pass device section check and compare different devices
This commit is contained in:
@@ -18,7 +18,8 @@ def version_tuple(v):
|
||||
|
||||
tabulate_version = version_tuple(tabulate.__version__)
|
||||
|
||||
all_devices = []
|
||||
all_ref_devices = []
|
||||
all_cmp_devices = []
|
||||
config_count = 0
|
||||
unknown_count = 0
|
||||
failure_count = 0
|
||||
@@ -32,7 +33,7 @@ def find_matching_bench(needle, haystack):
|
||||
return None
|
||||
|
||||
|
||||
def find_device_by_id(device_id):
|
||||
def find_device_by_id(device_id, all_devices):
|
||||
for device in all_devices:
|
||||
if device["id"] == device_id:
|
||||
return device
|
||||
@@ -113,7 +114,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
|
||||
print("# %s\n" % (cmp_bench["name"]))
|
||||
|
||||
device_ids = cmp_bench["devices"]
|
||||
cmp_device_ids = cmp_bench["devices"]
|
||||
axes = cmp_bench["axes"]
|
||||
ref_states = ref_bench["states"]
|
||||
cmp_states = cmp_bench["states"]
|
||||
@@ -138,7 +139,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
headers.append("Status")
|
||||
colalign.append("center")
|
||||
|
||||
for device_id in device_ids:
|
||||
for cmp_device_id in cmp_device_ids:
|
||||
rows = []
|
||||
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
|
||||
|
||||
@@ -284,8 +285,21 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
if len(rows) == 0:
|
||||
continue
|
||||
|
||||
device = find_device_by_id(device_id)
|
||||
print("## [%d] %s\n" % (device["id"], device["name"]))
|
||||
cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices)
|
||||
ref_device = find_device_by_id(ref_state["device"], all_ref_devices)
|
||||
|
||||
if cmp_device == ref_device:
|
||||
print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"]))
|
||||
else:
|
||||
print(
|
||||
"## [%d] %s vs. [%d] %s\n"
|
||||
% (
|
||||
ref_device["id"],
|
||||
ref_device["name"],
|
||||
cmp_device["id"],
|
||||
cmp_device["name"],
|
||||
)
|
||||
)
|
||||
# colalign and github format require tabulate 0.8.3
|
||||
if tabulate_version >= (0, 8, 3):
|
||||
print(
|
||||
@@ -303,7 +317,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
plt.yscale("log")
|
||||
plt.xlabel(plot)
|
||||
plt.ylabel("time [s]")
|
||||
plt.title(device["name"])
|
||||
plt.title(cmp_device["name"])
|
||||
|
||||
def plot_line(key, shape, label):
|
||||
x = [float(x) for x in plot_data[key][axis].keys()]
|
||||
@@ -328,6 +342,13 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
|
||||
def main():
|
||||
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
|
||||
parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text)
|
||||
parser.add_argument(
|
||||
"--ignore-devices",
|
||||
dest="ignore_devices",
|
||||
default=False,
|
||||
help="Ignore differences in the device sections and compare anyway",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold-diff",
|
||||
type=float,
|
||||
@@ -369,17 +390,24 @@ def main():
|
||||
ref_root = reader.read_file(ref)
|
||||
cmp_root = reader.read_file(comp)
|
||||
|
||||
global all_devices
|
||||
all_devices = cmp_root["devices"]
|
||||
global all_ref_devices
|
||||
global all_cmp_devices
|
||||
all_ref_devices = ref_root["devices"]
|
||||
all_cmp_devices = cmp_root["devices"]
|
||||
|
||||
if ref_root["devices"] != cmp_root["devices"]:
|
||||
print("Device sections do not match.")
|
||||
print(
|
||||
(Fore.YELLOW if args.ignore_devices else Fore.RED)
|
||||
+ "Device sections do not match:"
|
||||
+ Fore.RESET
|
||||
)
|
||||
print(
|
||||
jsondiff.diff(
|
||||
ref_root["devices"], cmp_root["devices"], syntax="symmetric"
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
if not args.ignore_devices:
|
||||
sys.exit(1)
|
||||
|
||||
compare_benches(
|
||||
ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot
|
||||
|
||||
Reference in New Issue
Block a user