Add nvbench_compare presets and rare-support-aware bulk coverage

Introduce comparison threshold presets in nvbench_compare and thread the
selected preset through main() into compare_benches.

Refine bulk nearest-neighbor support handling by:
  - adding rare-support filtering thresholds
  - ignoring low-count support values only when removed sample mass is small
  - falling back to full support for all-unique or otherwise unusable support
  - keeping sample-weight coverage over all values

Tighten bulk mismatch reporting to show compact min(ref, cmp) coverage
summaries, and add tests covering:
  - rare-tail filtering
  - strict fallback when too much support mass would be removed
  - all-unique support preservation
  - preset lookup and CLI preset propagation
This commit is contained in:
Oleksandr Pavlyk
2026-06-03 15:21:26 -05:00
parent b791522d48
commit 20b3bd3148
2 changed files with 237 additions and 13 deletions

View File

@@ -67,6 +67,61 @@ class ComparisonThresholds:
same_relative_dispersion_ceiling: float = 0.02
bulk_same_sample_coverage: float = 0.99
bulk_same_support_coverage: float = 0.80
bulk_support_rare_sample_fraction: float = 0.001
bulk_support_max_removed_sample_fraction: float = 0.01
COMPARISON_THRESHOLD_PRESET_VALUES = {
"default": {
"clear_gap_relative": 0.005,
"same_center_relative": 0.005,
"same_overlap_fraction": 0.5,
"same_relative_dispersion_ceiling": 0.02,
"bulk_same_sample_coverage": 0.99,
"bulk_same_support_coverage": 0.80,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.01,
},
"strict": {
"clear_gap_relative": 0.01,
"same_center_relative": 0.0025,
"same_overlap_fraction": 0.75,
"same_relative_dispersion_ceiling": 0.01,
"bulk_same_sample_coverage": 0.995,
"bulk_same_support_coverage": 0.90,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.005,
},
"permissive": {
"clear_gap_relative": 0.0025,
"same_center_relative": 0.01,
"same_overlap_fraction": 0.25,
"same_relative_dispersion_ceiling": 0.05,
"bulk_same_sample_coverage": 0.98,
"bulk_same_support_coverage": 0.60,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.02,
},
}
COMPARISON_THRESHOLD_PRESETS = {
name: ComparisonThresholds(**values)
for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items()
}
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
try:
return COMPARISON_THRESHOLD_PRESETS[preset_name]
except KeyError as exc:
raise ValueError(f"unknown comparison preset {preset_name!r}") from exc
@dataclass(frozen=True)
class SupportFilterInfo:
activated: bool
reason: str
removed_sample_fraction: float
@dataclass(frozen=True)
@@ -677,6 +732,82 @@ def symmetric_nearest_log_distances(x, y):
return symmetric_nearest_distances(np.log(x), np.log(y))
def compute_effective_support_mask(counts, thresholds):
"""Return the unique-value mask used for support coverage.
Sample-weight coverage always uses all values. Support coverage may ignore
low-count values only when their total sample mass is small; otherwise it
falls back to full support, preserving all-unique datasets.
"""
counts = np.asarray(counts)
total_count = np.sum(counts)
if (
len(counts) == 0
or total_count <= 0
or thresholds.bulk_support_rare_sample_fraction <= 0.0
or thresholds.bulk_support_max_removed_sample_fraction <= 0.0
):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="disabled",
removed_sample_fraction=0.0,
)
if np.all(counts == 1):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="all_values_unique",
removed_sample_fraction=0.0,
)
min_count = max(
2,
math.ceil(thresholds.bulk_support_rare_sample_fraction * total_count),
)
support_mask = counts >= min_count
if np.all(support_mask):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="no_rare_values",
removed_sample_fraction=0.0,
)
if not np.any(support_mask):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="would_remove_all_support",
removed_sample_fraction=0.0,
)
removed_sample_fraction = np.sum(counts[~support_mask]) / total_count
if removed_sample_fraction > thresholds.bulk_support_max_removed_sample_fraction:
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="would_remove_too_much_mass",
removed_sample_fraction=0.0,
)
return support_mask, SupportFilterInfo(
activated=True,
reason="filtered",
removed_sample_fraction=removed_sample_fraction,
)
def format_support_filter_info(filter_info):
if filter_info.activated:
return f"on({format_coverage(filter_info.removed_sample_fraction)})"
if filter_info.reason == "no_rare_values":
return "off(no rare values)"
if filter_info.reason == "all_values_unique":
return "off(all values unique)"
if filter_info.reason == "would_remove_too_much_mass":
return "off(would remove too much mass)"
if filter_info.reason == "would_remove_all_support":
return "off(would remove all support)"
return "off(disabled)"
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
ref_unique, ref_counts = np.unique_counts(ref_values)
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
@@ -689,12 +820,20 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
tolerance = math.log1p(thresholds.same_center_relative)
ref_covered = ref_distances <= tolerance
cmp_covered = cmp_distances <= tolerance
ref_support_mask, ref_filter_info = compute_effective_support_mask(
ref_counts, thresholds
)
cmp_support_mask, cmp_filter_info = compute_effective_support_mask(
cmp_counts, thresholds
)
return {
"ref_sample": np.sum(ref_counts[ref_covered]) / np.sum(ref_counts),
"cmp_sample": np.sum(cmp_counts[cmp_covered]) / np.sum(cmp_counts),
"ref_support": np.mean(ref_covered),
"cmp_support": np.mean(cmp_covered),
"ref_support": np.mean(ref_covered[ref_support_mask]),
"cmp_support": np.mean(cmp_covered[cmp_support_mask]),
"ref_support_filter": ref_filter_info,
"cmp_support_filter": cmp_filter_info,
}
@@ -732,10 +871,10 @@ def make_bulk_coverage_mismatch_decision(label, coverages, thresholds):
return make_decision(
ComparisonStatus.UNDECIDED,
f"bulk_{label}_support_mismatch",
f"sample ref={format_coverage(coverages['ref_sample'])} "
f"cmp={format_coverage(coverages['cmp_sample'])} >= {sample_threshold}; "
f"support ref={format_coverage(coverages['ref_support'])} "
f"cmp={format_coverage(coverages['cmp_support'])} >= {support_threshold}",
f"sample: min(ref={format_coverage(coverages['ref_sample'])}, "
f"cmp={format_coverage(coverages['cmp_sample'])}) >= {sample_threshold}; "
f"support: min(ref={format_coverage(coverages['ref_support'])}, "
f"cmp={format_coverage(coverages['cmp_support'])}) >= {support_threshold}",
severity=severity,
)
@@ -1760,6 +1899,12 @@ def main() -> int:
default=0.0,
help="only show benchmarks where percentage diff is >= THRESHOLD",
)
parser.add_argument(
"--preset",
choices=sorted(COMPARISON_THRESHOLD_PRESETS),
default="default",
help="comparison threshold preset",
)
parser.add_argument(
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
)
@@ -1819,6 +1964,7 @@ def main() -> int:
compare_device_filter = parse_device_filter(
args.compare_devices, "--compare-devices"
)
comparison_thresholds = get_comparison_thresholds(args.preset)
except ValueError as exc:
print(str(exc))
return 1
@@ -1907,6 +2053,7 @@ def main() -> int:
compare_device_filter,
os.path.dirname(ref),
os.path.dirname(comp),
comparison_thresholds,
)
except ValueError as exc:
print(str(exc))

View File

@@ -714,8 +714,8 @@ def test_compare_gpu_timings_keeps_bulk_mismatch_undecided(nvbench_compare):
assert comparison is not None
assert comparison.status == nvbench_compare.ComparisonStatus.UNDECIDED
assert comparison.reason.code == "bulk_time_support_mismatch"
assert "sample ref=" in comparison.reason.message
assert "support ref=" in comparison.reason.message
assert "sample: min(ref=0.0%, cmp=0.0%) >= 99.0%" in comparison.reason.message
assert "support: min(ref=0.0%, cmp=0.0%) >= 80.0%" in comparison.reason.message
assert "99.0%" in comparison.reason.message
assert "80.0%" in comparison.reason.message
@@ -756,11 +756,11 @@ def test_bulk_same_reports_sample_weight_coverage_mismatch(nvbench_compare):
assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED
assert decision.reason.code == "bulk_time_support_mismatch"
assert "sample ref=3.8%" in decision.reason.message
assert "support ref=80.0%" in decision.reason.message
assert "sample: min(ref=3.8%, cmp=100.0%) >= 99.0%" in decision.reason.message
assert "support: min(ref=80.0%, cmp=100.0%) >= 80.0%" in decision.reason.message
def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare):
def test_bulk_same_filters_rare_values_from_support_coverage(nvbench_compare):
ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)]
cmp_values = [1.0]
@@ -771,10 +771,47 @@ def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare):
thresholds=nvbench_compare.ComparisonThresholds(),
)
assert decision.status == nvbench_compare.ComparisonStatus.SAME
assert decision.reason.code == "bulk_time_same"
def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare):
ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)]
cmp_values = [1.0]
decision = nvbench_compare.compare_values_for_bulk_same(
ref_values,
cmp_values,
label="time",
thresholds=nvbench_compare.ComparisonThresholds(
bulk_support_max_removed_sample_fraction=0.005
),
)
assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED
assert decision.reason.code == "bulk_time_support_mismatch"
assert "sample ref=99.0%" in decision.reason.message
assert "support ref=9.1%" in decision.reason.message
assert "sample: min(ref=99.0%, cmp=100.0%) >= 99.0%" in decision.reason.message
assert "support: min(ref=9.1%, cmp=100.0%) >= 80.0%" in decision.reason.message
def test_bulk_same_retains_full_support_when_all_values_are_unique(nvbench_compare):
coverages = nvbench_compare.compute_nearest_neighbor_coverages(
[1.0, 1.02],
[1.0],
thresholds=nvbench_compare.ComparisonThresholds(
bulk_support_rare_sample_fraction=1.0,
bulk_support_max_removed_sample_fraction=1.0,
),
)
assert coverages is not None
assert coverages["ref_sample"] == 0.5
assert coverages["ref_support"] == 0.5
assert coverages["ref_support_filter"] == nvbench_compare.SupportFilterInfo(
activated=False,
reason="all_values_unique",
removed_sample_fraction=0.0,
)
def test_comparison_stats_records_undecided_status(nvbench_compare):
@@ -1221,3 +1258,43 @@ def test_main_prints_undecided_reason_summary(monkeypatch, capsys, nvbench_compa
output = capsys.readouterr().out
assert "Undecided (comparison requires more evidence): 1" in output
assert "noise_too_high: 1" in output
def test_get_comparison_thresholds_returns_named_presets(nvbench_compare):
default = nvbench_compare.get_comparison_thresholds("default")
strict = nvbench_compare.get_comparison_thresholds("strict")
permissive = nvbench_compare.get_comparison_thresholds("permissive")
assert default == nvbench_compare.ComparisonThresholds()
assert strict.clear_gap_relative > default.clear_gap_relative
assert strict.same_center_relative < default.same_center_relative
assert strict.bulk_same_sample_coverage > default.bulk_same_sample_coverage
assert permissive.clear_gap_relative < default.clear_gap_relative
assert permissive.same_center_relative > default.same_center_relative
assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage
def test_main_passes_selected_preset_to_compare_benches(monkeypatch, nvbench_compare):
devices = [{"id": 0, "name": "Test GPU"}]
root = {
"devices": devices,
"benchmarks": [],
}
captured = {}
monkeypatch.setattr(nvbench_compare.reader, "read_file", lambda _: root)
def fake_compare_benches(*args, **kwargs):
captured["comparison_thresholds"] = args[-1]
monkeypatch.setattr(nvbench_compare, "compare_benches", fake_compare_benches)
monkeypatch.setattr(
sys,
"argv",
["nvbench_compare", "--preset", "strict", "ref.json", "cmp.json"],
)
assert nvbench_compare.main() == 0
assert captured[
"comparison_thresholds"
] == nvbench_compare.get_comparison_thresholds("strict")