Add TOML configuration for nvbench-compare thresholds

Add versioned TOML configuration support for nvbench-compare threshold
settings. The new --config option reads grouped settings for clear-gap,
same-result, bulk coverage, and rare-support filtering thresholds. The parser
validates the schema strictly so unknown tables, unknown keys, invalid types,
unsupported versions, and out-of-range values fail early.

Add --dump-config to print the effective configuration without requiring input
JSON files. This makes the currently selected preset and resolved threshold
values discoverable and gives users a starting point for custom configuration.

Preset resolution is:
  - default is used when neither TOML nor CLI selects a preset
  - [preset] name = "..." in TOML selects the base preset
  - --preset ... overrides the TOML preset selection
  - explicit threshold values in TOML override whichever base preset was selected

For example:
  - nvbench-compare --dump-config
    Prints the built-in default settings as grouped TOML.

  - nvbench-compare --preset permissive --dump-config
    Prints the permissive preset values as TOML.

  - nvbench-compare --config compare.toml ref.json cmp.json
    Compares using the preset named in compare.toml, plus any explicit TOML
    threshold overrides.

  - nvbench-compare --config compare.toml --preset strict ref.json cmp.json
    Uses the strict preset as the base, while preserving explicit threshold
    overrides from compare.toml.

Keep TOML parsing lazy: Python 3.11+ uses tomllib, while Python 3.10 only
requires tomli when --config is used. Add focused tests for grouped config
dumping, strict validation, preset/override precedence, and CLI dump behavior.
This commit is contained in:
Oleksandr Pavlyk
2026-06-04 09:55:58 -05:00
parent 2585842cf5
commit 732d227be1
2 changed files with 433 additions and 4 deletions

View File

@@ -9,10 +9,11 @@ import os
import sys
import warnings
from collections import Counter
from dataclasses import dataclass, field
from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import cached_property
from typing import Any, Callable, Mapping
from typing import Any, BinaryIO, Callable, Protocol
import jsondiff
import numpy as np
@@ -51,6 +52,15 @@ SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
Float32Reader = Callable[[str], object]
class TomlModule(Protocol):
# TOML support is imported lazily. This protocol documents the narrow
# tomllib/tomli module surface used by this script.
@property
def TOMLDecodeError(self) -> type[BaseException]: ...
def load(self, fp: BinaryIO, /) -> dict[str, Any]: ...
def read_float32_file(filename: str) -> object:
return np.fromfile(filename, dtype="<f4")
@@ -109,6 +119,43 @@ COMPARISON_THRESHOLD_PRESETS = {
for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items()
}
COMPARISON_CONFIG_VERSION = 1
COMPARISON_DEFAULT_PRESET = "default"
COMPARISON_CONFIG_TABLES = {
"preset",
"clear_gap",
"same",
"bulk",
}
COMPARISON_CONFIG_KEYS = {
"clear_gap": {
"relative": "clear_gap_relative",
},
"same": {
"center_relative": "same_center_relative",
"overlap_fraction": "same_overlap_fraction",
"relative_dispersion_ceiling": "same_relative_dispersion_ceiling",
},
"bulk": {
"sample_coverage": "bulk_same_sample_coverage",
"support_coverage": "bulk_same_support_coverage",
},
"bulk.rare_support": {
"sample_fraction": "bulk_support_rare_sample_fraction",
"max_removed_sample_fraction": "bulk_support_max_removed_sample_fraction",
},
}
COMPARISON_THRESHOLD_RANGES = {
"clear_gap_relative": (0.0, None),
"same_center_relative": (0.0, None),
"same_overlap_fraction": (0.0, 1.0),
"same_relative_dispersion_ceiling": (0.0, None),
"bulk_same_sample_coverage": (0.0, 1.0),
"bulk_same_support_coverage": (0.0, 1.0),
"bulk_support_rare_sample_fraction": (0.0, 1.0),
"bulk_support_max_removed_sample_fraction": (0.0, 1.0),
}
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
try:
@@ -117,6 +164,192 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
raise ValueError(f"unknown comparison preset {preset_name!r}") from exc
def load_toml_module() -> TomlModule:
try:
import tomllib
return tomllib
except ModuleNotFoundError:
try:
import tomli
return tomli
except ModuleNotFoundError as exc:
raise ValueError(
"TOML config support requires Python 3.11+ or the tomli package"
) from exc
def validate_config_table(value: object, table_name: str) -> None:
if not isinstance(value, Mapping):
raise ValueError(f"config table [{table_name}] must be a TOML table")
def validate_config_float(value: object, key: str, field_name: str) -> float:
if isinstance(value, bool) or not isinstance(value, int | float):
raise ValueError(f"config value {key!r} must be a finite number")
value = float(value)
if not math.isfinite(value):
raise ValueError(f"config value {key!r} must be finite")
minimum, maximum = COMPARISON_THRESHOLD_RANGES[field_name]
if value < minimum:
raise ValueError(f"config value {key!r} must be >= {minimum:g}")
if maximum is not None and value > maximum:
raise ValueError(f"config value {key!r} must be <= {maximum:g}")
return value
def parse_config_section(
table: Mapping[str, Any], section_name: str
) -> dict[str, float]:
validate_config_table(table, section_name)
known_keys = COMPARISON_CONFIG_KEYS[section_name]
unknown_keys = set(table) - set(known_keys)
if unknown_keys:
unknown = ", ".join(sorted(unknown_keys))
raise ValueError(f"unknown config key(s) in [{section_name}]: {unknown}")
overrides = {}
for key, field_name in known_keys.items():
if key not in table:
continue
full_key = f"{section_name}.{key}"
overrides[field_name] = validate_config_float(table[key], full_key, field_name)
return overrides
def parse_comparison_config_data(
config_data: Mapping[str, Any],
) -> tuple[str | None, dict[str, float]]:
if not isinstance(config_data, Mapping):
raise ValueError("comparison config must be a TOML table")
unknown_top_level = set(config_data) - ({"version"} | COMPARISON_CONFIG_TABLES)
if unknown_top_level:
unknown = ", ".join(sorted(unknown_top_level))
raise ValueError(f"unknown top-level config key(s): {unknown}")
version = config_data.get("version")
if isinstance(version, bool) or not isinstance(version, int):
raise ValueError(
f"comparison config must specify integer version = {COMPARISON_CONFIG_VERSION}"
)
if version != COMPARISON_CONFIG_VERSION:
raise ValueError(
f"unsupported comparison config version {version!r}; "
f"expected {COMPARISON_CONFIG_VERSION}"
)
preset_name = None
if "preset" in config_data:
preset_table = config_data["preset"]
validate_config_table(preset_table, "preset")
unknown_keys = set(preset_table) - {"name"}
if unknown_keys:
unknown = ", ".join(sorted(unknown_keys))
raise ValueError(f"unknown config key(s) in [preset]: {unknown}")
if "name" in preset_table:
preset_name = preset_table["name"]
if not isinstance(preset_name, str):
raise ValueError("config value 'preset.name' must be a string")
get_comparison_thresholds(preset_name)
overrides = {}
for section_name in ("clear_gap", "same"):
if section_name in config_data:
overrides.update(
parse_config_section(config_data[section_name], section_name)
)
if "bulk" in config_data:
bulk_table = config_data["bulk"]
validate_config_table(bulk_table, "bulk")
known_bulk_keys = set(COMPARISON_CONFIG_KEYS["bulk"]) | {"rare_support"}
unknown_keys = set(bulk_table) - known_bulk_keys
if unknown_keys:
unknown = ", ".join(sorted(unknown_keys))
raise ValueError(f"unknown config key(s) in [bulk]: {unknown}")
bulk_values = {
key: value for key, value in bulk_table.items() if key != "rare_support"
}
overrides.update(parse_config_section(bulk_values, "bulk"))
if "rare_support" in bulk_table:
overrides.update(
parse_config_section(bulk_table["rare_support"], "bulk.rare_support")
)
return preset_name, overrides
def read_comparison_config_file(
config_path: str | os.PathLike[str],
) -> tuple[str | None, dict[str, float]]:
toml_module = load_toml_module()
try:
with open(config_path, "rb") as config_file:
config_data = toml_module.load(config_file)
except toml_module.TOMLDecodeError as exc:
raise ValueError(
f"failed to parse comparison config {config_path!r}: {exc}"
) from exc
except OSError as exc:
raise ValueError(
f"failed to read comparison config {config_path!r}: {exc}"
) from exc
return parse_comparison_config_data(config_data)
def resolve_comparison_thresholds(
cli_preset_name: str | None = None,
config_path: str | os.PathLike[str] | None = None,
) -> tuple[str, ComparisonThresholds]:
config_preset_name = None
config_overrides: dict[str, float] = {}
if config_path is not None:
config_preset_name, config_overrides = read_comparison_config_file(config_path)
preset_name = cli_preset_name or config_preset_name or COMPARISON_DEFAULT_PRESET
thresholds = replace(get_comparison_thresholds(preset_name), **config_overrides)
return preset_name, thresholds
def format_toml_float(value: float) -> str:
return repr(float(value))
def dump_comparison_config(preset_name: str, thresholds: ComparisonThresholds) -> str:
lines = [
f"version = {COMPARISON_CONFIG_VERSION}",
"",
"[preset]",
f'name = "{preset_name}"',
"",
"[clear_gap]",
f"relative = {format_toml_float(thresholds.clear_gap_relative)}",
"",
"[same]",
f"center_relative = {format_toml_float(thresholds.same_center_relative)}",
f"overlap_fraction = {format_toml_float(thresholds.same_overlap_fraction)}",
"relative_dispersion_ceiling = "
f"{format_toml_float(thresholds.same_relative_dispersion_ceiling)}",
"",
"[bulk]",
f"sample_coverage = {format_toml_float(thresholds.bulk_same_sample_coverage)}",
f"support_coverage = {format_toml_float(thresholds.bulk_same_support_coverage)}",
"",
"[bulk.rare_support]",
"sample_fraction = "
f"{format_toml_float(thresholds.bulk_support_rare_sample_fraction)}",
"max_removed_sample_fraction = "
f"{format_toml_float(thresholds.bulk_support_max_removed_sample_fraction)}",
]
return "\n".join(lines) + "\n"
@dataclass(frozen=True)
class SupportFilterInfo:
activated: bool
@@ -2109,9 +2342,19 @@ def main() -> int:
parser.add_argument(
"--preset",
choices=sorted(COMPARISON_THRESHOLD_PRESETS),
default="default",
default=None,
help="comparison threshold preset",
)
parser.add_argument(
"--config",
default=None,
help="comparison threshold TOML config",
)
parser.add_argument(
"--dump-config",
action="store_true",
help="print the effective comparison threshold config and exit",
)
parser.add_argument(
"--display",
choices=["intervals", "legacy", "explain"],
@@ -2169,6 +2412,18 @@ def main() -> int:
)
args, files_or_dirs = parser.parse_known_args()
try:
comparison_preset, comparison_thresholds = resolve_comparison_thresholds(
args.preset, args.config
)
except ValueError as exc:
print(str(exc))
return 1
if args.dump_config:
print(dump_comparison_config(comparison_preset, comparison_thresholds), end="")
return 0
try:
filter_plan = build_benchmark_filter_plan(args.filter_actions)
reference_device_filter = parse_device_filter(
@@ -2177,7 +2432,6 @@ def main() -> int:
compare_device_filter = parse_device_filter(
args.compare_devices, "--compare-devices"
)
comparison_thresholds = get_comparison_thresholds(args.preset)
except ValueError as exc:
print(str(exc))
return 1

View File

@@ -1380,6 +1380,181 @@ def test_get_comparison_thresholds_returns_named_presets(nvbench_compare):
assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage
def test_dump_comparison_config_uses_grouped_toml(nvbench_compare):
config = nvbench_compare.dump_comparison_config(
"default", nvbench_compare.get_comparison_thresholds("default")
)
assert "version = 1\n" in config
assert '[preset]\nname = "default"\n' in config
assert "[clear_gap]\nrelative = 0.005\n" in config
assert "[same]\n" in config
assert "[bulk]\n" in config
assert "sample_coverage = 0.97\n" in config
assert "[bulk.rare_support]\n" in config
def test_resolve_comparison_thresholds_applies_config_overrides(
monkeypatch, nvbench_compare
):
def read_config(_):
return (
"strict",
{
"bulk_same_sample_coverage": 0.93,
"bulk_support_max_removed_sample_fraction": 0.02,
},
)
monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config)
preset, thresholds = nvbench_compare.resolve_comparison_thresholds(
None, "settings.toml"
)
assert preset == "strict"
assert thresholds.clear_gap_relative == pytest.approx(
nvbench_compare.get_comparison_thresholds("strict").clear_gap_relative
)
assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93)
assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02)
preset, thresholds = nvbench_compare.resolve_comparison_thresholds(
"permissive", "settings.toml"
)
assert preset == "permissive"
assert thresholds.clear_gap_relative == pytest.approx(
nvbench_compare.get_comparison_thresholds("permissive").clear_gap_relative
)
assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93)
assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02)
def test_parse_comparison_config_data_validates_grouped_thresholds(nvbench_compare):
preset, overrides = nvbench_compare.parse_comparison_config_data(
{
"version": 1,
"preset": {"name": "strict"},
"clear_gap": {"relative": 0.01},
"same": {
"center_relative": 0.002,
"overlap_fraction": 0.75,
"relative_dispersion_ceiling": 0.02,
},
"bulk": {
"sample_coverage": 0.99,
"support_coverage": 0.8,
"rare_support": {
"sample_fraction": 0.001,
"max_removed_sample_fraction": 0.01,
},
},
}
)
assert preset == "strict"
assert overrides == {
"clear_gap_relative": 0.01,
"same_center_relative": 0.002,
"same_overlap_fraction": 0.75,
"same_relative_dispersion_ceiling": 0.02,
"bulk_same_sample_coverage": 0.99,
"bulk_same_support_coverage": 0.8,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.01,
}
@pytest.mark.parametrize(
"config_data, match",
[
({}, "version"),
({"version": 2}, "unsupported"),
({"version": 1, "rare_support": {}}, "unknown top-level"),
({"version": 1, "bulk": {"unknown": 0.1}}, r"\[bulk\]"),
({"version": 1, "clear_gap": {"rare_support": {}}}, r"\[clear_gap\]"),
({"version": 1, "bulk": {"sample_coverage": 1.5}}, "<= 1"),
({"version": 1, "same": {"center_relative": "tight"}}, "finite number"),
({"version": 1, "preset": {"name": "aggressive"}}, "unknown comparison preset"),
],
)
def test_parse_comparison_config_data_rejects_invalid_config(
nvbench_compare, config_data, match
):
with pytest.raises(ValueError, match=match):
nvbench_compare.parse_comparison_config_data(config_data)
def test_read_comparison_config_file_parses_toml_when_parser_is_available(
tmp_path, nvbench_compare
):
parser_module = "tomllib" if sys.version_info >= (3, 11) else "tomli"
pytest.importorskip(parser_module)
config_path = tmp_path / "settings.toml"
config_path.write_text(
"""
version = 1
[preset]
name = "strict"
[bulk]
sample_coverage = 0.93
""",
encoding="utf-8",
)
preset, overrides = nvbench_compare.read_comparison_config_file(config_path)
assert preset == "strict"
assert overrides == {"bulk_same_sample_coverage": 0.93}
def test_main_dump_config_does_not_require_input_files(
monkeypatch, capsys, nvbench_compare
):
def read_file(_):
raise AssertionError("dump-config should not read JSON files")
monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file)
monkeypatch.setattr(
sys,
"argv",
["nvbench_compare", "--preset", "strict", "--dump-config"],
)
assert nvbench_compare.main() == 0
output = capsys.readouterr().out
assert 'name = "strict"' in output
assert "[bulk.rare_support]" in output
def test_main_dump_config_merges_config_and_cli_preset(
monkeypatch, capsys, nvbench_compare
):
def read_config(_):
return ("strict", {"bulk_same_sample_coverage": 0.93})
monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config)
monkeypatch.setattr(
sys,
"argv",
[
"nvbench_compare",
"--config",
"settings.toml",
"--preset",
"permissive",
"--dump-config",
],
)
assert nvbench_compare.main() == 0
output = capsys.readouterr().out
assert 'name = "permissive"' in output
assert "relative = 0.0025" in output
assert "sample_coverage = 0.93" in output
def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compare):
run_data = make_comparison_run_data(nvbench_compare)
captured = {}