mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-07-01 19:57:41 +00:00
Group nvbench-compare thresholds into a config object
Replace the scattered module-level comparison threshold constants with a ComparisonThresholds value object. Thread this object through compare_benches, compare_gpu_timings, and the lower-level clear-gap, summary-SAME, and bulk-SAME decision helpers. Keep existing behavior by constructing default ComparisonThresholds when callers do not provide one. This prepares nvbench-compare for future CLI-configurable decision thresholds while keeping one consistent configuration for an entire comparison run. Add test coverage that passes custom thresholds through compare_benches and verifies they affect the SAME decision.
This commit is contained in:
@@ -45,12 +45,6 @@ GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative"
|
||||
GPU_SM_CLOCK_RATE_MEAN_TAG = "nv/cold/sm_clock_rate/mean"
|
||||
SAMPLE_TIMES_TAG = "nv/json/bin:nv/cold/sample_times"
|
||||
SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
|
||||
CLEAR_GAP_RELATIVE_THRESHOLD = 0.005
|
||||
SAME_CENTER_RELATIVE_THRESHOLD = 0.005
|
||||
SAME_OVERLAP_FRACTION_THRESHOLD = 0.5
|
||||
SAME_RELATIVE_DISPERSION_CEILING = 0.02
|
||||
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD = 0.99
|
||||
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD = 0.80
|
||||
|
||||
# The reader returns an object supporting the buffer protocol. Python 3.10 does
|
||||
# not provide a standard Buffer type annotation.
|
||||
@@ -65,6 +59,16 @@ def read_float32_file(filename: str) -> object:
|
||||
# accidental field reassignment but does not imply deep immutability.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComparisonThresholds:
|
||||
clear_gap_relative: float = 0.005
|
||||
same_center_relative: float = 0.005
|
||||
same_overlap_fraction: float = 0.5
|
||||
same_relative_dispersion_ceiling: float = 0.02
|
||||
bulk_same_sample_coverage: float = 0.99
|
||||
bulk_same_support_coverage: float = 0.80
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Float32BinarySource:
|
||||
count: int
|
||||
@@ -594,26 +598,26 @@ def make_decision(status, code, message, *, severity=0.0):
|
||||
)
|
||||
|
||||
|
||||
def compare_intervals_for_clear_gap(ref_interval, cmp_interval):
|
||||
def compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds):
|
||||
# These ratios are equivalent to log(ref/cmp) >= log(1 + delta), but avoid
|
||||
# evaluating logarithms on every comparison.
|
||||
if cmp_interval.upper < ref_interval.lower:
|
||||
gap = ref_interval.lower - cmp_interval.upper
|
||||
if gap / cmp_interval.upper >= CLEAR_GAP_RELATIVE_THRESHOLD:
|
||||
if gap / cmp_interval.upper >= thresholds.clear_gap_relative:
|
||||
return ComparisonStatus.FAST
|
||||
if cmp_interval.lower > ref_interval.upper:
|
||||
gap = cmp_interval.lower - ref_interval.upper
|
||||
if gap / ref_interval.upper >= CLEAR_GAP_RELATIVE_THRESHOLD:
|
||||
if gap / ref_interval.upper >= thresholds.clear_gap_relative:
|
||||
return ComparisonStatus.SLOW
|
||||
return None
|
||||
|
||||
|
||||
def centers_are_close(ref_center, cmp_center):
|
||||
def centers_are_close(ref_center, cmp_center, thresholds):
|
||||
if not is_positive_finite(ref_center) or not is_positive_finite(cmp_center):
|
||||
return False
|
||||
return (
|
||||
abs(ref_center - cmp_center) / min(ref_center, cmp_center)
|
||||
<= SAME_CENTER_RELATIVE_THRESHOLD
|
||||
<= thresholds.same_center_relative
|
||||
)
|
||||
|
||||
|
||||
@@ -644,10 +648,10 @@ def interval_overlap_fraction(ref_interval, cmp_interval):
|
||||
)
|
||||
|
||||
|
||||
def intervals_overlap_strongly(ref_interval, cmp_interval):
|
||||
def intervals_overlap_strongly(ref_interval, cmp_interval, thresholds):
|
||||
return (
|
||||
interval_overlap_fraction(ref_interval, cmp_interval)
|
||||
>= SAME_OVERLAP_FRACTION_THRESHOLD
|
||||
>= thresholds.same_overlap_fraction
|
||||
)
|
||||
|
||||
|
||||
@@ -673,7 +677,7 @@ def symmetric_nearest_log_distances(x, y):
|
||||
return symmetric_nearest_distances(np.log(x), np.log(y))
|
||||
|
||||
|
||||
def compute_nearest_neighbor_coverages(ref_values, cmp_values):
|
||||
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
|
||||
ref_unique, ref_counts = np.unique_counts(ref_values)
|
||||
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
|
||||
if len(ref_unique) == 0 or len(cmp_unique) == 0:
|
||||
@@ -682,7 +686,7 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values):
|
||||
ref_distances, cmp_distances = symmetric_nearest_log_distances(
|
||||
ref_unique, cmp_unique
|
||||
)
|
||||
tolerance = math.log1p(SAME_CENTER_RELATIVE_THRESHOLD)
|
||||
tolerance = math.log1p(thresholds.same_center_relative)
|
||||
ref_covered = ref_distances <= tolerance
|
||||
cmp_covered = cmp_distances <= tolerance
|
||||
|
||||
@@ -694,12 +698,12 @@ def compute_nearest_neighbor_coverages(ref_values, cmp_values):
|
||||
}
|
||||
|
||||
|
||||
def coverages_support_same(coverages):
|
||||
def coverages_support_same(coverages, thresholds):
|
||||
return (
|
||||
coverages["ref_sample"] >= BULK_SAME_SAMPLE_COVERAGE_THRESHOLD
|
||||
and coverages["cmp_sample"] >= BULK_SAME_SAMPLE_COVERAGE_THRESHOLD
|
||||
and coverages["ref_support"] >= BULK_SAME_SUPPORT_COVERAGE_THRESHOLD
|
||||
and coverages["cmp_support"] >= BULK_SAME_SUPPORT_COVERAGE_THRESHOLD
|
||||
coverages["ref_sample"] >= thresholds.bulk_same_sample_coverage
|
||||
and coverages["cmp_sample"] >= thresholds.bulk_same_sample_coverage
|
||||
and coverages["ref_support"] >= thresholds.bulk_same_support_coverage
|
||||
and coverages["cmp_support"] >= thresholds.bulk_same_support_coverage
|
||||
)
|
||||
|
||||
|
||||
@@ -711,17 +715,17 @@ def format_coverage(value):
|
||||
return f"{value * 100.0:.1f}%"
|
||||
|
||||
|
||||
def make_bulk_coverage_mismatch_decision(label, coverages):
|
||||
sample_threshold = format_coverage_threshold(BULK_SAME_SAMPLE_COVERAGE_THRESHOLD)
|
||||
support_threshold = format_coverage_threshold(BULK_SAME_SUPPORT_COVERAGE_THRESHOLD)
|
||||
def make_bulk_coverage_mismatch_decision(label, coverages, thresholds):
|
||||
sample_threshold = format_coverage_threshold(thresholds.bulk_same_sample_coverage)
|
||||
support_threshold = format_coverage_threshold(thresholds.bulk_same_support_coverage)
|
||||
sample_deficit = max(
|
||||
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD - coverages["ref_sample"],
|
||||
BULK_SAME_SAMPLE_COVERAGE_THRESHOLD - coverages["cmp_sample"],
|
||||
thresholds.bulk_same_sample_coverage - coverages["ref_sample"],
|
||||
thresholds.bulk_same_sample_coverage - coverages["cmp_sample"],
|
||||
0.0,
|
||||
)
|
||||
support_deficit = max(
|
||||
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD - coverages["ref_support"],
|
||||
BULK_SAME_SUPPORT_COVERAGE_THRESHOLD - coverages["cmp_support"],
|
||||
thresholds.bulk_same_support_coverage - coverages["ref_support"],
|
||||
thresholds.bulk_same_support_coverage - coverages["cmp_support"],
|
||||
0.0,
|
||||
)
|
||||
severity = max(sample_deficit, support_deficit)
|
||||
@@ -767,7 +771,7 @@ def scale_interval(interval, scale):
|
||||
|
||||
|
||||
def confirm_clear_gap_with_clock_rate(
|
||||
status, ref_timing, cmp_timing, ref_interval, cmp_interval
|
||||
status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
|
||||
):
|
||||
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
|
||||
return make_decision(
|
||||
@@ -785,7 +789,7 @@ def confirm_clear_gap_with_clock_rate(
|
||||
"clear timing gap was not confirmed because SM clock summaries are invalid",
|
||||
)
|
||||
|
||||
cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles)
|
||||
cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles, thresholds)
|
||||
if cycle_status == status:
|
||||
return make_decision(
|
||||
status,
|
||||
@@ -799,7 +803,7 @@ def confirm_clear_gap_with_clock_rate(
|
||||
)
|
||||
|
||||
|
||||
def compare_timings_for_clear_gap(ref_timing, cmp_timing):
|
||||
def compare_timings_for_clear_gap(ref_timing, cmp_timing, thresholds):
|
||||
ref_interval = compute_timing_interval(ref_timing)
|
||||
cmp_interval = compute_timing_interval(cmp_timing)
|
||||
if ref_interval is None or cmp_interval is None:
|
||||
@@ -809,7 +813,7 @@ def compare_timings_for_clear_gap(ref_timing, cmp_timing):
|
||||
"could not construct comparable timing intervals",
|
||||
)
|
||||
|
||||
status = compare_intervals_for_clear_gap(ref_interval, cmp_interval)
|
||||
status = compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds)
|
||||
if status is None:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
@@ -818,18 +822,18 @@ def compare_timings_for_clear_gap(ref_timing, cmp_timing):
|
||||
)
|
||||
|
||||
return confirm_clear_gap_with_clock_rate(
|
||||
status, ref_timing, cmp_timing, ref_interval, cmp_interval
|
||||
status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
|
||||
)
|
||||
|
||||
|
||||
def compare_intervals_for_same(ref_interval, cmp_interval):
|
||||
if not centers_are_close(ref_interval.center, cmp_interval.center):
|
||||
def compare_intervals_for_same(ref_interval, cmp_interval, thresholds):
|
||||
if not centers_are_close(ref_interval.center, cmp_interval.center, thresholds):
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
"centers_not_close",
|
||||
"timing centers are not close enough to declare same",
|
||||
)
|
||||
if not intervals_overlap_strongly(ref_interval, cmp_interval):
|
||||
if not intervals_overlap_strongly(ref_interval, cmp_interval, thresholds):
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
"weak_interval_overlap",
|
||||
@@ -842,7 +846,9 @@ def compare_intervals_for_same(ref_interval, cmp_interval):
|
||||
)
|
||||
|
||||
|
||||
def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_interval):
|
||||
def confirm_same_with_clock_rate(
|
||||
ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
|
||||
):
|
||||
if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None:
|
||||
return make_decision(
|
||||
ComparisonStatus.SAME,
|
||||
@@ -859,7 +865,7 @@ def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_inter
|
||||
"same decision was not confirmed because SM clock summaries are invalid",
|
||||
)
|
||||
|
||||
decision = compare_intervals_for_same(ref_cycles, cmp_cycles)
|
||||
decision = compare_intervals_for_same(ref_cycles, cmp_cycles, thresholds)
|
||||
if decision.status == ComparisonStatus.SAME:
|
||||
return make_decision(
|
||||
ComparisonStatus.SAME,
|
||||
@@ -873,24 +879,24 @@ def confirm_same_with_clock_rate(ref_timing, cmp_timing, ref_interval, cmp_inter
|
||||
)
|
||||
|
||||
|
||||
def compare_values_for_bulk_same(ref_values, cmp_values, *, label):
|
||||
coverages = compute_nearest_neighbor_coverages(ref_values, cmp_values)
|
||||
def compare_values_for_bulk_same(ref_values, cmp_values, *, label, thresholds):
|
||||
coverages = compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds)
|
||||
if coverages is None:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
f"bulk_{label}_data_unusable",
|
||||
f"bulk {label} data is empty or unusable",
|
||||
)
|
||||
if coverages_support_same(coverages):
|
||||
if coverages_support_same(coverages, thresholds):
|
||||
return make_decision(
|
||||
ComparisonStatus.SAME,
|
||||
f"bulk_{label}_same",
|
||||
f"bulk {label} nearest-neighbor coverage supports same",
|
||||
)
|
||||
return make_bulk_coverage_mismatch_decision(label, coverages)
|
||||
return make_bulk_coverage_mismatch_decision(label, coverages, thresholds)
|
||||
|
||||
|
||||
def compare_timings_for_bulk_same(ref_timing, cmp_timing):
|
||||
def compare_timings_for_bulk_same(ref_timing, cmp_timing, thresholds):
|
||||
ref_bulk = get_bulk_time_and_cycles(ref_timing)
|
||||
cmp_bulk = get_bulk_time_and_cycles(cmp_timing)
|
||||
if ref_bulk is None or cmp_bulk is None:
|
||||
@@ -903,11 +909,15 @@ def compare_timings_for_bulk_same(ref_timing, cmp_timing):
|
||||
ref_times, ref_cycles = ref_bulk
|
||||
cmp_times, cmp_cycles = cmp_bulk
|
||||
|
||||
time_decision = compare_values_for_bulk_same(ref_times, cmp_times, label="time")
|
||||
time_decision = compare_values_for_bulk_same(
|
||||
ref_times, cmp_times, label="time", thresholds=thresholds
|
||||
)
|
||||
if time_decision.status != ComparisonStatus.SAME:
|
||||
return time_decision
|
||||
|
||||
cycle_decision = compare_values_for_bulk_same(ref_cycles, cmp_cycles, label="cycle")
|
||||
cycle_decision = compare_values_for_bulk_same(
|
||||
ref_cycles, cmp_cycles, label="cycle", thresholds=thresholds
|
||||
)
|
||||
if cycle_decision.status != ComparisonStatus.SAME:
|
||||
return cycle_decision
|
||||
|
||||
@@ -918,14 +928,14 @@ def compare_timings_for_bulk_same(ref_timing, cmp_timing):
|
||||
)
|
||||
|
||||
|
||||
def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise):
|
||||
def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise, thresholds):
|
||||
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
"noise_unavailable",
|
||||
"relative dispersion is unavailable or non-finite",
|
||||
)
|
||||
if max(ref_noise, cmp_noise) > SAME_RELATIVE_DISPERSION_CEILING:
|
||||
if max(ref_noise, cmp_noise) > thresholds.same_relative_dispersion_ceiling:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNDECIDED,
|
||||
"noise_too_high",
|
||||
@@ -941,12 +951,12 @@ def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise):
|
||||
"could not construct comparable timing intervals",
|
||||
)
|
||||
|
||||
decision = compare_intervals_for_same(ref_interval, cmp_interval)
|
||||
decision = compare_intervals_for_same(ref_interval, cmp_interval, thresholds)
|
||||
if decision.status != ComparisonStatus.SAME:
|
||||
return decision
|
||||
|
||||
return confirm_same_with_clock_rate(
|
||||
ref_timing, cmp_timing, ref_interval, cmp_interval
|
||||
ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds
|
||||
)
|
||||
|
||||
|
||||
@@ -1022,7 +1032,10 @@ def compute_common_time_estimates(ref_timing, cmp_timing):
|
||||
)
|
||||
|
||||
|
||||
def compare_gpu_timings(ref_timing, cmp_timing):
|
||||
def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
|
||||
if comparison_thresholds is None:
|
||||
comparison_thresholds = ComparisonThresholds()
|
||||
|
||||
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
|
||||
|
||||
cmp_time = cmp_estimate.center
|
||||
@@ -1048,15 +1061,19 @@ def compare_gpu_timings(ref_timing, cmp_timing):
|
||||
else:
|
||||
max_noise = max(ref_noise, cmp_noise)
|
||||
|
||||
decision = compare_timings_for_clear_gap(ref_timing, cmp_timing)
|
||||
decision = compare_timings_for_clear_gap(
|
||||
ref_timing, cmp_timing, comparison_thresholds
|
||||
)
|
||||
if decision.status == ComparisonStatus.UNDECIDED and decision.reason.code in {
|
||||
"no_clear_gap",
|
||||
"missing_interval",
|
||||
}:
|
||||
bulk_decision = compare_timings_for_bulk_same(ref_timing, cmp_timing)
|
||||
bulk_decision = compare_timings_for_bulk_same(
|
||||
ref_timing, cmp_timing, comparison_thresholds
|
||||
)
|
||||
if bulk_decision.reason.code == "bulk_data_unavailable":
|
||||
decision = compare_timings_for_same(
|
||||
ref_timing, cmp_timing, ref_noise, cmp_noise
|
||||
ref_timing, cmp_timing, ref_noise, cmp_noise, comparison_thresholds
|
||||
)
|
||||
else:
|
||||
decision = bulk_decision
|
||||
@@ -1413,7 +1430,11 @@ def compare_benches(
|
||||
compare_device_filter=None,
|
||||
ref_json_dir=None,
|
||||
cmp_json_dir=None,
|
||||
comparison_thresholds=None,
|
||||
):
|
||||
if comparison_thresholds is None:
|
||||
comparison_thresholds = ComparisonThresholds()
|
||||
|
||||
if plot_along:
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
@@ -1533,7 +1554,9 @@ def compare_benches(
|
||||
# "Timing" column + values "CPU/GPU/Batch"?
|
||||
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
|
||||
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
|
||||
comparison = compare_gpu_timings(ref_gpu_time, cmp_gpu_time)
|
||||
comparison = compare_gpu_timings(
|
||||
ref_gpu_time, cmp_gpu_time, comparison_thresholds
|
||||
)
|
||||
if comparison is None:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user