diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index 9d54889..c96a4d4 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -86,6 +86,19 @@ def read_nvbench_json_root(filename: str) -> Mapping[str, Any]: f"NVBench JSON file {filename!r} is missing required root key(s): {missing}" ) + for key in ("devices", "benchmarks"): + value = root[key] + if not isinstance(value, list): + raise ValueError( + f"NVBench JSON file {filename!r} root key {key!r} must be an array" + ) + for index, entry in enumerate(value): + if not isinstance(entry, Mapping): + raise ValueError( + f"NVBench JSON file {filename!r} root key {key!r} entry " + f"{index} must be an object" + ) + return root @@ -131,7 +144,7 @@ COMPARISON_THRESHOLD_PRESET_VALUES = { "same_center_relative": 0.01, "same_overlap_fraction": 0.25, "same_relative_dispersion_ceiling": 0.05, - "bulk_same_sample_coverage": 0.98, + "bulk_same_sample_coverage": 0.90, "bulk_same_support_coverage": 0.60, "bulk_support_rare_sample_fraction": 0.001, "bulk_support_max_removed_sample_fraction": 0.02, diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 99a3498..8ef9a05 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -2234,6 +2234,37 @@ def test_main_reports_missing_required_root_keys(monkeypatch, capsys, nvbench_co assert "'benchmarks'" in output +def test_main_rejects_non_array_root_keys(monkeypatch, capsys, nvbench_compare): + monkeypatch.setattr( + nvbench_compare.reader, + "read_file", + lambda _: {"devices": {}, "benchmarks": []}, + ) + monkeypatch.setattr(sys, "argv", ["nvbench_compare", "ref.json", "cmp.json"]) + + assert nvbench_compare.main() == 1 + output = capsys.readouterr().out + assert "NVBench JSON file 'ref.json' root key 'devices' must be an array" in output + + +def test_main_rejects_non_object_root_array_entries( + monkeypatch, capsys, nvbench_compare +): + monkeypatch.setattr( + nvbench_compare.reader, + "read_file", + lambda _: {"devices": [None], "benchmarks": []}, + ) + monkeypatch.setattr(sys, "argv", ["nvbench_compare", "ref.json", "cmp.json"]) + + assert nvbench_compare.main() == 1 + output = capsys.readouterr().out + assert ( + "NVBench JSON file 'ref.json' root key 'devices' entry 0 must be an object" + in output + ) + + def test_main_prints_bulk_debug_python_to_stdout(monkeypatch, capsys, nvbench_compare): devices = [{"id": 0, "name": "Test GPU"}] root = {