Group nvbench-compare thresholds into a config object

Replace the scattered module-level comparison threshold constants
with a ComparisonThresholds value object. Thread this object through
compare_benches, compare_gpu_timings, and the lower-level clear-gap,
summary-SAME, and bulk-SAME decision helpers.

Keep existing behavior by constructing default ComparisonThresholds
when callers do not provide one. This prepares nvbench-compare for
future CLI-configurable decision thresholds while keeping one consistent
configuration for an entire comparison run.

Add test coverage that passes custom thresholds through compare_benches and
verifies they affect the SAME decision.
This commit is contained in:
Oleksandr Pavlyk
2026-06-03 10:02:46 -05:00
parent 0f091438a5
commit d8efe3dd9e
2 changed files with 130 additions and 55 deletions

View File

@@ -45,12 +45,6 @@ GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative"
GPU_SM_CLOCK_RATE_MEAN_TAG = "nv/cold/sm_clock_rate/mean"
SAMPLE_TIMES_TAG = "nv/json/bin:nv/cold/sample_times"
SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
CLEAR_GAP_RELATIVE_THRESHOLD = 0.005
SAME_CENTER_RELATIVE_THRESHOLD = 0.005
SAME_OVERLAP_FRACTION_THRESHOLD = 0.5
SAME_RELATIVE_DISPERSION_CEILING = 0.02
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD = 0.99
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD = 0.80
# The reader returns an object supporting the buffer protocol. Python 3.10 does
# not provide a standard Buffer type annotation.
@@ -65,6 +59,16 @@ def read_float32_file(filename: str) -> object:
# accidental field reassignment but does not imply deep immutability.
@dataclass(frozen=True)
class ComparisonThresholds:
clear_gap_relative: float = 0.005
same_center_relative: float = 0.005
same_overlap_fraction: float = 0.5
same_relative_dispersion_ceiling: float = 0.02
bulk_same_sample_coverage: float = 0.99
bulk_same_support_coverage: float = 0.80
@dataclass(frozen=True)
class Float32BinarySource:
count: int
@@ -594,26 +598,26 @@ def make_decision(status, code, message, *, severity=0.0):
)
def compare_intervals_for_clear_gap(ref_interval, cmp_interval):
def compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds):
# These ratios are equivalent to log(ref/cmp) >= log(1 + delta), but avoid
# evaluating logarithms on every comparison.
if cmp_interval.upper < ref_interval.lower:
gap = ref_interval.lower - cmp_interval.upper
if gap / cmp_interval.upper >= CLEAR_GAP_RELATIVE_THRESHOLD:
if gap / cmp_interval.upper >= thresholds.clear_gap_relative:
return ComparisonStatus.FAST
if cmp_interval.lower > ref_interval.upper:
gap = cmp_interval.lower - ref_interval.upper
if gap / ref_interval.upper >= CLEAR_GAP_RELATIVE_THRESHOLD:
if gap / ref_interval.upper >= thresholds.clear_gap_relative:
return ComparisonStatus.SLOW
return None
def centers_are_close(ref_center, cmp_center):
def centers_are_close(ref_center, cmp_center, thresholds):
if not is_positive_finite(ref_center) or not is_positive_finite(cmp_center):
return False
return (
abs(ref_center - cmp_center) / min(ref_center, cmp_center)
<= SAME_CENTER_RELATIVE_THRESHOLD
<= thresholds.same_center_relative
)
@@ -644,10 +648,10 @@ def interval_overlap_fraction(ref_interval, cmp_interval):
)
def intervals_overlap_strongly(ref_interval, cmp_interval):
def intervals_overlap_strongly(ref_interval, cmp_interval, thresholds):
return (
interval_overlap_fraction(ref_interval, cmp_interval)
>= SAME_OVERLAP_FRACTION_THRESHOLD
>= thresholds.same_overlap_fraction
)
@@ -673,7 +677,7 @@ def symmetric_nearest_log_distances(x, y):
return symmetric_nearest_distances(np.log(x), np.log(y))
def compute_nearest_neighbor_coverages(ref_values, cmp_values):
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
ref_unique, ref_counts = np.unique_counts(ref_values)
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
if len(ref_unique) == 0 or len(cmp_unique) == 0:
@@ -682,7 +686,7 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values):
ref_distances, cmp_distances = symmetric_nearest_log_distances(
ref_unique, cmp_unique
)
tolerance = math.log1p(SAME_CENTER_RELATIVE_THRESHOLD)
tolerance = math.log1p(thresholds.same_center_relative)
ref_covered = ref_distances <= tolerance
cmp_covered = cmp_distances <= tolerance
@@ -694,12 +698,12 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values):
}
def coverages_support_same(coverages):
def coverages_support_same(coverages, thresholds):
return (
coverages["ref_sample"] >= BULK_SAME_SAMPLE_COVERAGE_THRESHOLD
and coverages["cmp_sample"] >= BULK_SAME_SAMPLE_COVERAGE_THRESHOLD
and coverages["ref_support"] >= BULK_SAME_SUPPORT_COVERAGE_THRESHOLD
and coverages["cmp_support"] >= BULK_SAME_SUPPORT_COVERAGE_THRESHOLD
coverages["ref_sample"] >= thresholds.bulk_same_sample_coverage
and coverages["cmp_sample"] >= thresholds.bulk_same_sample_coverage
and coverages["ref_support"] >= thresholds.bulk_same_support_coverage
and coverages["cmp_support"] >= thresholds.bulk_same_support_coverage
)
@@ -711,17 +715,17 @@ def format_coverage(value):
return f"{value * 100.0:.1f}%"
def make_bulk_coverage_mismatch_decision(label, coverages):
sample_threshold = format_coverage_threshold(BULK_SAME_SAMPLE_COVERAGE_THRESHOLD)
support_threshold = format_coverage_threshold(BULK_SAME_SUPPORT_COVERAGE_THRESHOLD)
def make_bulk_coverage_mismatch_decision(label, coverages, thresholds):
sample_threshold = format_coverage_threshold(thresholds.bulk_same_sample_coverage)
support_threshold = format_coverage_threshold(thresholds.bulk_same_support_coverage)
sample_deficit = max(
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD - coverages["ref_sample"],
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD - coverages["cmp_sample"],
thresholds.bulk_same_sample_coverage - coverages["ref_sample"],
thresholds.bulk_same_sample_coverage - coverages["cmp_sample"],
0.0,
)
support_deficit = max(
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD - coverages["ref_support"],
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD - coverages["cmp_support"],
thresholds.bulk_same_support_coverage - coverages["ref_support"],
thresholds.bulk_same_support_coverage - coverages["cmp_support"],
0.0,
)
severity = max(sample_deficit, support_deficit)
@@ -767,7 +771,7 @@ def scale_interval(interval, scale):
def confirm_clear_gap_with_clock_rate(
status, ref_timing, cmp_timing, ref_interval, cmp_interval
status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
):
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
return make_decision(
@@ -785,7 +789,7 @@ def confirm_clear_gap_with_clock_rate(
"clear timing gap was not confirmed because SM clock summaries are invalid",
)
cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles)
cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles, thresholds)
if cycle_status == status:
return make_decision(
status,
@@ -799,7 +803,7 @@ def confirm_clear_gap_with_clock_rate(
)
def compare_timings_for_clear_gap(ref_timing, cmp_timing):
def compare_timings_for_clear_gap(ref_timing, cmp_timing, thresholds):
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
if ref_interval is None or cmp_interval is None:
@@ -809,7 +813,7 @@ def compare_timings_for_clear_gap(ref_timing, cmp_timing):
"could not construct comparable timing intervals",
)
status = compare_intervals_for_clear_gap(ref_interval, cmp_interval)
status = compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds)
if status is None:
return make_decision(
ComparisonStatus.UNDECIDED,
@@ -818,18 +822,18 @@ def compare_timings_for_clear_gap(ref_timing, cmp_timing):
)
return confirm_clear_gap_with_clock_rate(
status, ref_timing, cmp_timing, ref_interval, cmp_interval
status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
)
def compare_intervals_for_same(ref_interval, cmp_interval):
if not centers_are_close(ref_interval.center, cmp_interval.center):
def compare_intervals_for_same(ref_interval, cmp_interval, thresholds):
if not centers_are_close(ref_interval.center, cmp_interval.center, thresholds):
return make_decision(
ComparisonStatus.UNDECIDED,
"centers_not_close",
"timing centers are not close enough to declare same",
)
if not intervals_overlap_strongly(ref_interval, cmp_interval):
if not intervals_overlap_strongly(ref_interval, cmp_interval, thresholds):
return make_decision(
ComparisonStatus.UNDECIDED,
"weak_interval_overlap",
@@ -842,7 +846,9 @@ def compare_intervals_for_same(ref_interval, cmp_interval):
)
def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_interval):
def confirm_same_with_clock_rate(
ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
):
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
return make_decision(
ComparisonStatus.SAME,
@@ -859,7 +865,7 @@ def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_inter
"same decision was not confirmed because SM clock summaries are invalid",
)
decision = compare_intervals_for_same(ref_cycles, cmp_cycles)
decision = compare_intervals_for_same(ref_cycles, cmp_cycles, thresholds)
if decision.status == ComparisonStatus.SAME:
return make_decision(
ComparisonStatus.SAME,
@@ -873,24 +879,24 @@ def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_inter
)
def compare_values_for_bulk_same(ref_values, cmp_values, *, label):
coverages = compute_nearest_neighbor_coverages(ref_values, cmp_values)
def compare_values_for_bulk_same(ref_values, cmp_values, *, label, thresholds):
coverages = compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds)
if coverages is None:
return make_decision(
ComparisonStatus.UNDECIDED,
f"bulk_{label}_data_unusable",
f"bulk {label} data is empty or unusable",
)
if coverages_support_same(coverages):
if coverages_support_same(coverages, thresholds):
return make_decision(
ComparisonStatus.SAME,
f"bulk_{label}_same",
f"bulk {label} nearest-neighbor coverage supports same",
)
return make_bulk_coverage_mismatch_decision(label, coverages)
return make_bulk_coverage_mismatch_decision(label, coverages, thresholds)
def compare_timings_for_bulk_same(ref_timing, cmp_timing):
def compare_timings_for_bulk_same(ref_timing, cmp_timing, thresholds):
ref_bulk = get_bulk_time_and_cycles(ref_timing)
cmp_bulk = get_bulk_time_and_cycles(cmp_timing)
if ref_bulk is None or cmp_bulk is None:
@@ -903,11 +909,15 @@ def compare_timings_for_bulk_same(ref_timing, cmp_timing):
ref_times, ref_cycles = ref_bulk
cmp_times, cmp_cycles = cmp_bulk
time_decision = compare_values_for_bulk_same(ref_times, cmp_times, label="time")
time_decision = compare_values_for_bulk_same(
ref_times, cmp_times, label="time", thresholds=thresholds
)
if time_decision.status != ComparisonStatus.SAME:
return time_decision
cycle_decision = compare_values_for_bulk_same(ref_cycles, cmp_cycles, label="cycle")
cycle_decision = compare_values_for_bulk_same(
ref_cycles, cmp_cycles, label="cycle", thresholds=thresholds
)
if cycle_decision.status != ComparisonStatus.SAME:
return cycle_decision
@@ -918,14 +928,14 @@ def compare_timings_for_bulk_same(ref_timing, cmp_timing):
)
def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise):
def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise, thresholds):
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
return make_decision(
ComparisonStatus.UNDECIDED,
"noise_unavailable",
"relative dispersion is unavailable or non-finite",
)
if max(ref_noise, cmp_noise) > SAME_RELATIVE_DISPERSION_CEILING:
if max(ref_noise, cmp_noise) > thresholds.same_relative_dispersion_ceiling:
return make_decision(
ComparisonStatus.UNDECIDED,
"noise_too_high",
@@ -941,12 +951,12 @@ def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise):
"could not construct comparable timing intervals",
)
decision = compare_intervals_for_same(ref_interval, cmp_interval)
decision = compare_intervals_for_same(ref_interval, cmp_interval, thresholds)
if decision.status != ComparisonStatus.SAME:
return decision
return confirm_same_with_clock_rate(
ref_timing, cmp_timing, ref_interval, cmp_interval
ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
)
@@ -1022,7 +1032,10 @@ def compute_common_time_estimates(ref_timing, cmp_timing):
)
def compare_gpu_timings(ref_timing, cmp_timing):
def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
cmp_time = cmp_estimate.center
@@ -1048,15 +1061,19 @@ def compare_gpu_timings(ref_timing, cmp_timing):
else:
max_noise = max(ref_noise, cmp_noise)
decision = compare_timings_for_clear_gap(ref_timing, cmp_timing)
decision = compare_timings_for_clear_gap(
ref_timing, cmp_timing, comparison_thresholds
)
if decision.status == ComparisonStatus.UNDECIDED and decision.reason.code in {
"no_clear_gap",
"missing_interval",
}:
bulk_decision = compare_timings_for_bulk_same(ref_timing, cmp_timing)
bulk_decision = compare_timings_for_bulk_same(
ref_timing, cmp_timing, comparison_thresholds
)
if bulk_decision.reason.code == "bulk_data_unavailable":
decision = compare_timings_for_same(
ref_timing, cmp_timing, ref_noise, cmp_noise
ref_timing, cmp_timing, ref_noise, cmp_noise, comparison_thresholds
)
else:
decision = bulk_decision
@@ -1413,7 +1430,11 @@ def compare_benches(
compare_device_filter=None,
ref_json_dir=None,
cmp_json_dir=None,
comparison_thresholds=None,
):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
if plot_along:
import matplotlib.pyplot as plt
import seaborn as sns
@@ -1533,7 +1554,9 @@ def compare_benches(
# "Timing" column + values "CPU/GPU/Batch"?
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(ref_gpu_time, cmp_gpu_time)
comparison = compare_gpu_timings(
ref_gpu_time, cmp_gpu_time, comparison_thresholds
)
if comparison is None:
continue