Explicitly handle unavailable timings in nvbench-compare

Treat matched states with unusable timing data as UNKNOWN instead of
dropping them from the comparison. This includes missing, non-finite, or
non-positive timing centers, skipped states, and states with missing GPU
timing summaries.

Add explicit reason codes for these cases so the summary points users at
the underlying data issue. Preserve available timing data from the other
side when only one side is missing, and render unavailable durations as
n/a in all display modes.

Also sort values returned by np.unique_counts before nearest-neighbor
coverage checks so the distance algorithm receives ordered inputs.

Add regression coverage for UNKNOWN counting, skipped states, missing
summaries, unavailable center formatting, and the updated coverage helper.
This commit is contained in:
Oleksandr Pavlyk
2026-06-24 15:06:08 -05:00
parent 17536fd4ff
commit b34dfbb348
2 changed files with 386 additions and 51 deletions

View File

@@ -160,6 +160,10 @@ COMPARISON_THRESHOLD_RANGES = {
}
def get_default_thresholds() -> ComparisonThresholds:
return COMPARISON_THRESHOLD_PRESETS[COMPARISON_DEFAULT_PRESET]
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
try:
return COMPARISON_THRESHOLD_PRESETS[preset_name]
@@ -169,11 +173,14 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
def load_toml_module() -> TomlModule:
try:
# built-in Python module, added in 3.11 via PEP 680
import tomllib
return tomllib
except ModuleNotFoundError:
try:
# third-party library for Python 3.10
# note Python 3.10 EOL date is Oct. 31, 2026
import tomli
return tomli
@@ -465,6 +472,11 @@ REASON_DISPLAY_CODES = {
"same_summary": "sum-same",
"same_without_clock_rate": "same-no-clk",
"summary_cycle_gap_not_confirmed": "sc-gap-miss",
"gpu_timing_summaries_missing": "summ-miss",
"state_skipped": "state-skip",
"timing_center_missing": "center-miss",
"timing_center_nonfinite": "center-nonfin",
"timing_center_nonpositive": "center-nonpos",
"weak_interval_overlap": "weak-overlap",
}
@@ -494,12 +506,12 @@ class SummaryComparison:
cmp_interval: TimingInterval | None
ref_estimate: TimeEstimate
cmp_estimate: TimeEstimate
ref_time: float
cmp_time: float
ref_time: float | None
cmp_time: float | None
ref_noise: float | None
cmp_noise: float | None
diff: float
frac_diff: float
diff: float | None
frac_diff: float | None
diff_interval: tuple[float, float] | None
frac_diff_interval: tuple[float, float] | None
max_noise: float | None
@@ -915,6 +927,21 @@ def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float3
)
def make_empty_gpu_timing_data():
return GpuTimingData(
minimum=None,
maximum=None,
mean=None,
stdev=None,
stdev_relative=None,
first_quartile=None,
median=None,
third_quartile=None,
interquartile_range=None,
interquartile_range_relative=None,
)
def resolve_bulk_source_filename(source: Float32BinarySource | None) -> str | None:
if source is None:
return None
@@ -1299,9 +1326,19 @@ def format_support_filter_info(filter_info):
return "off(disabled)"
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
ref_unique, ref_counts = np.unique_counts(ref_values)
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
def sorted_unique_counts(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
unique_values, unique_counts = np.unique_counts(values)
# unique is not guaranteed to return sorted values
# make sure to order them
sorting_indices = np.argsort(unique_values)
return unique_values[sorting_indices], unique_counts[sorting_indices]
def compute_nearest_neighbor_coverages(
ref_values: np.ndarray, cmp_values: np.ndarray, thresholds: ComparisonThresholds
) -> dict[str, Any] | None:
ref_unique, ref_counts = sorted_unique_counts(ref_values)
cmp_unique, cmp_counts = sorted_unique_counts(cmp_values)
if len(ref_unique) == 0 or len(cmp_unique) == 0:
return None
@@ -1697,27 +1734,67 @@ def compute_common_time_estimates(ref_timing, cmp_timing):
)
def unusable_timing_center_decision(ref_time, cmp_time):
if ref_time is None or cmp_time is None:
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_missing",
"timing center is missing",
)
if not math.isfinite(ref_time) or not math.isfinite(cmp_time):
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_nonfinite",
"timing center is non-finite",
)
if ref_time <= 0.0 or cmp_time <= 0.0:
return make_decision(
ComparisonStatus.UNKNOWN,
"timing_center_nonpositive",
"timing center is non-positive",
)
return None
def make_unavailable_timing_comparison(decision, ref_timing, cmp_timing):
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
return SummaryComparison(
ref_interval=compute_timing_interval(ref_timing),
cmp_interval=compute_timing_interval(cmp_timing),
ref_estimate=ref_estimate,
cmp_estimate=cmp_estimate,
ref_time=ref_estimate.center,
cmp_time=cmp_estimate.center,
ref_noise=ref_estimate.relative_dispersion,
cmp_noise=cmp_estimate.relative_dispersion,
diff=None,
frac_diff=None,
diff_interval=None,
frac_diff_interval=None,
max_noise=None,
status=decision.status,
reason=decision.reason,
)
def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
comparison_thresholds = get_default_thresholds()
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
cmp_time = cmp_estimate.center
ref_time = ref_estimate.center
if cmp_time is None or ref_time is None:
return None
if not math.isfinite(cmp_time) or not math.isfinite(ref_time):
return None
if cmp_time <= 0.0 or ref_time <= 0.0:
return None
cmp_noise = cmp_estimate.relative_dispersion
ref_noise = ref_estimate.relative_dispersion
unusable_center_decision = unusable_timing_center_decision(ref_time, cmp_time)
if unusable_center_decision is not None:
return make_unavailable_timing_comparison(
unusable_center_decision, ref_timing, cmp_timing
)
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
diff = cmp_time - ref_time
@@ -1769,6 +1846,53 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
)
def get_state_summaries(state: Mapping[str, Any]) -> list[dict[str, Any]]:
summaries = state.get("summaries")
return summaries if summaries is not None else []
def state_has_summaries(state):
return bool(state.get("summaries"))
def format_skipped_state_reason(side, state):
reason = state.get("skip_reason")
if reason:
return f"{side} state skipped: {reason}"
return f"{side} state skipped"
def missing_state_summaries_decision(ref_state, cmp_state):
skipped_messages = []
if ref_state.get("is_skipped"):
skipped_messages.append(format_skipped_state_reason("reference", ref_state))
if cmp_state.get("is_skipped"):
skipped_messages.append(format_skipped_state_reason("compare", cmp_state))
if skipped_messages:
return make_decision(
ComparisonStatus.UNKNOWN,
"state_skipped",
"; ".join(skipped_messages),
)
missing_sides = []
if not state_has_summaries(ref_state):
missing_sides.append("reference")
if not state_has_summaries(cmp_state):
missing_sides.append("compare")
if not missing_sides:
return None
if len(missing_sides) == 2:
message = "reference and compare GPU timing summaries are missing"
else:
message = f"{missing_sides[0]} GPU timing summaries are missing"
return make_decision(
ComparisonStatus.UNKNOWN,
"gpu_timing_summaries_missing",
message,
)
def find_matching_bench(needle, haystack):
for hay in haystack:
if hay["name"] == needle["name"]:
@@ -1959,7 +2083,18 @@ def matching_axis_filters(state, axis_filter_groups):
)
def format_duration(seconds):
def is_finite_number(value):
return value is not None and math.isfinite(value)
def format_duration(seconds, *, allow_negative=False, allow_zero=False):
if (
not is_finite_number(seconds)
or (seconds < 0.0 and not allow_negative)
or (seconds == 0.0 and not allow_zero)
):
return "n/a"
if seconds >= 1:
multiplier = 1.0
units = "s"
@@ -1976,6 +2111,10 @@ def format_duration(seconds):
def select_duration_units(*seconds_values):
seconds_values = [value for value in seconds_values if is_finite_number(value)]
if not seconds_values:
return 1e6, "us"
max_abs_seconds = max(abs(value) for value in seconds_values)
if max_abs_seconds >= 1:
return 1.0, "s"
@@ -1985,6 +2124,9 @@ def select_duration_units(*seconds_values):
def duration_precision_for_center(center, delta_multiplier):
if not is_finite_number(center):
return 3
center_multiplier, _ = select_duration_units(center)
center_quantum = 10.0**-3 * (delta_multiplier / center_multiplier)
if center_quantum >= 1.0:
@@ -1996,6 +2138,9 @@ def format_duration_range(bounds):
if bounds is None:
return "n/a"
lower, upper = bounds
if not is_finite_number(lower) or not is_finite_number(upper):
return "n/a"
multiplier, units = select_duration_units(lower, upper)
return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}"
@@ -2003,7 +2148,7 @@ def format_duration_range(bounds):
def format_timing_with_interval(
center, interval, *, center_width=None, interval_width=None
):
if center is None:
if center is None or not is_positive_finite(center):
return "n/a"
if interval is None:
if center_width is not None and interval_width is not None:
@@ -2068,6 +2213,13 @@ def align_interval_values(values, widths=None):
def explicit_interval_values(center, interval):
if (
not is_positive_finite(interval.lower)
or not is_positive_finite(interval.center)
or not is_positive_finite(interval.upper)
):
return None
multiplier, units = select_duration_units(
interval.lower, interval.center, interval.upper
)
@@ -2086,10 +2238,13 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
for comparison in comparisons:
center = center_getter(comparison)
interval = interval_getter(comparison)
if center is None or interval is None:
if not is_positive_finite(center) or interval is None:
continue
values, _ = explicit_interval_values(center, interval)
interval_values = explicit_interval_values(center, interval)
if interval_values is None:
continue
values, _ = interval_values
prefix = longest_common_prefix(values)
if common_numeric_prefix_is_useful(prefix):
continue
@@ -2099,12 +2254,15 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
def format_timing_with_explicit_interval(center, interval, *, value_widths=None):
if center is None:
if center is None or not is_positive_finite(center):
return "n/a"
if interval is None:
return format_duration(center)
values, units = explicit_interval_values(center, interval)
interval_values = explicit_interval_values(center, interval)
if interval_values is None:
return "n/a"
values, units = interval_values
prefix = longest_common_prefix(values)
if not common_numeric_prefix_is_useful(prefix):
values = align_interval_values(values, value_widths)
@@ -2180,7 +2338,9 @@ def append_display_row(row, comparison, no_color, display):
row.append(format_percentage(comparison.ref_noise))
row.append(format_duration(comparison.cmp_time))
row.append(format_percentage(comparison.cmp_noise))
row.append(format_duration(comparison.diff))
row.append(
format_duration(comparison.diff, allow_negative=True, allow_zero=True)
)
row.append(format_percentage(comparison.frac_diff))
row.append(colorize_comparison_status(comparison.status, no_color))
return
@@ -2230,7 +2390,7 @@ def timing_interval_column_widths(comparisons, center_getter, interval_getter):
interval_width = 0
for comparison in comparisons:
center = center_getter(comparison)
if center is None:
if not is_positive_finite(center):
continue
center_multiplier, center_units = select_duration_units(center)
@@ -2411,7 +2571,7 @@ def compare_benches(
bulk_debug_rows=None,
):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
comparison_thresholds = get_default_thresholds()
if plot_along:
import matplotlib.pyplot as plt
@@ -2511,23 +2671,42 @@ def compare_benches(
axis_value_name = axis_value["name"]
row.append(format_axis_value(axis_value_name, axis_value, axes))
cmp_summaries = cmp_state["summaries"]
ref_summaries = ref_state["summaries"]
if not ref_summaries or not cmp_summaries:
continue
cmp_summaries = get_state_summaries(cmp_state)
ref_summaries = get_state_summaries(ref_state)
# TODO: Use other timings, too. Maybe multiple rows, with a
# "Timing" column + values "CPU/GPU/Batch"?
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(
ref_gpu_time, cmp_gpu_time, comparison_thresholds
missing_summaries_decision = missing_state_summaries_decision(
ref_state, cmp_state
)
if missing_summaries_decision is None:
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(
ref_gpu_time, cmp_gpu_time, comparison_thresholds
)
else:
cmp_gpu_time = (
extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
if cmp_summaries
else make_empty_gpu_timing_data()
)
ref_gpu_time = (
extract_gpu_timing_data(ref_summaries, ref_json_dir)
if ref_summaries
else make_empty_gpu_timing_data()
)
comparison = make_unavailable_timing_comparison(
missing_summaries_decision, ref_gpu_time, cmp_gpu_time
)
if comparison is None:
continue
if plot_along:
if (
plot_along
and is_positive_finite(comparison.ref_time)
and is_positive_finite(comparison.cmp_time)
):
axis_name_parts = []
axis_value = None
for av in axis_values:
@@ -2554,7 +2733,10 @@ def compare_benches(
)
run_data.stats.record(comparison.status, comparison.reason)
if abs(comparison.frac_diff) >= threshold:
if comparison.status == ComparisonStatus.UNKNOWN or (
comparison.frac_diff is not None
and abs(comparison.frac_diff) >= threshold
):
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
append_display_row(row, comparison, no_color, display)