mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-06-29 18:57:44 +00:00
Add nvbench_compare presets and rare-support-aware bulk coverage
Introduce comparison threshold presets in nvbench_compare and thread the selected preset through main() into compare_benches. Refine bulk nearest-neighbor support handling by: - adding rare-support filtering thresholds - ignoring low-count support values only when removed sample mass is small - falling back to full support for all-unique or otherwise unusable support - keeping sample-weight coverage over all values Tighten bulk mismatch reporting to show compact min(ref, cmp) coverage summaries, and add tests covering: - rare-tail filtering - strict fallback when too much support mass would be removed - all-unique support preservation - preset lookup and CLI preset propagation
This commit is contained in:
@@ -67,6 +67,61 @@ class ComparisonThresholds:
|
||||
same_relative_dispersion_ceiling: float = 0.02
|
||||
bulk_same_sample_coverage: float = 0.99
|
||||
bulk_same_support_coverage: float = 0.80
|
||||
bulk_support_rare_sample_fraction: float = 0.001
|
||||
bulk_support_max_removed_sample_fraction: float = 0.01
|
||||
|
||||
|
||||
COMPARISON_THRESHOLD_PRESET_VALUES = {
|
||||
"default": {
|
||||
"clear_gap_relative": 0.005,
|
||||
"same_center_relative": 0.005,
|
||||
"same_overlap_fraction": 0.5,
|
||||
"same_relative_dispersion_ceiling": 0.02,
|
||||
"bulk_same_sample_coverage": 0.99,
|
||||
"bulk_same_support_coverage": 0.80,
|
||||
"bulk_support_rare_sample_fraction": 0.001,
|
||||
"bulk_support_max_removed_sample_fraction": 0.01,
|
||||
},
|
||||
"strict": {
|
||||
"clear_gap_relative": 0.01,
|
||||
"same_center_relative": 0.0025,
|
||||
"same_overlap_fraction": 0.75,
|
||||
"same_relative_dispersion_ceiling": 0.01,
|
||||
"bulk_same_sample_coverage": 0.995,
|
||||
"bulk_same_support_coverage": 0.90,
|
||||
"bulk_support_rare_sample_fraction": 0.001,
|
||||
"bulk_support_max_removed_sample_fraction": 0.005,
|
||||
},
|
||||
"permissive": {
|
||||
"clear_gap_relative": 0.0025,
|
||||
"same_center_relative": 0.01,
|
||||
"same_overlap_fraction": 0.25,
|
||||
"same_relative_dispersion_ceiling": 0.05,
|
||||
"bulk_same_sample_coverage": 0.98,
|
||||
"bulk_same_support_coverage": 0.60,
|
||||
"bulk_support_rare_sample_fraction": 0.001,
|
||||
"bulk_support_max_removed_sample_fraction": 0.02,
|
||||
},
|
||||
}
|
||||
|
||||
COMPARISON_THRESHOLD_PRESETS = {
|
||||
name: ComparisonThresholds(**values)
|
||||
for name, values in COMPARISON_THRESHOLD_PRESET_VALUES.items()
|
||||
}
|
||||
|
||||
|
||||
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
|
||||
try:
|
||||
return COMPARISON_THRESHOLD_PRESETS[preset_name]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"unknown comparison preset {preset_name!r}") from exc
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SupportFilterInfo:
|
||||
activated: bool
|
||||
reason: str
|
||||
removed_sample_fraction: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -677,6 +732,82 @@ def symmetric_nearest_log_distances(x, y):
|
||||
return symmetric_nearest_distances(np.log(x), np.log(y))
|
||||
|
||||
|
||||
def compute_effective_support_mask(counts, thresholds):
|
||||
"""Return the unique-value mask used for support coverage.
|
||||
|
||||
Sample-weight coverage always uses all values. Support coverage may ignore
|
||||
low-count values only when their total sample mass is small; otherwise it
|
||||
falls back to full support, preserving all-unique datasets.
|
||||
"""
|
||||
counts = np.asarray(counts)
|
||||
total_count = np.sum(counts)
|
||||
if (
|
||||
len(counts) == 0
|
||||
or total_count <= 0
|
||||
or thresholds.bulk_support_rare_sample_fraction <= 0.0
|
||||
or thresholds.bulk_support_max_removed_sample_fraction <= 0.0
|
||||
):
|
||||
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
|
||||
activated=False,
|
||||
reason="disabled",
|
||||
removed_sample_fraction=0.0,
|
||||
)
|
||||
|
||||
if np.all(counts == 1):
|
||||
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
|
||||
activated=False,
|
||||
reason="all_values_unique",
|
||||
removed_sample_fraction=0.0,
|
||||
)
|
||||
|
||||
min_count = max(
|
||||
2,
|
||||
math.ceil(thresholds.bulk_support_rare_sample_fraction * total_count),
|
||||
)
|
||||
support_mask = counts >= min_count
|
||||
if np.all(support_mask):
|
||||
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
|
||||
activated=False,
|
||||
reason="no_rare_values",
|
||||
removed_sample_fraction=0.0,
|
||||
)
|
||||
if not np.any(support_mask):
|
||||
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
|
||||
activated=False,
|
||||
reason="would_remove_all_support",
|
||||
removed_sample_fraction=0.0,
|
||||
)
|
||||
|
||||
removed_sample_fraction = np.sum(counts[~support_mask]) / total_count
|
||||
if removed_sample_fraction > thresholds.bulk_support_max_removed_sample_fraction:
|
||||
return np.ones(len(counts), dtype=bool), SupportFilterInfo(
|
||||
activated=False,
|
||||
reason="would_remove_too_much_mass",
|
||||
removed_sample_fraction=0.0,
|
||||
)
|
||||
|
||||
return support_mask, SupportFilterInfo(
|
||||
activated=True,
|
||||
reason="filtered",
|
||||
removed_sample_fraction=removed_sample_fraction,
|
||||
)
|
||||
|
||||
|
||||
def format_support_filter_info(filter_info):
|
||||
if filter_info.activated:
|
||||
return f"on({format_coverage(filter_info.removed_sample_fraction)})"
|
||||
|
||||
if filter_info.reason == "no_rare_values":
|
||||
return "off(no rare values)"
|
||||
if filter_info.reason == "all_values_unique":
|
||||
return "off(all values unique)"
|
||||
if filter_info.reason == "would_remove_too_much_mass":
|
||||
return "off(would remove too much mass)"
|
||||
if filter_info.reason == "would_remove_all_support":
|
||||
return "off(would remove all support)"
|
||||
return "off(disabled)"
|
||||
|
||||
|
||||
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
|
||||
ref_unique, ref_counts = np.unique_counts(ref_values)
|
||||
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
|
||||
@@ -689,12 +820,20 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
|
||||
tolerance = math.log1p(thresholds.same_center_relative)
|
||||
ref_covered = ref_distances <= tolerance
|
||||
cmp_covered = cmp_distances <= tolerance
|
||||
ref_support_mask, ref_filter_info = compute_effective_support_mask(
|
||||
ref_counts, thresholds
|
||||
)
|
||||
cmp_support_mask, cmp_filter_info = compute_effective_support_mask(
|
||||
cmp_counts, thresholds
|
||||
)
|
||||
|
||||
return {
|
||||
"ref_sample": np.sum(ref_counts[ref_covered]) / np.sum(ref_counts),
|
||||
"cmp_sample": np.sum(cmp_counts[cmp_covered]) / np.sum(cmp_counts),
|
||||
"ref_support": np.mean(ref_covered),
|
||||
"cmp_support": np.mean(cmp_covered),
|
||||
"ref_support": np.mean(ref_covered[ref_support_mask]),
|
||||
"cmp_support": np.mean(cmp_covered[cmp_support_mask]),
|
||||
"ref_support_filter": ref_filter_info,
|
||||
"cmp_support_filter": cmp_filter_info,
|
||||
}
|
||||
|
||||
|
||||
@@ -732,10 +871,10 @@ def make_bulk_coverage_mismatch_decision(label, coverages, thresholds):
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
f"bulk_{label}_support_mismatch",
|
||||
f"sample ref={format_coverage(coverages['ref_sample'])} "
|
||||
f"cmp={format_coverage(coverages['cmp_sample'])} >= {sample_threshold}; "
|
||||
f"support ref={format_coverage(coverages['ref_support'])} "
|
||||
f"cmp={format_coverage(coverages['cmp_support'])} >= {support_threshold}",
|
||||
f"sample: min(ref={format_coverage(coverages['ref_sample'])}, "
|
||||
f"cmp={format_coverage(coverages['cmp_sample'])}) >= {sample_threshold}; "
|
||||
f"support: min(ref={format_coverage(coverages['ref_support'])}, "
|
||||
f"cmp={format_coverage(coverages['cmp_support'])}) >= {support_threshold}",
|
||||
severity=severity,
|
||||
)
|
||||
|
||||
@@ -1760,6 +1899,12 @@ def main() -> int:
|
||||
default=0.0,
|
||||
help="only show benchmarks where percentage diff is >= THRESHOLD",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preset",
|
||||
choices=sorted(COMPARISON_THRESHOLD_PRESETS),
|
||||
default="default",
|
||||
help="comparison threshold preset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
|
||||
)
|
||||
@@ -1819,6 +1964,7 @@ def main() -> int:
|
||||
compare_device_filter = parse_device_filter(
|
||||
args.compare_devices, "--compare-devices"
|
||||
)
|
||||
comparison_thresholds = get_comparison_thresholds(args.preset)
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
return 1
|
||||
@@ -1907,6 +2053,7 @@ def main() -> int:
|
||||
compare_device_filter,
|
||||
os.path.dirname(ref),
|
||||
os.path.dirname(comp),
|
||||
comparison_thresholds,
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
|
||||
Reference in New Issue
Block a user