mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-07-01 11:47:33 +00:00
Add TOML configuration for nvbench-compare thresholds
Add versioned TOML configuration support for nvbench-compare threshold
settings. The new --config option reads grouped settings for clear-gap,
same-result, bulk coverage, and rare-support filtering thresholds. The parser
validates the schema strictly so unknown tables, unknown keys, invalid types,
unsupported versions, and out-of-range values fail early.
Add --dump-config to print the effective configuration without requiring input
JSON files. This makes the currently selected preset and resolved threshold
values discoverable and gives users a starting point for custom configuration.
Preset resolution is:
- default is used when neither TOML nor CLI selects a preset
- [preset] name = "..." in TOML selects the base preset
- --preset ... overrides the TOML preset selection
- explicit threshold values in TOML override whichever base preset was selected
For example:
- nvbench-compare --dump-config
Prints the built-in default settings as grouped TOML.
- nvbench-compare --preset permissive --dump-config
Prints the permissive preset values as TOML.
- nvbench-compare --config compare.toml ref.json cmp.json
Compares using the preset named in compare.toml, plus any explicit TOML
threshold overrides.
- nvbench-compare --config compare.toml --preset strict ref.json cmp.json
Uses the strict preset as the base, while preserving explicit threshold
overrides from compare.toml.
Keep TOML parsing lazy: Python 3.11+ uses tomllib, while Python 3.10 only
requires tomli when --config is used. Add focused tests for grouped config
dumping, strict validation, preset/override precedence, and CLI dump behavior.
This commit is contained in:
@@ -9,10 +9,11 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Mapping
|
||||
from typing import Any, BinaryIO, Callable, Protocol
|
||||
|
||||
import jsondiff
|
||||
import numpy as np
|
||||
@@ -51,6 +52,15 @@ SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
|
||||
Float32Reader = Callable[[str], object]
|
||||
|
||||
|
||||
class TomlModule(Protocol):
|
||||
# TOML support is imported lazily. This protocol documents the narrow
|
||||
# tomllib/tomli module surface used by this script.
|
||||
@property
|
||||
def TOMLDecodeError(self) -> type[BaseException]: ...
|
||||
|
||||
def load(self, fp: BinaryIO, /) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
def read_float32_file(filename: str) -> object:
|
||||
return np.fromfile(filename, dtype="<f4")
|
||||
|
||||
@@ -109,6 +119,43 @@ COMPARISON_THRESHOLD_PRESETS = {
|
||||
for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items()
|
||||
}
|
||||
|
||||
COMPARISON_CONFIG_VERSION = 1
|
||||
COMPARISON_DEFAULT_PRESET = "default"
|
||||
COMPARISON_CONFIG_TABLES = {
|
||||
"preset",
|
||||
"clear_gap",
|
||||
"same",
|
||||
"bulk",
|
||||
}
|
||||
COMPARISON_CONFIG_KEYS = {
|
||||
"clear_gap": {
|
||||
"relative": "clear_gap_relative",
|
||||
},
|
||||
"same": {
|
||||
"center_relative": "same_center_relative",
|
||||
"overlap_fraction": "same_overlap_fraction",
|
||||
"relative_dispersion_ceiling": "same_relative_dispersion_ceiling",
|
||||
},
|
||||
"bulk": {
|
||||
"sample_coverage": "bulk_same_sample_coverage",
|
||||
"support_coverage": "bulk_same_support_coverage",
|
||||
},
|
||||
"bulk.rare_support": {
|
||||
"sample_fraction": "bulk_support_rare_sample_fraction",
|
||||
"max_removed_sample_fraction": "bulk_support_max_removed_sample_fraction",
|
||||
},
|
||||
}
|
||||
COMPARISON_THRESHOLD_RANGES = {
|
||||
"clear_gap_relative": (0.0, None),
|
||||
"same_center_relative": (0.0, None),
|
||||
"same_overlap_fraction": (0.0, 1.0),
|
||||
"same_relative_dispersion_ceiling": (0.0, None),
|
||||
"bulk_same_sample_coverage": (0.0, 1.0),
|
||||
"bulk_same_support_coverage": (0.0, 1.0),
|
||||
"bulk_support_rare_sample_fraction": (0.0, 1.0),
|
||||
"bulk_support_max_removed_sample_fraction": (0.0, 1.0),
|
||||
}
|
||||
|
||||
|
||||
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
|
||||
try:
|
||||
@@ -117,6 +164,192 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
|
||||
raise ValueError(f"unknown comparison preset {preset_name!r}") from exc
|
||||
|
||||
|
||||
def load_toml_module() -> TomlModule:
|
||||
try:
|
||||
import tomllib
|
||||
|
||||
return tomllib
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
import tomli
|
||||
|
||||
return tomli
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
"TOML config support requires Python 3.11+ or the tomli package"
|
||||
) from exc
|
||||
|
||||
|
||||
def validate_config_table(value: object, table_name: str) -> None:
|
||||
if not isinstance(value, Mapping):
|
||||
raise ValueError(f"config table [{table_name}] must be a TOML table")
|
||||
|
||||
|
||||
def validate_config_float(value: object, key: str, field_name: str) -> float:
|
||||
if isinstance(value, bool) or not isinstance(value, int | float):
|
||||
raise ValueError(f"config value {key!r} must be a finite number")
|
||||
|
||||
value = float(value)
|
||||
if not math.isfinite(value):
|
||||
raise ValueError(f"config value {key!r} must be finite")
|
||||
|
||||
minimum, maximum = COMPARISON_THRESHOLD_RANGES[field_name]
|
||||
if value < minimum:
|
||||
raise ValueError(f"config value {key!r} must be >= {minimum:g}")
|
||||
if maximum is not None and value > maximum:
|
||||
raise ValueError(f"config value {key!r} must be <= {maximum:g}")
|
||||
return value
|
||||
|
||||
|
||||
def parse_config_section(
|
||||
table: Mapping[str, Any], section_name: str
|
||||
) -> dict[str, float]:
|
||||
validate_config_table(table, section_name)
|
||||
known_keys = COMPARISON_CONFIG_KEYS[section_name]
|
||||
unknown_keys = set(table) - set(known_keys)
|
||||
if unknown_keys:
|
||||
unknown = ", ".join(sorted(unknown_keys))
|
||||
raise ValueError(f"unknown config key(s) in [{section_name}]: {unknown}")
|
||||
|
||||
overrides = {}
|
||||
for key, field_name in known_keys.items():
|
||||
if key not in table:
|
||||
continue
|
||||
full_key = f"{section_name}.{key}"
|
||||
overrides[field_name] = validate_config_float(table[key], full_key, field_name)
|
||||
return overrides
|
||||
|
||||
|
||||
def parse_comparison_config_data(
|
||||
config_data: Mapping[str, Any],
|
||||
) -> tuple[str | None, dict[str, float]]:
|
||||
if not isinstance(config_data, Mapping):
|
||||
raise ValueError("comparison config must be a TOML table")
|
||||
|
||||
unknown_top_level = set(config_data) - ({"version"} | COMPARISON_CONFIG_TABLES)
|
||||
if unknown_top_level:
|
||||
unknown = ", ".join(sorted(unknown_top_level))
|
||||
raise ValueError(f"unknown top-level config key(s): {unknown}")
|
||||
|
||||
version = config_data.get("version")
|
||||
if isinstance(version, bool) or not isinstance(version, int):
|
||||
raise ValueError(
|
||||
f"comparison config must specify integer version = {COMPARISON_CONFIG_VERSION}"
|
||||
)
|
||||
if version != COMPARISON_CONFIG_VERSION:
|
||||
raise ValueError(
|
||||
f"unsupported comparison config version {version!r}; "
|
||||
f"expected {COMPARISON_CONFIG_VERSION}"
|
||||
)
|
||||
|
||||
preset_name = None
|
||||
if "preset" in config_data:
|
||||
preset_table = config_data["preset"]
|
||||
validate_config_table(preset_table, "preset")
|
||||
unknown_keys = set(preset_table) - {"name"}
|
||||
if unknown_keys:
|
||||
unknown = ", ".join(sorted(unknown_keys))
|
||||
raise ValueError(f"unknown config key(s) in [preset]: {unknown}")
|
||||
if "name" in preset_table:
|
||||
preset_name = preset_table["name"]
|
||||
if not isinstance(preset_name, str):
|
||||
raise ValueError("config value 'preset.name' must be a string")
|
||||
get_comparison_thresholds(preset_name)
|
||||
|
||||
overrides = {}
|
||||
for section_name in ("clear_gap", "same"):
|
||||
if section_name in config_data:
|
||||
overrides.update(
|
||||
parse_config_section(config_data[section_name], section_name)
|
||||
)
|
||||
|
||||
if "bulk" in config_data:
|
||||
bulk_table = config_data["bulk"]
|
||||
validate_config_table(bulk_table, "bulk")
|
||||
known_bulk_keys = set(COMPARISON_CONFIG_KEYS["bulk"]) | {"rare_support"}
|
||||
unknown_keys = set(bulk_table) - known_bulk_keys
|
||||
if unknown_keys:
|
||||
unknown = ", ".join(sorted(unknown_keys))
|
||||
raise ValueError(f"unknown config key(s) in [bulk]: {unknown}")
|
||||
|
||||
bulk_values = {
|
||||
key: value for key, value in bulk_table.items() if key != "rare_support"
|
||||
}
|
||||
overrides.update(parse_config_section(bulk_values, "bulk"))
|
||||
if "rare_support" in bulk_table:
|
||||
overrides.update(
|
||||
parse_config_section(bulk_table["rare_support"], "bulk.rare_support")
|
||||
)
|
||||
|
||||
return preset_name, overrides
|
||||
|
||||
|
||||
def read_comparison_config_file(
|
||||
config_path: str | os.PathLike[str],
|
||||
) -> tuple[str | None, dict[str, float]]:
|
||||
toml_module = load_toml_module()
|
||||
try:
|
||||
with open(config_path, "rb") as config_file:
|
||||
config_data = toml_module.load(config_file)
|
||||
except toml_module.TOMLDecodeError as exc:
|
||||
raise ValueError(
|
||||
f"failed to parse comparison config {config_path!r}: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
f"failed to read comparison config {config_path!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
return parse_comparison_config_data(config_data)
|
||||
|
||||
|
||||
def resolve_comparison_thresholds(
|
||||
cli_preset_name: str | None = None,
|
||||
config_path: str | os.PathLike[str] | None = None,
|
||||
) -> tuple[str, ComparisonThresholds]:
|
||||
config_preset_name = None
|
||||
config_overrides: dict[str, float] = {}
|
||||
if config_path is not None:
|
||||
config_preset_name, config_overrides = read_comparison_config_file(config_path)
|
||||
|
||||
preset_name = cli_preset_name or config_preset_name or COMPARISON_DEFAULT_PRESET
|
||||
thresholds = replace(get_comparison_thresholds(preset_name), **config_overrides)
|
||||
return preset_name, thresholds
|
||||
|
||||
|
||||
def format_toml_float(value: float) -> str:
|
||||
return repr(float(value))
|
||||
|
||||
|
||||
def dump_comparison_config(preset_name: str, thresholds: ComparisonThresholds) -> str:
|
||||
lines = [
|
||||
f"version = {COMPARISON_CONFIG_VERSION}",
|
||||
"",
|
||||
"[preset]",
|
||||
f'name = "{preset_name}"',
|
||||
"",
|
||||
"[clear_gap]",
|
||||
f"relative = {format_toml_float(thresholds.clear_gap_relative)}",
|
||||
"",
|
||||
"[same]",
|
||||
f"center_relative = {format_toml_float(thresholds.same_center_relative)}",
|
||||
f"overlap_fraction = {format_toml_float(thresholds.same_overlap_fraction)}",
|
||||
"relative_dispersion_ceiling = "
|
||||
f"{format_toml_float(thresholds.same_relative_dispersion_ceiling)}",
|
||||
"",
|
||||
"[bulk]",
|
||||
f"sample_coverage = {format_toml_float(thresholds.bulk_same_sample_coverage)}",
|
||||
f"support_coverage = {format_toml_float(thresholds.bulk_same_support_coverage)}",
|
||||
"",
|
||||
"[bulk.rare_support]",
|
||||
"sample_fraction = "
|
||||
f"{format_toml_float(thresholds.bulk_support_rare_sample_fraction)}",
|
||||
"max_removed_sample_fraction = "
|
||||
f"{format_toml_float(thresholds.bulk_support_max_removed_sample_fraction)}",
|
||||
]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SupportFilterInfo:
|
||||
activated: bool
|
||||
@@ -2109,9 +2342,19 @@ def main() -> int:
|
||||
parser.add_argument(
|
||||
"--preset",
|
||||
choices=sorted(COMPARISON_THRESHOLD_PRESETS),
|
||||
default="default",
|
||||
default=None,
|
||||
help="comparison threshold preset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=None,
|
||||
help="comparison threshold TOML config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dump-config",
|
||||
action="store_true",
|
||||
help="print the effective comparison threshold config and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display",
|
||||
choices=["intervals", "legacy", "explain"],
|
||||
@@ -2169,6 +2412,18 @@ def main() -> int:
|
||||
)
|
||||
|
||||
args, files_or_dirs = parser.parse_known_args()
|
||||
try:
|
||||
comparison_preset, comparison_thresholds = resolve_comparison_thresholds(
|
||||
args.preset, args.config
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
return 1
|
||||
|
||||
if args.dump_config:
|
||||
print(dump_comparison_config(comparison_preset, comparison_thresholds), end="")
|
||||
return 0
|
||||
|
||||
try:
|
||||
filter_plan = build_benchmark_filter_plan(args.filter_actions)
|
||||
reference_device_filter = parse_device_filter(
|
||||
@@ -2177,7 +2432,6 @@ def main() -> int:
|
||||
compare_device_filter = parse_device_filter(
|
||||
args.compare_devices, "--compare-devices"
|
||||
)
|
||||
comparison_thresholds = get_comparison_thresholds(args.preset)
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
return 1
|
||||
|
||||
@@ -1380,6 +1380,181 @@ def test_get_comparison_thresholds_returns_named_presets(nvbench_compare):
|
||||
assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage
|
||||
|
||||
|
||||
def test_dump_comparison_config_uses_grouped_toml(nvbench_compare):
|
||||
config = nvbench_compare.dump_comparison_config(
|
||||
"default", nvbench_compare.get_comparison_thresholds("default")
|
||||
)
|
||||
|
||||
assert "version = 1\n" in config
|
||||
assert '[preset]\nname = "default"\n' in config
|
||||
assert "[clear_gap]\nrelative = 0.005\n" in config
|
||||
assert "[same]\n" in config
|
||||
assert "[bulk]\n" in config
|
||||
assert "sample_coverage = 0.97\n" in config
|
||||
assert "[bulk.rare_support]\n" in config
|
||||
|
||||
|
||||
def test_resolve_comparison_thresholds_applies_config_overrides(
|
||||
monkeypatch, nvbench_compare
|
||||
):
|
||||
def read_config(_):
|
||||
return (
|
||||
"strict",
|
||||
{
|
||||
"bulk_same_sample_coverage": 0.93,
|
||||
"bulk_support_max_removed_sample_fraction": 0.02,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config)
|
||||
|
||||
preset, thresholds = nvbench_compare.resolve_comparison_thresholds(
|
||||
None, "settings.toml"
|
||||
)
|
||||
assert preset == "strict"
|
||||
assert thresholds.clear_gap_relative == pytest.approx(
|
||||
nvbench_compare.get_comparison_thresholds("strict").clear_gap_relative
|
||||
)
|
||||
assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93)
|
||||
assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02)
|
||||
|
||||
preset, thresholds = nvbench_compare.resolve_comparison_thresholds(
|
||||
"permissive", "settings.toml"
|
||||
)
|
||||
assert preset == "permissive"
|
||||
assert thresholds.clear_gap_relative == pytest.approx(
|
||||
nvbench_compare.get_comparison_thresholds("permissive").clear_gap_relative
|
||||
)
|
||||
assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93)
|
||||
assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02)
|
||||
|
||||
|
||||
def test_parse_comparison_config_data_validates_grouped_thresholds(nvbench_compare):
|
||||
preset, overrides = nvbench_compare.parse_comparison_config_data(
|
||||
{
|
||||
"version": 1,
|
||||
"preset": {"name": "strict"},
|
||||
"clear_gap": {"relative": 0.01},
|
||||
"same": {
|
||||
"center_relative": 0.002,
|
||||
"overlap_fraction": 0.75,
|
||||
"relative_dispersion_ceiling": 0.02,
|
||||
},
|
||||
"bulk": {
|
||||
"sample_coverage": 0.99,
|
||||
"support_coverage": 0.8,
|
||||
"rare_support": {
|
||||
"sample_fraction": 0.001,
|
||||
"max_removed_sample_fraction": 0.01,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert preset == "strict"
|
||||
assert overrides == {
|
||||
"clear_gap_relative": 0.01,
|
||||
"same_center_relative": 0.002,
|
||||
"same_overlap_fraction": 0.75,
|
||||
"same_relative_dispersion_ceiling": 0.02,
|
||||
"bulk_same_sample_coverage": 0.99,
|
||||
"bulk_same_support_coverage": 0.8,
|
||||
"bulk_support_rare_sample_fraction": 0.001,
|
||||
"bulk_support_max_removed_sample_fraction": 0.01,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_data, match",
|
||||
[
|
||||
({}, "version"),
|
||||
({"version": 2}, "unsupported"),
|
||||
({"version": 1, "rare_support": {}}, "unknown top-level"),
|
||||
({"version": 1, "bulk": {"unknown": 0.1}}, r"\[bulk\]"),
|
||||
({"version": 1, "clear_gap": {"rare_support": {}}}, r"\[clear_gap\]"),
|
||||
({"version": 1, "bulk": {"sample_coverage": 1.5}}, "<= 1"),
|
||||
({"version": 1, "same": {"center_relative": "tight"}}, "finite number"),
|
||||
({"version": 1, "preset": {"name": "aggressive"}}, "unknown comparison preset"),
|
||||
],
|
||||
)
|
||||
def test_parse_comparison_config_data_rejects_invalid_config(
|
||||
nvbench_compare, config_data, match
|
||||
):
|
||||
with pytest.raises(ValueError, match=match):
|
||||
nvbench_compare.parse_comparison_config_data(config_data)
|
||||
|
||||
|
||||
def test_read_comparison_config_file_parses_toml_when_parser_is_available(
|
||||
tmp_path, nvbench_compare
|
||||
):
|
||||
parser_module = "tomllib" if sys.version_info >= (3, 11) else "tomli"
|
||||
pytest.importorskip(parser_module)
|
||||
config_path = tmp_path / "settings.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[preset]
|
||||
name = "strict"
|
||||
|
||||
[bulk]
|
||||
sample_coverage = 0.93
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
preset, overrides = nvbench_compare.read_comparison_config_file(config_path)
|
||||
|
||||
assert preset == "strict"
|
||||
assert overrides == {"bulk_same_sample_coverage": 0.93}
|
||||
|
||||
|
||||
def test_main_dump_config_does_not_require_input_files(
|
||||
monkeypatch, capsys, nvbench_compare
|
||||
):
|
||||
def read_file(_):
|
||||
raise AssertionError("dump-config should not read JSON files")
|
||||
|
||||
monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["nvbench_compare", "--preset", "strict", "--dump-config"],
|
||||
)
|
||||
|
||||
assert nvbench_compare.main() == 0
|
||||
output = capsys.readouterr().out
|
||||
assert 'name = "strict"' in output
|
||||
assert "[bulk.rare_support]" in output
|
||||
|
||||
|
||||
def test_main_dump_config_merges_config_and_cli_preset(
|
||||
monkeypatch, capsys, nvbench_compare
|
||||
):
|
||||
def read_config(_):
|
||||
return ("strict", {"bulk_same_sample_coverage": 0.93})
|
||||
|
||||
monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config)
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"nvbench_compare",
|
||||
"--config",
|
||||
"settings.toml",
|
||||
"--preset",
|
||||
"permissive",
|
||||
"--dump-config",
|
||||
],
|
||||
)
|
||||
|
||||
assert nvbench_compare.main() == 0
|
||||
output = capsys.readouterr().out
|
||||
assert 'name = "permissive"' in output
|
||||
assert "relative = 0.0025" in output
|
||||
assert "sample_coverage = 0.93" in output
|
||||
|
||||
|
||||
def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compare):
|
||||
run_data = make_comparison_run_data(nvbench_compare)
|
||||
captured = {}
|
||||
|
||||
Reference in New Issue
Block a user