From 732d227be136712cfd70eeffbf5f016a4696a6bd Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 4 Jun 2026 09:55:58 -0500 Subject: [PATCH] Add TOML configuration for nvbench-compare thresholds Add versioned TOML configuration support for nvbench-compare threshold settings. The new --config option reads grouped settings for clear-gap, same-result, bulk coverage, and rare-support filtering thresholds. The parser validates the schema strictly so unknown tables, unknown keys, invalid types, unsupported versions, and out-of-range values fail early. Add --dump-config to print the effective configuration without requiring input JSON files. This makes the currently selected preset and resolved threshold values discoverable and gives users a starting point for custom configuration. Preset resolution is: - default is used when neither TOML nor CLI selects a preset - [preset] name = "..." in TOML selects the base preset - --preset ... overrides the TOML preset selection - explicit threshold values in TOML override whichever base preset was selected For example: - nvbench-compare --dump-config Prints the built-in default settings as grouped TOML. - nvbench-compare --preset permissive --dump-config Prints the permissive preset values as TOML. - nvbench-compare --config compare.toml ref.json cmp.json Compares using the preset named in compare.toml, plus any explicit TOML threshold overrides. - nvbench-compare --config compare.toml --preset strict ref.json cmp.json Uses the strict preset as the base, while preserving explicit threshold overrides from compare.toml. Keep TOML parsing lazy: Python 3.11+ uses tomllib, while Python 3.10 only requires tomli when --config is used. Add focused tests for grouped config dumping, strict validation, preset/override precedence, and CLI dump behavior. --- python/scripts/nvbench_compare.py | 262 +++++++++++++++++++++++++++- python/test/test_nvbench_compare.py | 175 +++++++++++++++++++ 2 files changed, 433 insertions(+), 4 deletions(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index b86be01..f4bc220 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -9,10 +9,11 @@ import os import sys import warnings from collections import Counter -from dataclasses import dataclass, field +from collections.abc import Mapping +from dataclasses import dataclass, field, replace from enum import Enum from functools import cached_property -from typing import Any, Callable, Mapping +from typing import Any, BinaryIO, Callable, Protocol import jsondiff import numpy as np @@ -51,6 +52,15 @@ SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs" Float32Reader = Callable[[str], object] +class TomlModule(Protocol): + # TOML support is imported lazily. This protocol documents the narrow + # tomllib/tomli module surface used by this script. + @property + def TOMLDecodeError(self) -> type[BaseException]: ... + + def load(self, fp: BinaryIO, /) -> dict[str, Any]: ... + + def read_float32_file(filename: str) -> object: return np.fromfile(filename, dtype=" ComparisonThresholds: try: @@ -117,6 +164,192 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds: raise ValueError(f"unknown comparison preset {preset_name!r}") from exc +def load_toml_module() -> TomlModule: + try: + import tomllib + + return tomllib + except ModuleNotFoundError: + try: + import tomli + + return tomli + except ModuleNotFoundError as exc: + raise ValueError( + "TOML config support requires Python 3.11+ or the tomli package" + ) from exc + + +def validate_config_table(value: object, table_name: str) -> None: + if not isinstance(value, Mapping): + raise ValueError(f"config table [{table_name}] must be a TOML table") + + +def validate_config_float(value: object, key: str, field_name: str) -> float: + if isinstance(value, bool) or not isinstance(value, int | float): + raise ValueError(f"config value {key!r} must be a finite number") + + value = float(value) + if not math.isfinite(value): + raise ValueError(f"config value {key!r} must be finite") + + minimum, maximum = COMPARISON_THRESHOLD_RANGES[field_name] + if value < minimum: + raise ValueError(f"config value {key!r} must be >= {minimum:g}") + if maximum is not None and value > maximum: + raise ValueError(f"config value {key!r} must be <= {maximum:g}") + return value + + +def parse_config_section( + table: Mapping[str, Any], section_name: str +) -> dict[str, float]: + validate_config_table(table, section_name) + known_keys = COMPARISON_CONFIG_KEYS[section_name] + unknown_keys = set(table) - set(known_keys) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [{section_name}]: {unknown}") + + overrides = {} + for key, field_name in known_keys.items(): + if key not in table: + continue + full_key = f"{section_name}.{key}" + overrides[field_name] = validate_config_float(table[key], full_key, field_name) + return overrides + + +def parse_comparison_config_data( + config_data: Mapping[str, Any], +) -> tuple[str | None, dict[str, float]]: + if not isinstance(config_data, Mapping): + raise ValueError("comparison config must be a TOML table") + + unknown_top_level = set(config_data) - ({"version"} | COMPARISON_CONFIG_TABLES) + if unknown_top_level: + unknown = ", ".join(sorted(unknown_top_level)) + raise ValueError(f"unknown top-level config key(s): {unknown}") + + version = config_data.get("version") + if isinstance(version, bool) or not isinstance(version, int): + raise ValueError( + f"comparison config must specify integer version = {COMPARISON_CONFIG_VERSION}" + ) + if version != COMPARISON_CONFIG_VERSION: + raise ValueError( + f"unsupported comparison config version {version!r}; " + f"expected {COMPARISON_CONFIG_VERSION}" + ) + + preset_name = None + if "preset" in config_data: + preset_table = config_data["preset"] + validate_config_table(preset_table, "preset") + unknown_keys = set(preset_table) - {"name"} + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [preset]: {unknown}") + if "name" in preset_table: + preset_name = preset_table["name"] + if not isinstance(preset_name, str): + raise ValueError("config value 'preset.name' must be a string") + get_comparison_thresholds(preset_name) + + overrides = {} + for section_name in ("clear_gap", "same"): + if section_name in config_data: + overrides.update( + parse_config_section(config_data[section_name], section_name) + ) + + if "bulk" in config_data: + bulk_table = config_data["bulk"] + validate_config_table(bulk_table, "bulk") + known_bulk_keys = set(COMPARISON_CONFIG_KEYS["bulk"]) | {"rare_support"} + unknown_keys = set(bulk_table) - known_bulk_keys + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [bulk]: {unknown}") + + bulk_values = { + key: value for key, value in bulk_table.items() if key != "rare_support" + } + overrides.update(parse_config_section(bulk_values, "bulk")) + if "rare_support" in bulk_table: + overrides.update( + parse_config_section(bulk_table["rare_support"], "bulk.rare_support") + ) + + return preset_name, overrides + + +def read_comparison_config_file( + config_path: str | os.PathLike[str], +) -> tuple[str | None, dict[str, float]]: + toml_module = load_toml_module() + try: + with open(config_path, "rb") as config_file: + config_data = toml_module.load(config_file) + except toml_module.TOMLDecodeError as exc: + raise ValueError( + f"failed to parse comparison config {config_path!r}: {exc}" + ) from exc + except OSError as exc: + raise ValueError( + f"failed to read comparison config {config_path!r}: {exc}" + ) from exc + + return parse_comparison_config_data(config_data) + + +def resolve_comparison_thresholds( + cli_preset_name: str | None = None, + config_path: str | os.PathLike[str] | None = None, +) -> tuple[str, ComparisonThresholds]: + config_preset_name = None + config_overrides: dict[str, float] = {} + if config_path is not None: + config_preset_name, config_overrides = read_comparison_config_file(config_path) + + preset_name = cli_preset_name or config_preset_name or COMPARISON_DEFAULT_PRESET + thresholds = replace(get_comparison_thresholds(preset_name), **config_overrides) + return preset_name, thresholds + + +def format_toml_float(value: float) -> str: + return repr(float(value)) + + +def dump_comparison_config(preset_name: str, thresholds: ComparisonThresholds) -> str: + lines = [ + f"version = {COMPARISON_CONFIG_VERSION}", + "", + "[preset]", + f'name = "{preset_name}"', + "", + "[clear_gap]", + f"relative = {format_toml_float(thresholds.clear_gap_relative)}", + "", + "[same]", + f"center_relative = {format_toml_float(thresholds.same_center_relative)}", + f"overlap_fraction = {format_toml_float(thresholds.same_overlap_fraction)}", + "relative_dispersion_ceiling = " + f"{format_toml_float(thresholds.same_relative_dispersion_ceiling)}", + "", + "[bulk]", + f"sample_coverage = {format_toml_float(thresholds.bulk_same_sample_coverage)}", + f"support_coverage = {format_toml_float(thresholds.bulk_same_support_coverage)}", + "", + "[bulk.rare_support]", + "sample_fraction = " + f"{format_toml_float(thresholds.bulk_support_rare_sample_fraction)}", + "max_removed_sample_fraction = " + f"{format_toml_float(thresholds.bulk_support_max_removed_sample_fraction)}", + ] + return "\n".join(lines) + "\n" + + @dataclass(frozen=True) class SupportFilterInfo: activated: bool @@ -2109,9 +2342,19 @@ def main() -> int: parser.add_argument( "--preset", choices=sorted(COMPARISON_THRESHOLD_PRESETS), - default="default", + default=None, help="comparison threshold preset", ) + parser.add_argument( + "--config", + default=None, + help="comparison threshold TOML config", + ) + parser.add_argument( + "--dump-config", + action="store_true", + help="print the effective comparison threshold config and exit", + ) parser.add_argument( "--display", choices=["intervals", "legacy", "explain"], @@ -2169,6 +2412,18 @@ def main() -> int: ) args, files_or_dirs = parser.parse_known_args() + try: + comparison_preset, comparison_thresholds = resolve_comparison_thresholds( + args.preset, args.config + ) + except ValueError as exc: + print(str(exc)) + return 1 + + if args.dump_config: + print(dump_comparison_config(comparison_preset, comparison_thresholds), end="") + return 0 + try: filter_plan = build_benchmark_filter_plan(args.filter_actions) reference_device_filter = parse_device_filter( @@ -2177,7 +2432,6 @@ def main() -> int: compare_device_filter = parse_device_filter( args.compare_devices, "--compare-devices" ) - comparison_thresholds = get_comparison_thresholds(args.preset) except ValueError as exc: print(str(exc)) return 1 diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index e982b9d..3ae755a 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -1380,6 +1380,181 @@ def test_get_comparison_thresholds_returns_named_presets(nvbench_compare): assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage +def test_dump_comparison_config_uses_grouped_toml(nvbench_compare): + config = nvbench_compare.dump_comparison_config( + "default", nvbench_compare.get_comparison_thresholds("default") + ) + + assert "version = 1\n" in config + assert '[preset]\nname = "default"\n' in config + assert "[clear_gap]\nrelative = 0.005\n" in config + assert "[same]\n" in config + assert "[bulk]\n" in config + assert "sample_coverage = 0.97\n" in config + assert "[bulk.rare_support]\n" in config + + +def test_resolve_comparison_thresholds_applies_config_overrides( + monkeypatch, nvbench_compare +): + def read_config(_): + return ( + "strict", + { + "bulk_same_sample_coverage": 0.93, + "bulk_support_max_removed_sample_fraction": 0.02, + }, + ) + + monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config) + + preset, thresholds = nvbench_compare.resolve_comparison_thresholds( + None, "settings.toml" + ) + assert preset == "strict" + assert thresholds.clear_gap_relative == pytest.approx( + nvbench_compare.get_comparison_thresholds("strict").clear_gap_relative + ) + assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93) + assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02) + + preset, thresholds = nvbench_compare.resolve_comparison_thresholds( + "permissive", "settings.toml" + ) + assert preset == "permissive" + assert thresholds.clear_gap_relative == pytest.approx( + nvbench_compare.get_comparison_thresholds("permissive").clear_gap_relative + ) + assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93) + assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02) + + +def test_parse_comparison_config_data_validates_grouped_thresholds(nvbench_compare): + preset, overrides = nvbench_compare.parse_comparison_config_data( + { + "version": 1, + "preset": {"name": "strict"}, + "clear_gap": {"relative": 0.01}, + "same": { + "center_relative": 0.002, + "overlap_fraction": 0.75, + "relative_dispersion_ceiling": 0.02, + }, + "bulk": { + "sample_coverage": 0.99, + "support_coverage": 0.8, + "rare_support": { + "sample_fraction": 0.001, + "max_removed_sample_fraction": 0.01, + }, + }, + } + ) + + assert preset == "strict" + assert overrides == { + "clear_gap_relative": 0.01, + "same_center_relative": 0.002, + "same_overlap_fraction": 0.75, + "same_relative_dispersion_ceiling": 0.02, + "bulk_same_sample_coverage": 0.99, + "bulk_same_support_coverage": 0.8, + "bulk_support_rare_sample_fraction": 0.001, + "bulk_support_max_removed_sample_fraction": 0.01, + } + + +@pytest.mark.parametrize( + "config_data, match", + [ + ({}, "version"), + ({"version": 2}, "unsupported"), + ({"version": 1, "rare_support": {}}, "unknown top-level"), + ({"version": 1, "bulk": {"unknown": 0.1}}, r"\[bulk\]"), + ({"version": 1, "clear_gap": {"rare_support": {}}}, r"\[clear_gap\]"), + ({"version": 1, "bulk": {"sample_coverage": 1.5}}, "<= 1"), + ({"version": 1, "same": {"center_relative": "tight"}}, "finite number"), + ({"version": 1, "preset": {"name": "aggressive"}}, "unknown comparison preset"), + ], +) +def test_parse_comparison_config_data_rejects_invalid_config( + nvbench_compare, config_data, match +): + with pytest.raises(ValueError, match=match): + nvbench_compare.parse_comparison_config_data(config_data) + + +def test_read_comparison_config_file_parses_toml_when_parser_is_available( + tmp_path, nvbench_compare +): + parser_module = "tomllib" if sys.version_info >= (3, 11) else "tomli" + pytest.importorskip(parser_module) + config_path = tmp_path / "settings.toml" + config_path.write_text( + """ +version = 1 + +[preset] +name = "strict" + +[bulk] +sample_coverage = 0.93 +""", + encoding="utf-8", + ) + + preset, overrides = nvbench_compare.read_comparison_config_file(config_path) + + assert preset == "strict" + assert overrides == {"bulk_same_sample_coverage": 0.93} + + +def test_main_dump_config_does_not_require_input_files( + monkeypatch, capsys, nvbench_compare +): + def read_file(_): + raise AssertionError("dump-config should not read JSON files") + + monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) + monkeypatch.setattr( + sys, + "argv", + ["nvbench_compare", "--preset", "strict", "--dump-config"], + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert 'name = "strict"' in output + assert "[bulk.rare_support]" in output + + +def test_main_dump_config_merges_config_and_cli_preset( + monkeypatch, capsys, nvbench_compare +): + def read_config(_): + return ("strict", {"bulk_same_sample_coverage": 0.93}) + + monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config) + monkeypatch.setattr( + sys, + "argv", + [ + "nvbench_compare", + "--config", + "settings.toml", + "--preset", + "permissive", + "--dump-config", + ], + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert 'name = "permissive"' in output + assert "relative = 0.0025" in output + assert "sample_coverage = 0.93" in output + + def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compare): run_data = make_comparison_run_data(nvbench_compare) captured = {}