From 20b3bd314843adfca962bdefda50e587fc955252 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Wed, 3 Jun 2026 15:21:26 -0500 Subject: [PATCH] Add nvbench_compare presets and rare-support-aware bulk coverage Introduce comparison threshold presets in nvbench_compare and thread the selected preset through main() into compare_benches. Refine bulk nearest-neighbor support handling by: - adding rare-support filtering thresholds - ignoring low-count support values only when removed sample mass is small - falling back to full support for all-unique or otherwise unusable support - keeping sample-weight coverage over all values Tighten bulk mismatch reporting to show compact min(ref, cmp) coverage summaries, and add tests covering: - rare-tail filtering - strict fallback when too much support mass would be removed - all-unique support preservation - preset lookup and CLI preset propagation --- python/scripts/nvbench_compare.py | 159 ++++++++++++++++++++++++++-- python/test/test_nvbench_compare.py | 91 ++++++++++++++-- 2 files changed, 237 insertions(+), 13 deletions(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index c3dd5a8..3b69360 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -67,6 +67,61 @@ class ComparisonThresholds: same_relative_dispersion_ceiling: float = 0.02 bulk_same_sample_coverage: float = 0.99 bulk_same_support_coverage: float = 0.80 + bulk_support_rare_sample_fraction: float = 0.001 + bulk_support_max_removed_sample_fraction: float = 0.01 + + +COMPARISON_THRESHOLD_PRESET_VALUES = { + "default": { + "clear_gap_relative": 0.005, + "same_center_relative": 0.005, + "same_overlap_fraction": 0.5, + "same_relative_dispersion_ceiling": 0.02, + "bulk_same_sample_coverage": 0.99, + "bulk_same_support_coverage": 0.80, + "bulk_support_rare_sample_fraction": 0.001, + "bulk_support_max_removed_sample_fraction": 0.01, + }, + "strict": { + "clear_gap_relative": 0.01, + "same_center_relative": 0.0025, + "same_overlap_fraction": 0.75, + "same_relative_dispersion_ceiling": 0.01, + "bulk_same_sample_coverage": 0.995, + "bulk_same_support_coverage": 0.90, + "bulk_support_rare_sample_fraction": 0.001, + "bulk_support_max_removed_sample_fraction": 0.005, + }, + "permissive": { + "clear_gap_relative": 0.0025, + "same_center_relative": 0.01, + "same_overlap_fraction": 0.25, + "same_relative_dispersion_ceiling": 0.05, + "bulk_same_sample_coverage": 0.98, + "bulk_same_support_coverage": 0.60, + "bulk_support_rare_sample_fraction": 0.001, + "bulk_support_max_removed_sample_fraction": 0.02, + }, +} + +COMPARISON_THRESHOLD_PRESETS = { + name: ComparisonThresholds(**values) + for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items() +} + + +def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds: + try: + return COMPARISON_THRESHOLD_PRESETS[preset_name] + except KeyError as exc: + raise ValueError(f"unknown comparison preset {preset_name!r}") from exc + + +@dataclass(frozen=True) +class SupportFilterInfo: + activated: bool + reason: str + removed_sample_fraction: float @dataclass(frozen=True) @@ -677,6 +732,82 @@ def symmetric_nearest_log_distances(x, y): return symmetric_nearest_distances(np.log(x), np.log(y)) +def compute_effective_support_mask(counts, thresholds): + """Return the unique-value mask used for support coverage. + + Sample-weight coverage always uses all values. Support coverage may ignore + low-count values only when their total sample mass is small; otherwise it + falls back to full support, preserving all-unique datasets. + """ + counts = np.asarray(counts) + total_count = np.sum(counts) + if ( + len(counts) == 0 + or total_count <= 0 + or thresholds.bulk_support_rare_sample_fraction <= 0.0 + or thresholds.bulk_support_max_removed_sample_fraction <= 0.0 + ): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="disabled", + removed_sample_fraction=0.0, + ) + + if np.all(counts == 1): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="all_values_unique", + removed_sample_fraction=0.0, + ) + + min_count = max( + 2, + math.ceil(thresholds.bulk_support_rare_sample_fraction * total_count), + ) + support_mask = counts >= min_count + if np.all(support_mask): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="no_rare_values", + removed_sample_fraction=0.0, + ) + if not np.any(support_mask): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="would_remove_all_support", + removed_sample_fraction=0.0, + ) + + removed_sample_fraction = np.sum(counts[~support_mask]) / total_count + if removed_sample_fraction > thresholds.bulk_support_max_removed_sample_fraction: + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="would_remove_too_much_mass", + removed_sample_fraction=0.0, + ) + + return support_mask, SupportFilterInfo( + activated=True, + reason="filtered", + removed_sample_fraction=removed_sample_fraction, + ) + + +def format_support_filter_info(filter_info): + if filter_info.activated: + return f"on({format_coverage(filter_info.removed_sample_fraction)})" + + if filter_info.reason == "no_rare_values": + return "off(no rare values)" + if filter_info.reason == "all_values_unique": + return "off(all values unique)" + if filter_info.reason == "would_remove_too_much_mass": + return "off(would remove too much mass)" + if filter_info.reason == "would_remove_all_support": + return "off(would remove all support)" + return "off(disabled)" + + def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds): ref_unique, ref_counts = np.unique_counts(ref_values) cmp_unique, cmp_counts = np.unique_counts(cmp_values) @@ -689,12 +820,20 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds): tolerance = math.log1p(thresholds.same_center_relative) ref_covered = ref_distances <= tolerance cmp_covered = cmp_distances <= tolerance + ref_support_mask, ref_filter_info = compute_effective_support_mask( + ref_counts, thresholds + ) + cmp_support_mask, cmp_filter_info = compute_effective_support_mask( + cmp_counts, thresholds + ) return { "ref_sample": np.sum(ref_counts[ref_covered]) / np.sum(ref_counts), "cmp_sample": np.sum(cmp_counts[cmp_covered]) / np.sum(cmp_counts), - "ref_support": np.mean(ref_covered), - "cmp_support": np.mean(cmp_covered), + "ref_support": np.mean(ref_covered[ref_support_mask]), + "cmp_support": np.mean(cmp_covered[cmp_support_mask]), + "ref_support_filter": ref_filter_info, + "cmp_support_filter": cmp_filter_info, } @@ -732,10 +871,10 @@ def make_bulk_coverage_mismatch_decision(label, coverages, thresholds): return make_decision( ComparisonStatus.UNDECIDED, f"bulk_{label}_support_mismatch", - f"sample ref={format_coverage(coverages['ref_sample'])} " - f"cmp={format_coverage(coverages['cmp_sample'])} >= {sample_threshold}; " - f"support ref={format_coverage(coverages['ref_support'])} " - f"cmp={format_coverage(coverages['cmp_support'])} >= {support_threshold}", + f"sample: min(ref={format_coverage(coverages['ref_sample'])}, " + f"cmp={format_coverage(coverages['cmp_sample'])}) >= {sample_threshold}; " + f"support: min(ref={format_coverage(coverages['ref_support'])}, " + f"cmp={format_coverage(coverages['cmp_support'])}) >= {support_threshold}", severity=severity, ) @@ -1760,6 +1899,12 @@ def main() -> int: default=0.0, help="only show benchmarks where percentage diff is >= THRESHOLD", ) + parser.add_argument( + "--preset", + choices=sorted(COMPARISON_THRESHOLD_PRESETS), + default="default", + help="comparison threshold preset", + ) parser.add_argument( "--plot-along", type=str, dest="plot_along", default=None, help="plot results" ) @@ -1819,6 +1964,7 @@ def main() -> int: compare_device_filter = parse_device_filter( args.compare_devices, "--compare-devices" ) + comparison_thresholds = get_comparison_thresholds(args.preset) except ValueError as exc: print(str(exc)) return 1 @@ -1907,6 +2053,7 @@ def main() -> int: compare_device_filter, os.path.dirname(ref), os.path.dirname(comp), + comparison_thresholds, ) except ValueError as exc: print(str(exc)) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index 8d2c003..70e4e8b 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -714,8 +714,8 @@ def test_compare_gpu_timings_keeps_bulk_mismatch_undecided(nvbench_compare): assert comparison is not None assert comparison.status == nvbench_compare.ComparisonStatus.UNDECIDED assert comparison.reason.code == "bulk_time_support_mismatch" - assert "sample ref=" in comparison.reason.message - assert "support ref=" in comparison.reason.message + assert "sample: min(ref=0.0%, cmp=0.0%) >= 99.0%" in comparison.reason.message + assert "support: min(ref=0.0%, cmp=0.0%) >= 80.0%" in comparison.reason.message assert "99.0%" in comparison.reason.message assert "80.0%" in comparison.reason.message @@ -756,11 +756,11 @@ def test_bulk_same_reports_sample_weight_coverage_mismatch(nvbench_compare): assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED assert decision.reason.code == "bulk_time_support_mismatch" - assert "sample ref=3.8%" in decision.reason.message - assert "support ref=80.0%" in decision.reason.message + assert "sample: min(ref=3.8%, cmp=100.0%) >= 99.0%" in decision.reason.message + assert "support: min(ref=80.0%, cmp=100.0%) >= 80.0%" in decision.reason.message -def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare): +def test_bulk_same_filters_rare_values_from_support_coverage(nvbench_compare): ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)] cmp_values = [1.0] @@ -771,10 +771,47 @@ def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare): thresholds=nvbench_compare.ComparisonThresholds(), ) + assert decision.status == nvbench_compare.ComparisonStatus.SAME + assert decision.reason.code == "bulk_time_same" + + +def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare): + ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)] + cmp_values = [1.0] + + decision = nvbench_compare.compare_values_for_bulk_same( + ref_values, + cmp_values, + label="time", + thresholds=nvbench_compare.ComparisonThresholds( + bulk_support_max_removed_sample_fraction=0.005 + ), + ) + assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED assert decision.reason.code == "bulk_time_support_mismatch" - assert "sample ref=99.0%" in decision.reason.message - assert "support ref=9.1%" in decision.reason.message + assert "sample: min(ref=99.0%, cmp=100.0%) >= 99.0%" in decision.reason.message + assert "support: min(ref=9.1%, cmp=100.0%) >= 80.0%" in decision.reason.message + + +def test_bulk_same_retains_full_support_when_all_values_are_unique(nvbench_compare): + coverages = nvbench_compare.compute_nearest_neighbor_coverages( + [1.0, 1.02], + [1.0], + thresholds=nvbench_compare.ComparisonThresholds( + bulk_support_rare_sample_fraction=1.0, + bulk_support_max_removed_sample_fraction=1.0, + ), + ) + + assert coverages is not None + assert coverages["ref_sample"] == 0.5 + assert coverages["ref_support"] == 0.5 + assert coverages["ref_support_filter"] == nvbench_compare.SupportFilterInfo( + activated=False, + reason="all_values_unique", + removed_sample_fraction=0.0, + ) def test_comparison_stats_records_undecided_status(nvbench_compare): @@ -1221,3 +1258,43 @@ def test_main_prints_undecided_reason_summary(monkeypatch, capsys, nvbench_compa output = capsys.readouterr().out assert "Undecided (comparison requires more evidence): 1" in output assert "noise_too_high: 1" in output + + +def test_get_comparison_thresholds_returns_named_presets(nvbench_compare): + default = nvbench_compare.get_comparison_thresholds("default") + strict = nvbench_compare.get_comparison_thresholds("strict") + permissive = nvbench_compare.get_comparison_thresholds("permissive") + + assert default == nvbench_compare.ComparisonThresholds() + assert strict.clear_gap_relative > default.clear_gap_relative + assert strict.same_center_relative < default.same_center_relative + assert strict.bulk_same_sample_coverage > default.bulk_same_sample_coverage + assert permissive.clear_gap_relative < default.clear_gap_relative + assert permissive.same_center_relative > default.same_center_relative + assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage + + +def test_main_passes_selected_preset_to_compare_benches(monkeypatch, nvbench_compare): + devices = [{"id": 0, "name": "Test GPU"}] + root = { + "devices": devices, + "benchmarks": [], + } + captured = {} + + monkeypatch.setattr(nvbench_compare.reader, "read_file", lambda _: root) + + def fake_compare_benches(*args, **kwargs): + captured["comparison_thresholds"] = args[-1] + + monkeypatch.setattr(nvbench_compare, "compare_benches", fake_compare_benches) + monkeypatch.setattr( + sys, + "argv", + ["nvbench_compare", "--preset", "strict", "ref.json", "cmp.json"], + ) + + assert nvbench_compare.main() == 0 + assert captured[ + "comparison_thresholds" + ] == nvbench_compare.get_comparison_thresholds("strict")