Add nvbench_compare presets and rare-support-aware bulk coverage

Introduce comparison threshold presets in nvbench_compare and thread the
selected preset through main() into compare_benches.

Refine bulk nearest-neighbor support handling by:
  - adding rare-support filtering thresholds
  - ignoring low-count support values only when removed sample mass is small
  - falling back to full support for all-unique or otherwise unusable support
  - keeping sample-weight coverage over all values

Tighten bulk mismatch reporting to show compact min(ref, cmp) coverage
summaries, and add tests covering:
  - rare-tail filtering
  - strict fallback when too much support mass would be removed
  - all-unique support preservation
  - preset lookup and CLI preset propagation
This commit is contained in:
Oleksandr Pavlyk
2026-06-03 15:21:26 -05:00
parent b791522d48
commit 20b3bd3148
2 changed files with 237 additions and 13 deletions

View File

@@ -67,6 +67,61 @@ class ComparisonThresholds:
same_relative_dispersion_ceiling: float = 0.02
bulk_same_sample_coverage: float = 0.99
bulk_same_support_coverage: float = 0.80
bulk_support_rare_sample_fraction: float = 0.001
bulk_support_max_removed_sample_fraction: float = 0.01
COMPARISON_THRESHOLD_PRESET_VALUES = {
"default": {
"clear_gap_relative": 0.005,
"same_center_relative": 0.005,
"same_overlap_fraction": 0.5,
"same_relative_dispersion_ceiling": 0.02,
"bulk_same_sample_coverage": 0.99,
"bulk_same_support_coverage": 0.80,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.01,
},
"strict": {
"clear_gap_relative": 0.01,
"same_center_relative": 0.0025,
"same_overlap_fraction": 0.75,
"same_relative_dispersion_ceiling": 0.01,
"bulk_same_sample_coverage": 0.995,
"bulk_same_support_coverage": 0.90,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.005,
},
"permissive": {
"clear_gap_relative": 0.0025,
"same_center_relative": 0.01,
"same_overlap_fraction": 0.25,
"same_relative_dispersion_ceiling": 0.05,
"bulk_same_sample_coverage": 0.98,
"bulk_same_support_coverage": 0.60,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.02,
},
}
COMPARISON_THRESHOLD_PRESETS = {
name: ComparisonThresholds(**values)
for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items()
}
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
try:
return COMPARISON_THRESHOLD_PRESETS[preset_name]
except KeyError as exc:
raise ValueError(f"unknown comparison preset {preset_name!r}") from exc
@dataclass(frozen=True)
class SupportFilterInfo:
activated: bool
reason: str
removed_sample_fraction: float
@dataclass(frozen=True)
@@ -677,6 +732,82 @@ def symmetric_nearest_log_distances(x, y):
return symmetric_nearest_distances(np.log(x), np.log(y))
def compute_effective_support_mask(counts, thresholds):
"""Return the unique-value mask used for support coverage.
Sample-weight coverage always uses all values. Support coverage may ignore
low-count values only when their total sample mass is small; otherwise it
falls back to full support, preserving all-unique datasets.
"""
counts = np.asarray(counts)
total_count = np.sum(counts)
if (
len(counts) == 0
or total_count <= 0
or thresholds.bulk_support_rare_sample_fraction <= 0.0
or thresholds.bulk_support_max_removed_sample_fraction <= 0.0
):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="disabled",
removed_sample_fraction=0.0,
)
if np.all(counts == 1):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="all_values_unique",
removed_sample_fraction=0.0,
)
min_count = max(
2,
math.ceil(thresholds.bulk_support_rare_sample_fraction * total_count),
)
support_mask = counts >= min_count
if np.all(support_mask):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="no_rare_values",
removed_sample_fraction=0.0,
)
if not np.any(support_mask):
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="would_remove_all_support",
removed_sample_fraction=0.0,
)
removed_sample_fraction = np.sum(counts[~support_mask]) / total_count
if removed_sample_fraction > thresholds.bulk_support_max_removed_sample_fraction:
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
activated=False,
reason="would_remove_too_much_mass",
removed_sample_fraction=0.0,
)
return support_mask, SupportFilterInfo(
activated=True,
reason="filtered",
removed_sample_fraction=removed_sample_fraction,
)
def format_support_filter_info(filter_info):
if filter_info.activated:
return f"on({format_coverage(filter_info.removed_sample_fraction)})"
if filter_info.reason == "no_rare_values":
return "off(no rare values)"
if filter_info.reason == "all_values_unique":
return "off(all values unique)"
if filter_info.reason == "would_remove_too_much_mass":
return "off(would remove too much mass)"
if filter_info.reason == "would_remove_all_support":
return "off(would remove all support)"
return "off(disabled)"
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
ref_unique, ref_counts = np.unique_counts(ref_values)
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
@@ -689,12 +820,20 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
tolerance = math.log1p(thresholds.same_center_relative)
ref_covered = ref_distances <= tolerance
cmp_covered = cmp_distances <= tolerance
ref_support_mask, ref_filter_info = compute_effective_support_mask(
ref_counts, thresholds
)
cmp_support_mask, cmp_filter_info = compute_effective_support_mask(
cmp_counts, thresholds
)
return {
"ref_sample": np.sum(ref_counts[ref_covered]) / np.sum(ref_counts),
"cmp_sample": np.sum(cmp_counts[cmp_covered]) / np.sum(cmp_counts),
"ref_support": np.mean(ref_covered),
"cmp_support": np.mean(cmp_covered),
"ref_support": np.mean(ref_covered[ref_support_mask]),
"cmp_support": np.mean(cmp_covered[cmp_support_mask]),
"ref_support_filter": ref_filter_info,
"cmp_support_filter": cmp_filter_info,
}
@@ -732,10 +871,10 @@ def make_bulk_coverage_mismatch_decision(label, coverages, thresholds):
return make_decision(
ComparisonStatus.UNDECIDED,
f"bulk_{label}_support_mismatch",
f"sample ref={format_coverage(coverages['ref_sample'])} "
f"cmp={format_coverage(coverages['cmp_sample'])} >= {sample_threshold}; "
f"support ref={format_coverage(coverages['ref_support'])} "
f"cmp={format_coverage(coverages['cmp_support'])} >= {support_threshold}",
f"sample: min(ref={format_coverage(coverages['ref_sample'])}, "
f"cmp={format_coverage(coverages['cmp_sample'])}) >= {sample_threshold}; "
f"support: min(ref={format_coverage(coverages['ref_support'])}, "
f"cmp={format_coverage(coverages['cmp_support'])}) >= {support_threshold}",
severity=severity,
)
@@ -1760,6 +1899,12 @@ def main() -> int:
default=0.0,
help="only show benchmarks where percentage diff is >= THRESHOLD",
)
parser.add_argument(
"--preset",
choices=sorted(COMPARISON_THRESHOLD_PRESETS),
default="default",
help="comparison threshold preset",
)
parser.add_argument(
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
)
@@ -1819,6 +1964,7 @@ def main() -> int:
compare_device_filter = parse_device_filter(
args.compare_devices, "--compare-devices"
)
comparison_thresholds = get_comparison_thresholds(args.preset)
except ValueError as exc:
print(str(exc))
return 1
@@ -1907,6 +2053,7 @@ def main() -> int:
compare_device_filter,
os.path.dirname(ref),
os.path.dirname(comp),
comparison_thresholds,
)
except ValueError as exc:
print(str(exc))