mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-07-01 19:57:41 +00:00
Explicitly handle unavailable timings in nvbench-compare
Treat matched states with unusable timing data as UNKNOWN instead of dropping them from the comparison. This includes missing, non-finite, or non-positive timing centers, skipped states, and states with missing GPU timing summaries. Add explicit reason codes for these cases so the summary points users at the underlying data issue. Preserve available timing data from the other side when only one side is missing, and render unavailable durations as n/a in all display modes. Also sort values returned by np.unique_counts before nearest-neighbor coverage checks so the distance algorithm receives ordered inputs. Add regression coverage for UNKNOWN counting, skipped states, missing summaries, unavailable center formatting, and the updated coverage helper.
This commit is contained in:
@@ -160,6 +160,10 @@ COMPARISON_THRESHOLD_RANGES = {
|
||||
}
|
||||
|
||||
|
||||
def get_default_thresholds() -> ComparisonThresholds:
|
||||
return COMPARISON_THRESHOLD_PRESETS[COMPARISON_DEFAULT_PRESET]
|
||||
|
||||
|
||||
def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
|
||||
try:
|
||||
return COMPARISON_THRESHOLD_PRESETS[preset_name]
|
||||
@@ -169,11 +173,14 @@ def get_comparison_thresholds(preset_name: str) -> ComparisonThresholds:
|
||||
|
||||
def load_toml_module() -> TomlModule:
|
||||
try:
|
||||
# built-in Python module, added in 3.11 via PEP 680
|
||||
import tomllib
|
||||
|
||||
return tomllib
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
# third-party library for Python 3.10
|
||||
# note Python 3.10 EOL date is Oct. 31, 2026
|
||||
import tomli
|
||||
|
||||
return tomli
|
||||
@@ -465,6 +472,11 @@ REASON_DISPLAY_CODES = {
|
||||
"same_summary": "sum-same",
|
||||
"same_without_clock_rate": "same-no-clk",
|
||||
"summary_cycle_gap_not_confirmed": "sc-gap-miss",
|
||||
"gpu_timing_summaries_missing": "summ-miss",
|
||||
"state_skipped": "state-skip",
|
||||
"timing_center_missing": "center-miss",
|
||||
"timing_center_nonfinite": "center-nonfin",
|
||||
"timing_center_nonpositive": "center-nonpos",
|
||||
"weak_interval_overlap": "weak-overlap",
|
||||
}
|
||||
|
||||
@@ -494,12 +506,12 @@ class SummaryComparison:
|
||||
cmp_interval: TimingInterval | None
|
||||
ref_estimate: TimeEstimate
|
||||
cmp_estimate: TimeEstimate
|
||||
ref_time: float
|
||||
cmp_time: float
|
||||
ref_time: float | None
|
||||
cmp_time: float | None
|
||||
ref_noise: float | None
|
||||
cmp_noise: float | None
|
||||
diff: float
|
||||
frac_diff: float
|
||||
diff: float | None
|
||||
frac_diff: float | None
|
||||
diff_interval: tuple[float, float] | None
|
||||
frac_diff_interval: tuple[float, float] | None
|
||||
max_noise: float | None
|
||||
@@ -915,6 +927,21 @@ def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float3
|
||||
)
|
||||
|
||||
|
||||
def make_empty_gpu_timing_data():
|
||||
return GpuTimingData(
|
||||
minimum=None,
|
||||
maximum=None,
|
||||
mean=None,
|
||||
stdev=None,
|
||||
stdev_relative=None,
|
||||
first_quartile=None,
|
||||
median=None,
|
||||
third_quartile=None,
|
||||
interquartile_range=None,
|
||||
interquartile_range_relative=None,
|
||||
)
|
||||
|
||||
|
||||
def resolve_bulk_source_filename(source: Float32BinarySource | None) -> str | None:
|
||||
if source is None:
|
||||
return None
|
||||
@@ -1299,9 +1326,19 @@ def format_support_filter_info(filter_info):
|
||||
return "off(disabled)"
|
||||
|
||||
|
||||
def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds):
|
||||
ref_unique, ref_counts = np.unique_counts(ref_values)
|
||||
cmp_unique, cmp_counts = np.unique_counts(cmp_values)
|
||||
def sorted_unique_counts(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
unique_values, unique_counts = np.unique_counts(values)
|
||||
# unique is not guaranteed to return sorted values
|
||||
# make sure to order them
|
||||
sorting_indices = np.argsort(unique_values)
|
||||
return unique_values[sorting_indices], unique_counts[sorting_indices]
|
||||
|
||||
|
||||
def compute_nearest_neighbor_coverages(
|
||||
ref_values: np.ndarray, cmp_values: np.ndarray, thresholds: ComparisonThresholds
|
||||
) -> dict[str, Any] | None:
|
||||
ref_unique, ref_counts = sorted_unique_counts(ref_values)
|
||||
cmp_unique, cmp_counts = sorted_unique_counts(cmp_values)
|
||||
if len(ref_unique) == 0 or len(cmp_unique) == 0:
|
||||
return None
|
||||
|
||||
@@ -1697,27 +1734,67 @@ def compute_common_time_estimates(ref_timing, cmp_timing):
|
||||
)
|
||||
|
||||
|
||||
def unusable_timing_center_decision(ref_time, cmp_time):
|
||||
if ref_time is None or cmp_time is None:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNKNOWN,
|
||||
"timing_center_missing",
|
||||
"timing center is missing",
|
||||
)
|
||||
if not math.isfinite(ref_time) or not math.isfinite(cmp_time):
|
||||
return make_decision(
|
||||
ComparisonStatus.UNKNOWN,
|
||||
"timing_center_nonfinite",
|
||||
"timing center is non-finite",
|
||||
)
|
||||
if ref_time <= 0.0 or cmp_time <= 0.0:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNKNOWN,
|
||||
"timing_center_nonpositive",
|
||||
"timing center is non-positive",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def make_unavailable_timing_comparison(decision, ref_timing, cmp_timing):
|
||||
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
|
||||
return SummaryComparison(
|
||||
ref_interval=compute_timing_interval(ref_timing),
|
||||
cmp_interval=compute_timing_interval(cmp_timing),
|
||||
ref_estimate=ref_estimate,
|
||||
cmp_estimate=cmp_estimate,
|
||||
ref_time=ref_estimate.center,
|
||||
cmp_time=cmp_estimate.center,
|
||||
ref_noise=ref_estimate.relative_dispersion,
|
||||
cmp_noise=cmp_estimate.relative_dispersion,
|
||||
diff=None,
|
||||
frac_diff=None,
|
||||
diff_interval=None,
|
||||
frac_diff_interval=None,
|
||||
max_noise=None,
|
||||
status=decision.status,
|
||||
reason=decision.reason,
|
||||
)
|
||||
|
||||
|
||||
def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
|
||||
if comparison_thresholds is None:
|
||||
comparison_thresholds = ComparisonThresholds()
|
||||
comparison_thresholds = get_default_thresholds()
|
||||
|
||||
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
|
||||
|
||||
cmp_time = cmp_estimate.center
|
||||
ref_time = ref_estimate.center
|
||||
|
||||
if cmp_time is None or ref_time is None:
|
||||
return None
|
||||
|
||||
if not math.isfinite(cmp_time) or not math.isfinite(ref_time):
|
||||
return None
|
||||
|
||||
if cmp_time <= 0.0 or ref_time <= 0.0:
|
||||
return None
|
||||
|
||||
cmp_noise = cmp_estimate.relative_dispersion
|
||||
ref_noise = ref_estimate.relative_dispersion
|
||||
|
||||
unusable_center_decision = unusable_timing_center_decision(ref_time, cmp_time)
|
||||
if unusable_center_decision is not None:
|
||||
return make_unavailable_timing_comparison(
|
||||
unusable_center_decision, ref_timing, cmp_timing
|
||||
)
|
||||
|
||||
ref_interval = compute_timing_interval(ref_timing)
|
||||
cmp_interval = compute_timing_interval(cmp_timing)
|
||||
diff = cmp_time - ref_time
|
||||
@@ -1769,6 +1846,53 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
|
||||
)
|
||||
|
||||
|
||||
def get_state_summaries(state: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||||
summaries = state.get("summaries")
|
||||
return summaries if summaries is not None else []
|
||||
|
||||
|
||||
def state_has_summaries(state):
|
||||
return bool(state.get("summaries"))
|
||||
|
||||
|
||||
def format_skipped_state_reason(side, state):
|
||||
reason = state.get("skip_reason")
|
||||
if reason:
|
||||
return f"{side} state skipped: {reason}"
|
||||
return f"{side} state skipped"
|
||||
|
||||
|
||||
def missing_state_summaries_decision(ref_state, cmp_state):
|
||||
skipped_messages = []
|
||||
if ref_state.get("is_skipped"):
|
||||
skipped_messages.append(format_skipped_state_reason("reference", ref_state))
|
||||
if cmp_state.get("is_skipped"):
|
||||
skipped_messages.append(format_skipped_state_reason("compare", cmp_state))
|
||||
if skipped_messages:
|
||||
return make_decision(
|
||||
ComparisonStatus.UNKNOWN,
|
||||
"state_skipped",
|
||||
"; ".join(skipped_messages),
|
||||
)
|
||||
|
||||
missing_sides = []
|
||||
if not state_has_summaries(ref_state):
|
||||
missing_sides.append("reference")
|
||||
if not state_has_summaries(cmp_state):
|
||||
missing_sides.append("compare")
|
||||
if not missing_sides:
|
||||
return None
|
||||
if len(missing_sides) == 2:
|
||||
message = "reference and compare GPU timing summaries are missing"
|
||||
else:
|
||||
message = f"{missing_sides[0]} GPU timing summaries are missing"
|
||||
return make_decision(
|
||||
ComparisonStatus.UNKNOWN,
|
||||
"gpu_timing_summaries_missing",
|
||||
message,
|
||||
)
|
||||
|
||||
|
||||
def find_matching_bench(needle, haystack):
|
||||
for hay in haystack:
|
||||
if hay["name"] == needle["name"]:
|
||||
@@ -1959,7 +2083,18 @@ def matching_axis_filters(state, axis_filter_groups):
|
||||
)
|
||||
|
||||
|
||||
def format_duration(seconds):
|
||||
def is_finite_number(value):
|
||||
return value is not None and math.isfinite(value)
|
||||
|
||||
|
||||
def format_duration(seconds, *, allow_negative=False, allow_zero=False):
|
||||
if (
|
||||
not is_finite_number(seconds)
|
||||
or (seconds < 0.0 and not allow_negative)
|
||||
or (seconds == 0.0 and not allow_zero)
|
||||
):
|
||||
return "n/a"
|
||||
|
||||
if seconds >= 1:
|
||||
multiplier = 1.0
|
||||
units = "s"
|
||||
@@ -1976,6 +2111,10 @@ def format_duration(seconds):
|
||||
|
||||
|
||||
def select_duration_units(*seconds_values):
|
||||
seconds_values = [value for value in seconds_values if is_finite_number(value)]
|
||||
if not seconds_values:
|
||||
return 1e6, "us"
|
||||
|
||||
max_abs_seconds = max(abs(value) for value in seconds_values)
|
||||
if max_abs_seconds >= 1:
|
||||
return 1.0, "s"
|
||||
@@ -1985,6 +2124,9 @@ def select_duration_units(*seconds_values):
|
||||
|
||||
|
||||
def duration_precision_for_center(center, delta_multiplier):
|
||||
if not is_finite_number(center):
|
||||
return 3
|
||||
|
||||
center_multiplier, _ = select_duration_units(center)
|
||||
center_quantum = 10.0**-3 * (delta_multiplier / center_multiplier)
|
||||
if center_quantum >= 1.0:
|
||||
@@ -1996,6 +2138,9 @@ def format_duration_range(bounds):
|
||||
if bounds is None:
|
||||
return "n/a"
|
||||
lower, upper = bounds
|
||||
if not is_finite_number(lower) or not is_finite_number(upper):
|
||||
return "n/a"
|
||||
|
||||
multiplier, units = select_duration_units(lower, upper)
|
||||
return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}"
|
||||
|
||||
@@ -2003,7 +2148,7 @@ def format_duration_range(bounds):
|
||||
def format_timing_with_interval(
|
||||
center, interval, *, center_width=None, interval_width=None
|
||||
):
|
||||
if center is None:
|
||||
if center is None or not is_positive_finite(center):
|
||||
return "n/a"
|
||||
if interval is None:
|
||||
if center_width is not None and interval_width is not None:
|
||||
@@ -2068,6 +2213,13 @@ def align_interval_values(values, widths=None):
|
||||
|
||||
|
||||
def explicit_interval_values(center, interval):
|
||||
if (
|
||||
not is_positive_finite(interval.lower)
|
||||
or not is_positive_finite(interval.center)
|
||||
or not is_positive_finite(interval.upper)
|
||||
):
|
||||
return None
|
||||
|
||||
multiplier, units = select_duration_units(
|
||||
interval.lower, interval.center, interval.upper
|
||||
)
|
||||
@@ -2086,10 +2238,13 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
|
||||
for comparison in comparisons:
|
||||
center = center_getter(comparison)
|
||||
interval = interval_getter(comparison)
|
||||
if center is None or interval is None:
|
||||
if not is_positive_finite(center) or interval is None:
|
||||
continue
|
||||
|
||||
values, _ = explicit_interval_values(center, interval)
|
||||
interval_values = explicit_interval_values(center, interval)
|
||||
if interval_values is None:
|
||||
continue
|
||||
values, _ = interval_values
|
||||
prefix = longest_common_prefix(values)
|
||||
if common_numeric_prefix_is_useful(prefix):
|
||||
continue
|
||||
@@ -2099,12 +2254,15 @@ def explicit_interval_column_widths(comparisons, center_getter, interval_getter)
|
||||
|
||||
|
||||
def format_timing_with_explicit_interval(center, interval, *, value_widths=None):
|
||||
if center is None:
|
||||
if center is None or not is_positive_finite(center):
|
||||
return "n/a"
|
||||
if interval is None:
|
||||
return format_duration(center)
|
||||
|
||||
values, units = explicit_interval_values(center, interval)
|
||||
interval_values = explicit_interval_values(center, interval)
|
||||
if interval_values is None:
|
||||
return "n/a"
|
||||
values, units = interval_values
|
||||
prefix = longest_common_prefix(values)
|
||||
if not common_numeric_prefix_is_useful(prefix):
|
||||
values = align_interval_values(values, value_widths)
|
||||
@@ -2180,7 +2338,9 @@ def append_display_row(row, comparison, no_color, display):
|
||||
row.append(format_percentage(comparison.ref_noise))
|
||||
row.append(format_duration(comparison.cmp_time))
|
||||
row.append(format_percentage(comparison.cmp_noise))
|
||||
row.append(format_duration(comparison.diff))
|
||||
row.append(
|
||||
format_duration(comparison.diff, allow_negative=True, allow_zero=True)
|
||||
)
|
||||
row.append(format_percentage(comparison.frac_diff))
|
||||
row.append(colorize_comparison_status(comparison.status, no_color))
|
||||
return
|
||||
@@ -2230,7 +2390,7 @@ def timing_interval_column_widths(comparisons, center_getter, interval_getter):
|
||||
interval_width = 0
|
||||
for comparison in comparisons:
|
||||
center = center_getter(comparison)
|
||||
if center is None:
|
||||
if not is_positive_finite(center):
|
||||
continue
|
||||
|
||||
center_multiplier, center_units = select_duration_units(center)
|
||||
@@ -2411,7 +2571,7 @@ def compare_benches(
|
||||
bulk_debug_rows=None,
|
||||
):
|
||||
if comparison_thresholds is None:
|
||||
comparison_thresholds = ComparisonThresholds()
|
||||
comparison_thresholds = get_default_thresholds()
|
||||
|
||||
if plot_along:
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -2511,23 +2671,42 @@ def compare_benches(
|
||||
axis_value_name = axis_value["name"]
|
||||
row.append(format_axis_value(axis_value_name, axis_value, axes))
|
||||
|
||||
cmp_summaries = cmp_state["summaries"]
|
||||
ref_summaries = ref_state["summaries"]
|
||||
|
||||
if not ref_summaries or not cmp_summaries:
|
||||
continue
|
||||
cmp_summaries = get_state_summaries(cmp_state)
|
||||
ref_summaries = get_state_summaries(ref_state)
|
||||
|
||||
# TODO: Use other timings, too. Maybe multiple rows, with a
|
||||
# "Timing" column + values "CPU/GPU/Batch"?
|
||||
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
|
||||
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
|
||||
comparison = compare_gpu_timings(
|
||||
ref_gpu_time, cmp_gpu_time, comparison_thresholds
|
||||
missing_summaries_decision = missing_state_summaries_decision(
|
||||
ref_state, cmp_state
|
||||
)
|
||||
if missing_summaries_decision is None:
|
||||
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
|
||||
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
|
||||
comparison = compare_gpu_timings(
|
||||
ref_gpu_time, cmp_gpu_time, comparison_thresholds
|
||||
)
|
||||
else:
|
||||
cmp_gpu_time = (
|
||||
extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
|
||||
if cmp_summaries
|
||||
else make_empty_gpu_timing_data()
|
||||
)
|
||||
ref_gpu_time = (
|
||||
extract_gpu_timing_data(ref_summaries, ref_json_dir)
|
||||
if ref_summaries
|
||||
else make_empty_gpu_timing_data()
|
||||
)
|
||||
comparison = make_unavailable_timing_comparison(
|
||||
missing_summaries_decision, ref_gpu_time, cmp_gpu_time
|
||||
)
|
||||
if comparison is None:
|
||||
continue
|
||||
|
||||
if plot_along:
|
||||
if (
|
||||
plot_along
|
||||
and is_positive_finite(comparison.ref_time)
|
||||
and is_positive_finite(comparison.cmp_time)
|
||||
):
|
||||
axis_name_parts = []
|
||||
axis_value = None
|
||||
for av in axis_values:
|
||||
@@ -2554,7 +2733,10 @@ def compare_benches(
|
||||
)
|
||||
|
||||
run_data.stats.record(comparison.status, comparison.reason)
|
||||
if abs(comparison.frac_diff) >= threshold:
|
||||
if comparison.status == ComparisonStatus.UNKNOWN or (
|
||||
comparison.frac_diff is not None
|
||||
and abs(comparison.frac_diff) >= threshold
|
||||
):
|
||||
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
|
||||
append_display_row(row, comparison, no_color, display)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user