Merge pull request #298 from bernhardmgruber/ignore_device

Allow to by-pass device section check and compare different devices
This commit is contained in:
Bernhard Manfred Gruber
2025-12-10 18:24:26 +01:00
committed by GitHub

View File

@@ -18,7 +18,8 @@ def version_tuple(v):
tabulate_version = version_tuple(tabulate.__version__)
all_devices = []
all_ref_devices = []
all_cmp_devices = []
config_count = 0
unknown_count = 0
failure_count = 0
@@ -32,7 +33,7 @@ def find_matching_bench(needle, haystack):
return None
def find_device_by_id(device_id):
def find_device_by_id(device_id, all_devices):
for device in all_devices:
if device["id"] == device_id:
return device
@@ -113,7 +114,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
print("# %s\n" % (cmp_bench["name"]))
device_ids = cmp_bench["devices"]
cmp_device_ids = cmp_bench["devices"]
axes = cmp_bench["axes"]
ref_states = ref_bench["states"]
cmp_states = cmp_bench["states"]
@@ -138,7 +139,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
headers.append("Status")
colalign.append("center")
for device_id in device_ids:
for cmp_device_id in cmp_device_ids:
rows = []
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
@@ -284,8 +285,21 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
if len(rows) == 0:
continue
device = find_device_by_id(device_id)
print("## [%d] %s\n" % (device["id"], device["name"]))
cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices)
ref_device = find_device_by_id(ref_state["device"], all_ref_devices)
if cmp_device == ref_device:
print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"]))
else:
print(
"## [%d] %s vs. [%d] %s\n"
% (
ref_device["id"],
ref_device["name"],
cmp_device["id"],
cmp_device["name"],
)
)
# colalign and github format require tabulate 0.8.3
if tabulate_version >= (0, 8, 3):
print(
@@ -303,7 +317,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
plt.yscale("log")
plt.xlabel(plot)
plt.ylabel("time [s]")
plt.title(device["name"])
plt.title(cmp_device["name"])
def plot_line(key, shape, label):
x = [float(x) for x in plot_data[key][axis].keys()]
@@ -328,6 +342,13 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
def main():
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text)
parser.add_argument(
"--ignore-devices",
dest="ignore_devices",
default=False,
help="Ignore differences in the device sections and compare anyway",
action="store_true",
)
parser.add_argument(
"--threshold-diff",
type=float,
@@ -369,17 +390,24 @@ def main():
ref_root = reader.read_file(ref)
cmp_root = reader.read_file(comp)
global all_devices
all_devices = cmp_root["devices"]
global all_ref_devices
global all_cmp_devices
all_ref_devices = ref_root["devices"]
all_cmp_devices = cmp_root["devices"]
if ref_root["devices"] != cmp_root["devices"]:
print("Device sections do not match.")
print(
(Fore.YELLOW if args.ignore_devices else Fore.RED)
+ "Device sections do not match:"
+ Fore.RESET
)
print(
jsondiff.diff(
ref_root["devices"], cmp_root["devices"], syntax="symmetric"
)
)
sys.exit(1)
if not args.ignore_devices:
sys.exit(1)
compare_benches(
ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot