Explicitly handle unavailable timings in nvbench-compare

Treat matched states with unusable timing data as UNKNOWN instead of
dropping them from the comparison. This includes missing, non-finite, or
non-positive timing centers, skipped states, and states with missing GPU
timing summaries.

Add explicit reason codes for these cases so the summary points users at
the underlying data issue. Preserve available timing data from the other
side when only one side is missing, and render unavailable durations as
n/a in all display modes.

Also sort values returned by np.unique_counts before nearest-neighbor
coverage checks so the distance algorithm receives ordered inputs.

Add regression coverage for UNKNOWN counting, skipped states, missing
summaries, unavailable center formatting, and the updated coverage helper.
This commit is contained in:
Oleksandr Pavlyk
2026-06-24 15:06:08 -05:00
parent 17536fd4ff
commit b34dfbb348
2 changed files with 386 additions and 51 deletions

View File

@@ -160,6 +160,10 @@ COMPARISON_THRESHOLD_RANGES = {
}
def get_default_thresholds() -> ComparisonThresholds:
return COMPARISON_THRESHOLD_PRESETS[COMPARISON_DEFAULT_PRESET]
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
try:
return COMPARISON_THRESHOLD_PRESETS[preset_name]
@@ -169,11 +173,14 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
def load_toml_module() -> TomlModule:
try:
# built-in Python module, added in 3.11 via PEP 680
import tomllib
return tomllib
except ModuleNotFoundError:
try:
# third-party library for Python 3.10
# note Python 3.10 EOL date is Oct. 31, 2026
import tomli
return tomli
@@ -465,6 +472,11 @@ REASON_DISPLAY_CODES = {
"same_summary": "sum-same",
"same_without_clock_rate": "same-no-clk",
"summary_cycle_gap_not_confirmed": "sc-gap-miss",
"gpu_timing_summaries_missing": "summ-miss",
"state_skipped": "state-skip",
"timing_center_missing": "center-miss",
"timing_center_nonfinite": "center-nonfin",
"timing_center_nonpositive": "center-nonpos",
"weak_interval_overlap": "weak-overlap",
}
@@ -494,12 +506,12 @@ class SummaryComparison:
cmp_interval: TimingInterval | None
ref_estimate: TimeEstimate
cmp_estimate: TimeEstimate
ref_time: float
cmp_time: float
ref_time: float | None
cmp_time: float | None
ref_noise: float | None
cmp_noise: float | None
diff: float
frac_diff: float
diff: float | None
frac_diff: float | None
diff_interval: tuple[float, float] | None
frac_diff_interval: tuple[float, float] | None
max_noise: float | None
@@ -915,6 +927,21 @@ def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float3
)
def make_empty_gpu_timing_data():
return GpuTimingData(
minimum=None,
maximum=None,
mean=None,
stdev=None,
stdev_relative=None,
first_quartile=None,
median=None,
third_quartile=None,
interquartile_range=None,
interquartile_range_relative=None,
)
def resolve_bulk_source_filename(source: Float32BinarySource | None) -> str | None:
if source is None:
return None
@@ -1299,9 +1326,19 @@ def format_support_filter_info(filter_info):
return "off(disabled)"
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
ref_unique, ref_counts = np.unique_counts(ref_values)
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
def sorted_unique_counts(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
unique_values, unique_counts = np.unique_counts(values)
# unique is not guaranteed to return sorted values
# make sure to order them
sorting_indices = np.argsort(unique_values)
return unique_values[sorting_indices], unique_counts[sorting_indices]
def compute_nearest_neighbor_coverages(
ref_values: np.ndarray, cmp_values: np.ndarray, thresholds: ComparisonThresholds
) -> dict[str, Any] | None:
ref_unique, ref_counts = sorted_unique_counts(ref_values)
cmp_unique, cmp_counts = sorted_unique_counts(cmp_values)
if len(ref_unique) == 0 or len(cmp_unique) == 0:
return None
@@ -1697,27 +1734,67 @@ def compute_common_time_estimates(ref_timing, cmp_timing):
)
def unusable_timing_center_decision(ref_time, cmp_time):
if ref_time is None or cmp_time is None:
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_missing",
"timing center is missing",
)
if not math.isfinite(ref_time) or not math.isfinite(cmp_time):
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_nonfinite",
"timing center is non-finite",
)
if ref_time <= 0.0 or cmp_time <= 0.0:
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_nonpositive",
"timing center is non-positive",
)
return None
def make_unavailable_timing_comparison(decision, ref_timing, cmp_timing):
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
return SummaryComparison(
ref_interval=compute_timing_interval(ref_timing),
cmp_interval=compute_timing_interval(cmp_timing),
ref_estimate=ref_estimate,
cmp_estimate=cmp_estimate,
ref_time=ref_estimate.center,
cmp_time=cmp_estimate.center,
ref_noise=ref_estimate.relative_dispersion,
cmp_noise=cmp_estimate.relative_dispersion,
diff=None,
frac_diff=None,
diff_interval=None,
frac_diff_interval=None,
max_noise=None,
status=decision.status,
reason=decision.reason,
)
def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
comparison_thresholds = get_default_thresholds()
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
cmp_time = cmp_estimate.center
ref_time = ref_estimate.center
if cmp_time is None or ref_time is None:
return None
if not math.isfinite(cmp_time) or not math.isfinite(ref_time):
return None
if cmp_time <= 0.0 or ref_time <= 0.0:
return None
cmp_noise = cmp_estimate.relative_dispersion
ref_noise = ref_estimate.relative_dispersion
unusable_center_decision = unusable_timing_center_decision(ref_time, cmp_time)
if unusable_center_decision is not None:
return make_unavailable_timing_comparison(
unusable_center_decision, ref_timing, cmp_timing
)
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
diff = cmp_time - ref_time
@@ -1769,6 +1846,53 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
)
def get_state_summaries(state: Mapping[str, Any]) -> list[dict[str, Any]]:
summaries = state.get("summaries")
return summaries if summaries is not None else []
def state_has_summaries(state):
return bool(state.get("summaries"))
def format_skipped_state_reason(side, state):
reason = state.get("skip_reason")
if reason:
return f"{side} state skipped: {reason}"
return f"{side} state skipped"
def missing_state_summaries_decision(ref_state, cmp_state):
skipped_messages = []
if ref_state.get("is_skipped"):
skipped_messages.append(format_skipped_state_reason("reference", ref_state))
if cmp_state.get("is_skipped"):
skipped_messages.append(format_skipped_state_reason("compare", cmp_state))
if skipped_messages:
return make_decision(
ComparisonStatus.UNKNOWN,
"state_skipped",
"; ".join(skipped_messages),
)
missing_sides = []
if not state_has_summaries(ref_state):
missing_sides.append("reference")
if not state_has_summaries(cmp_state):
missing_sides.append("compare")
if not missing_sides:
return None
if len(missing_sides) == 2:
message = "reference and compare GPU timing summaries are missing"
else:
message = f"{missing_sides[0]} GPU timing summaries are missing"
return make_decision(
ComparisonStatus.UNKNOWN,
"gpu_timing_summaries_missing",
message,
)
def find_matching_bench(needle, haystack):
for hay in haystack:
if hay["name"] == needle["name"]:
@@ -1959,7 +2083,18 @@ def matching_axis_filters(state, axis_filter_groups):
)
def format_duration(seconds):
def is_finite_number(value):
return value is not None and math.isfinite(value)
def format_duration(seconds, *, allow_negative=False, allow_zero=False):
if (
not is_finite_number(seconds)
or (seconds < 0.0 and not allow_negative)
or (seconds == 0.0 and not allow_zero)
):
return "n/a"
if seconds >= 1:
multiplier = 1.0
units = "s"
@@ -1976,6 +2111,10 @@ def format_duration(seconds):
def select_duration_units(*seconds_values):
seconds_values = [value for value in seconds_values if is_finite_number(value)]
if not seconds_values:
return 1e6, "us"
max_abs_seconds = max(abs(value) for value in seconds_values)
if max_abs_seconds >= 1:
return 1.0, "s"
@@ -1985,6 +2124,9 @@ def select_duration_units(*seconds_values):
def duration_precision_for_center(center, delta_multiplier):
if not is_finite_number(center):
return 3
center_multiplier, _ = select_duration_units(center)
center_quantum = 10.0**-3 * (delta_multiplier / center_multiplier)
if center_quantum >= 1.0:
@@ -1996,6 +2138,9 @@ def format_duration_range(bounds):
if bounds is None:
return "n/a"
lower, upper = bounds
if not is_finite_number(lower) or not is_finite_number(upper):
return "n/a"
multiplier, units = select_duration_units(lower, upper)
return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}"
@@ -2003,7 +2148,7 @@ def format_duration_range(bounds):
def format_timing_with_interval(
center, interval, *, center_width=None, interval_width=None
):
if center is None:
if center is None or not is_positive_finite(center):
return "n/a"
if interval is None:
if center_width is not None and interval_width is not None:
@@ -2068,6 +2213,13 @@ def align_interval_values(values, widths=None):
def explicit_interval_values(center, interval):
if (
not is_positive_finite(interval.lower)
or not is_positive_finite(interval.center)
or not is_positive_finite(interval.upper)
):
return None
multiplier, units = select_duration_units(
interval.lower, interval.center, interval.upper
)
@@ -2086,10 +2238,13 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
for comparison in comparisons:
center = center_getter(comparison)
interval = interval_getter(comparison)
if center is None or interval is None:
if not is_positive_finite(center) or interval is None:
continue
values, _ = explicit_interval_values(center, interval)
interval_values = explicit_interval_values(center, interval)
if interval_values is None:
continue
values, _ = interval_values
prefix = longest_common_prefix(values)
if common_numeric_prefix_is_useful(prefix):
continue
@@ -2099,12 +2254,15 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
def format_timing_with_explicit_interval(center, interval, *, value_widths=None):
if center is None:
if center is None or not is_positive_finite(center):
return "n/a"
if interval is None:
return format_duration(center)
values, units = explicit_interval_values(center, interval)
interval_values = explicit_interval_values(center, interval)
if interval_values is None:
return "n/a"
values, units = interval_values
prefix = longest_common_prefix(values)
if not common_numeric_prefix_is_useful(prefix):
values = align_interval_values(values, value_widths)
@@ -2180,7 +2338,9 @@ def append_display_row(row, comparison, no_color, display):
row.append(format_percentage(comparison.ref_noise))
row.append(format_duration(comparison.cmp_time))
row.append(format_percentage(comparison.cmp_noise))
row.append(format_duration(comparison.diff))
row.append(
format_duration(comparison.diff, allow_negative=True, allow_zero=True)
)
row.append(format_percentage(comparison.frac_diff))
row.append(colorize_comparison_status(comparison.status, no_color))
return
@@ -2230,7 +2390,7 @@ def timing_interval_column_widths(comparisons, center_getter, interval_getter):
interval_width = 0
for comparison in comparisons:
center = center_getter(comparison)
if center is None:
if not is_positive_finite(center):
continue
center_multiplier, center_units = select_duration_units(center)
@@ -2411,7 +2571,7 @@ def compare_benches(
bulk_debug_rows=None,
):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
comparison_thresholds = get_default_thresholds()
if plot_along:
import matplotlib.pyplot as plt
@@ -2511,23 +2671,42 @@ def compare_benches(
axis_value_name = axis_value["name"]
row.append(format_axis_value(axis_value_name, axis_value, axes))
cmp_summaries = cmp_state["summaries"]
ref_summaries = ref_state["summaries"]
if not ref_summaries or not cmp_summaries:
continue
cmp_summaries = get_state_summaries(cmp_state)
ref_summaries = get_state_summaries(ref_state)
# TODO: Use other timings, too. Maybe multiple rows, with a
# "Timing" column + values "CPU/GPU/Batch"?
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(
ref_gpu_time, cmp_gpu_time, comparison_thresholds
missing_summaries_decision = missing_state_summaries_decision(
ref_state, cmp_state
)
if missing_summaries_decision is None:
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(
ref_gpu_time, cmp_gpu_time, comparison_thresholds
)
else:
cmp_gpu_time = (
extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
if cmp_summaries
else make_empty_gpu_timing_data()
)
ref_gpu_time = (
extract_gpu_timing_data(ref_summaries, ref_json_dir)
if ref_summaries
else make_empty_gpu_timing_data()
)
comparison = make_unavailable_timing_comparison(
missing_summaries_decision, ref_gpu_time, cmp_gpu_time
)
if comparison is None:
continue
if plot_along:
if (
plot_along
and is_positive_finite(comparison.ref_time)
and is_positive_finite(comparison.cmp_time)
):
axis_name_parts = []
axis_value = None
for av in axis_values:
@@ -2554,7 +2733,10 @@ def compare_benches(
)
run_data.stats.record(comparison.status, comparison.reason)
if abs(comparison.frac_diff) >= threshold:
if comparison.status == ComparisonStatus.UNKNOWN or (
comparison.frac_diff is not None
and abs(comparison.frac_diff) >= threshold
):
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
append_display_row(row, comparison, no_color, display)

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import importlib.util
import math
import sys
import types
from pathlib import Path
@@ -308,7 +309,9 @@ def test_compare_benches_matches_duplicate_states_after_axis_filter(
assert run_data.stats.unknown_count == 0
def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare):
def test_compare_benches_counts_non_finite_centers_as_unknown(
monkeypatch, nvbench_compare
):
run_data = make_comparison_run_data(nvbench_compare)
ref_benches = [
@@ -342,12 +345,12 @@ def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare):
no_color=True,
)
assert run_data.stats.config_count == 1
assert run_data.stats.config_count == 3
assert run_data.stats.pass_count == 0
assert run_data.stats.improvement_count == 0
assert run_data.stats.regression_count == 0
assert run_data.stats.undecided_count == 1
assert run_data.stats.unknown_count == 0
assert run_data.stats.unknown_count == 2
def test_gpu_timing_data_loads_samples_and_frequencies_lazily(
@@ -858,6 +861,15 @@ def test_compare_gpu_timings_uses_bulk_data_to_confirm_same(nvbench_compare):
def test_format_diff_and_percent_ranges(nvbench_compare):
assert nvbench_compare.format_duration(None) == "n/a"
assert nvbench_compare.format_duration(math.nan) == "n/a"
assert nvbench_compare.format_duration(math.inf) == "n/a"
assert nvbench_compare.format_duration(-1.0) == "n/a"
assert nvbench_compare.format_duration(0.0) == "n/a"
assert (
nvbench_compare.format_duration(-1.0, allow_negative=True) == "-1000000.000 us"
)
assert nvbench_compare.format_duration(0.0, allow_zero=True) == "0.000 us"
assert nvbench_compare.format_duration_range((-12e-6, 8e-6)) == "[-12.00, 8.00] us"
assert (
nvbench_compare.format_percentage_bounds(
@@ -1056,9 +1068,15 @@ def test_compare_gpu_timings_keeps_bulk_mismatch_undecided(nvbench_compare):
assert comparison is not None
assert comparison.status == nvbench_compare.ComparisonStatus.UNDECIDED
assert comparison.reason.code == "bulk_time_support_mismatch"
assert "sample: min(ref=0.0%, cmp=0.0%) >= 99.0%" in comparison.reason.message
sample_threshold = (
nvbench_compare.get_default_thresholds().bulk_same_sample_coverage * 100.0
)
assert (
f"sample: min(ref=0.0%, cmp=0.0%) >= {sample_threshold:0.1f}%"
in comparison.reason.message
)
assert "support: min(ref=0.0%, cmp=0.0%) >= 80.0%" in comparison.reason.message
assert "99.0%" in comparison.reason.message
assert f"{sample_threshold:0.1f}%" in comparison.reason.message
assert "80.0%" in comparison.reason.message
@@ -1203,18 +1221,31 @@ def test_reason_legend_omits_trivial_aliases(nvbench_compare):
]
@pytest.mark.parametrize("ref_time, cmp_time", [(None, 1.0), (1.0, None), (0.0, 1.0)])
def test_compare_gpu_timings_rejects_unusable_centers(
nvbench_compare, ref_time, cmp_time
@pytest.mark.parametrize(
"ref_time, cmp_time, reason_code",
[
(None, 1.0, "timing_center_missing"),
(1.0, None, "timing_center_missing"),
(math.nan, 1.0, "timing_center_nonfinite"),
(math.inf, 1.0, "timing_center_nonfinite"),
(0.0, 1.0, "timing_center_nonpositive"),
(-1.0, 1.0, "timing_center_nonpositive"),
],
)
def test_compare_gpu_timings_reports_unusable_centers_as_unknown(
nvbench_compare, ref_time, cmp_time, reason_code
):
assert (
nvbench_compare.compare_gpu_timings(
make_gpu_timing_data(nvbench_compare, mean=ref_time),
make_gpu_timing_data(nvbench_compare, mean=cmp_time),
)
is None
comparison = nvbench_compare.compare_gpu_timings(
make_gpu_timing_data(nvbench_compare, mean=ref_time),
make_gpu_timing_data(nvbench_compare, mean=cmp_time),
)
assert comparison is not None
assert comparison.status == nvbench_compare.ComparisonStatus.UNKNOWN
assert comparison.reason.code == reason_code
assert comparison.diff is None
assert comparison.frac_diff is None
def test_compare_benches_reports_regression_when_robust_intervals_and_clock_confirm(
monkeypatch, nvbench_compare
@@ -1856,6 +1887,128 @@ def test_main_prints_bulk_debug_python_to_stdout(monkeypatch, capsys, nvbench_co
assert "# NVB-BULK-END" in output
def test_compare_benches_counts_unusable_timing_as_unknown(
monkeypatch, nvbench_compare
):
run_data = make_comparison_run_data(nvbench_compare)
captured = {}
def fake_tabulate(rows, headers, *args, **kwargs):
captured["rows"] = rows
captured["headers"] = headers
return ""
monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate)
ref_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="nan")])]
cmp_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="1.0")])]
nvbench_compare.compare_benches(
run_data,
ref_benches,
cmp_benches,
threshold=1.0,
plot_along=None,
plot=False,
dark=False,
filter_plan=make_filter_plan(nvbench_compare),
no_color=True,
)
assert run_data.stats.config_count == 1
assert run_data.stats.unknown_count == 1
assert captured["headers"][-4:] == ["Ref", "Cmp", "Change", "Status"]
row = captured["rows"][0]
assert row[-4] == "n/a"
assert row[-3] == "1.000 s"
assert row[-2] == ""
assert row[-1] == "\U0001f7e1 ????"
def test_compare_benches_counts_skipped_state_as_unknown(monkeypatch, nvbench_compare):
run_data = make_comparison_run_data(nvbench_compare)
captured = {}
def fake_tabulate(rows, headers, *args, **kwargs):
captured["rows"] = rows
captured["headers"] = headers
return ""
monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate)
ref_state = make_state(nvbench_compare, "state")
ref_state["summaries"] = None
ref_state["is_skipped"] = True
ref_state["skip_reason"] = "requested by benchmark"
cmp_state = make_state(nvbench_compare, "state", mean="1.0")
nvbench_compare.compare_benches(
run_data,
[make_benchmark([ref_state])],
[make_benchmark([cmp_state])],
threshold=1.0,
plot_along=None,
plot=False,
dark=False,
filter_plan=make_filter_plan(nvbench_compare),
no_color=True,
)
assert run_data.stats.config_count == 1
assert run_data.stats.unknown_count == 1
reason_summary = run_data.stats.reason_legend["state-skip"]
assert reason_summary.canonical_code == "state_skipped"
assert reason_summary.message == "reference state skipped: requested by benchmark"
assert captured["headers"][-4:] == ["Ref", "Cmp", "Change", "Status"]
row = captured["rows"][0]
assert row[-4] == "n/a"
assert row[-3] == "1.000 s"
assert row[-2] == ""
assert row[-1] == "\U0001f7e1 ????"
def test_compare_benches_counts_missing_summaries_as_unknown(
monkeypatch, nvbench_compare
):
run_data = make_comparison_run_data(nvbench_compare)
captured = {}
def fake_tabulate(rows, headers, *args, **kwargs):
captured["rows"] = rows
captured["headers"] = headers
return ""
monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate)
ref_state = make_state(nvbench_compare, "state")
del ref_state["summaries"]
cmp_state = make_state(nvbench_compare, "state", mean="1.0")
nvbench_compare.compare_benches(
run_data,
[make_benchmark([ref_state])],
[make_benchmark([cmp_state])],
threshold=1.0,
plot_along=None,
plot=False,
dark=False,
filter_plan=make_filter_plan(nvbench_compare),
no_color=True,
)
assert run_data.stats.config_count == 1
assert run_data.stats.unknown_count == 1
reason_summary = run_data.stats.reason_legend["summ-miss"]
assert reason_summary.canonical_code == "gpu_timing_summaries_missing"
assert reason_summary.message == "reference GPU timing summaries are missing"
assert captured["headers"][-4:] == ["Ref", "Cmp", "Change", "Status"]
row = captured["rows"][0]
assert row[-4] == "n/a"
assert row[-3] == "1.000 s"
assert row[-2] == ""
assert row[-1] == "\U0001f7e1 ????"
def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compare):
run_data = make_comparison_run_data(nvbench_compare)
captured = {}