Improve nvbench-compare interval display readability

Add compact reason labels for explain-mode tables while keeping canonical
reason codes in the undecided summary. Emit a one-line legend only for
non-trivial abbreviations.

Refine interval displays so timing values align across table rows:
  - align Lo/Ce/Hi values in explain mode
  - align center values in intervals mode when some rows lack interval bounds
  - avoid repeating units when center and interval deltas use the same unit

Add a Change column for non-legacy displays so FAST/SLOW rows show the
signed interval-bound relative change, while SAME and UNDECIDED rows remain
blank.

Extend nvbench_compare tests to cover reason legend filtering, interval
alignment, missing-interval alignment, and Change column formatting.
This commit is contained in:
Oleksandr Pavlyk
2026-06-04 15:33:13 -05:00
parent 70d728cba6
commit 7a582db94e
2 changed files with 396 additions and 36 deletions

View File

@@ -441,6 +441,47 @@ class DecisionReason:
severity: float = 0.0
REASON_DISPLAY_CODES = {
"bulk_cycle_data_unusable": "bc-bad",
"bulk_cycle_gap_not_confirmed": "bc-gap-miss",
"bulk_cycle_same": "bc-same",
"bulk_cycle_support_mismatch": "bc-sup-miss",
"bulk_data_unavailable": "bulk-miss",
"bulk_same": "bulk-same",
"bulk_time_data_unusable": "bt-bad",
"bulk_time_same": "bt-same",
"bulk_time_support_mismatch": "bt-sup-miss",
"centers_not_close": "centers-far",
"clear_gap_confirmed_by_bulk_cycles": "bc-gap",
"clear_gap_confirmed_by_summary_cycles": "sc-gap",
"cycle_same_not_confirmed": "sc-same-miss",
"invalid_clock_rate": "clk-bad",
"missing_clock_rate": "clk-miss",
"missing_interval": "int-miss",
"no_clear_gap": "no-gap",
"noise_too_high": "noise-high",
"noise_unavailable": "noise-miss",
"same_confirmed_by_cycles": "sc-same",
"same_summary": "sum-same",
"same_without_clock_rate": "same-no-clk",
"summary_cycle_gap_not_confirmed": "sc-gap-miss",
"weak_interval_overlap": "weak-overlap",
}
def format_reason_display_code(code):
return REASON_DISPLAY_CODES.get(code, code)
def format_reason_legend_entries(reason_legend):
entries = []
for code, reason_summary in sorted(reason_legend.items()):
if code == reason_summary.canonical_code.replace("_", "-"):
continue
entries.append(f"{code} = {reason_summary.canonical_code}")
return entries
@dataclass(frozen=True)
class TimingDecision:
status: ComparisonStatus
@@ -469,6 +510,7 @@ class SummaryComparison:
@dataclass
class DecisionReasonSummary:
count: int = 0
canonical_code: str = ""
message: str = ""
severity: float = 0.0
@@ -482,23 +524,43 @@ class ComparisonStats:
undecided_count: int = 0
unknown_count: int = 0
undecided_reasons: dict[str, DecisionReasonSummary] = field(default_factory=dict)
reason_legend: dict[str, DecisionReasonSummary] = field(default_factory=dict)
@staticmethod
def record_reason_summary(
summaries: dict[str, DecisionReasonSummary],
reason: DecisionReason,
*,
use_display_code,
) -> None:
display_code = (
format_reason_display_code(reason.code) if use_display_code else reason.code
)
summary = summaries.setdefault(
display_code, DecisionReasonSummary(canonical_code=reason.code)
)
if summary.count == 0 or reason.severity > summary.severity:
summary.canonical_code = reason.code
summary.message = reason.message
summary.severity = reason.severity
summary.count += 1
def record(
self, status: ComparisonStatus, reason: DecisionReason | None = None
) -> None:
self.config_count += 1
if reason is not None:
self.record_reason_summary(
self.reason_legend, reason, use_display_code=True
)
if status == ComparisonStatus.UNKNOWN:
self.unknown_count += 1
elif status == ComparisonStatus.UNDECIDED:
self.undecided_count += 1
if reason is not None:
summary = self.undecided_reasons.setdefault(
reason.code, DecisionReasonSummary()
self.record_reason_summary(
self.undecided_reasons, reason, use_display_code=False
)
if summary.count == 0 or reason.severity > summary.severity:
summary.message = reason.message
summary.severity = reason.severity
summary.count += 1
elif status == ComparisonStatus.SAME:
self.pass_count += 1
elif status == ComparisonStatus.FAST:
@@ -1935,20 +1997,42 @@ def format_duration_range(bounds):
return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}"
def format_timing_with_interval(center, interval):
def format_timing_with_interval(
center, interval, *, center_width=None, interval_width=None
):
if center is None:
return "n/a"
if interval is None:
if center_width is not None and interval_width is not None:
center_multiplier, center_units = select_duration_units(center)
center_text = f"{center * center_multiplier:0.3f}"
center_text = f"{center_text:>{center_width}}"
if interval_width == 0:
return f"{center_text} {center_units}"
return f"{center_text} {' ' * interval_width} {center_units}"
return format_duration(center)
lower_delta = interval.lower - interval.center
upper_delta = interval.upper - interval.center
multiplier, units = select_duration_units(lower_delta, upper_delta)
precision = duration_precision_for_center(center, multiplier)
center_multiplier, center_units = select_duration_units(center)
delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta)
precision = duration_precision_for_center(center, delta_multiplier)
if center_units == delta_units:
center_text = f"{center * center_multiplier:0.3f}"
interval_text = (
f"[{lower_delta * delta_multiplier:+0.{precision}f}, "
f"{upper_delta * delta_multiplier:+0.{precision}f}]"
)
if center_width is not None:
center_text = f"{center_text:>{center_width}}"
if interval_width is not None:
interval_text = f"{interval_text:>{interval_width}}"
return f"{center_text} {interval_text} {center_units}"
return (
f"{format_duration(center)} "
f"[{lower_delta * multiplier:+0.{precision}f}, "
f"{upper_delta * multiplier:+0.{precision}f}] {units}"
f"[{lower_delta * delta_multiplier:+0.{precision}f}, "
f"{upper_delta * delta_multiplier:+0.{precision}f}] {delta_units}"
)
@@ -1964,22 +2048,63 @@ def longest_common_prefix(strings):
return prefix
def format_timing_with_explicit_interval(center, interval):
def common_numeric_prefix_is_useful(prefix):
if "." not in prefix:
return False
numeric_digits = sum(char.isdigit() for char in prefix)
fractional_prefix = prefix.split(".", 1)[1]
fractional_digits = sum(char.isdigit() for char in fractional_prefix)
return numeric_digits >= 2 and fractional_digits >= 1
def align_interval_values(values, widths=None):
if widths is None:
widths = [max(len(value) for value in values)] * len(values)
return [f"{value:>{width}}" for value, width in zip(values, widths, strict=True)]
def explicit_interval_values(center, interval):
multiplier, units = select_duration_units(
interval.lower, interval.center, interval.upper
)
return (
[
f"{interval.lower * multiplier:0.3f}",
f"{interval.center * multiplier:0.3f}",
f"{interval.upper * multiplier:0.3f}",
],
units,
)
def explicit_interval_column_widths(comparisons, center_getter, interval_getter):
widths = [0, 0, 0]
for comparison in comparisons:
center = center_getter(comparison)
interval = interval_getter(comparison)
if center is None or interval is None:
continue
values, _ = explicit_interval_values(center, interval)
prefix = longest_common_prefix(values)
if common_numeric_prefix_is_useful(prefix):
continue
widths = [max(width, len(value)) for width, value in zip(widths, values)]
return widths
def format_timing_with_explicit_interval(center, interval, *, value_widths=None):
if center is None:
return "n/a"
if interval is None:
return format_duration(center)
multiplier, units = select_duration_units(
interval.lower, interval.center, interval.upper
)
values = [
f"{interval.lower * multiplier:0.3f}",
f"{interval.center * multiplier:0.3f}",
f"{interval.upper * multiplier:0.3f}",
]
values, units = explicit_interval_values(center, interval)
prefix = longest_common_prefix(values)
if not prefix:
if not common_numeric_prefix_is_useful(prefix):
values = align_interval_values(values, value_widths)
return f"[{values[0]} | {values[1]} | {values[2]}] {units}"
suffixes = [value[len(prefix) :] for value in values]
@@ -2007,6 +2132,12 @@ def format_percentage_bounds(bounds, status):
return f"in [{lower * 100.0:+0.1f}%, {upper * 100.0:+0.1f}%]"
def format_change(comparison):
if comparison.status not in {ComparisonStatus.FAST, ComparisonStatus.SLOW}:
return ""
return format_percentage_bounds(comparison.frac_diff_interval, comparison.status)
def get_display_headers(display):
if display == "legacy":
return (
@@ -2024,18 +2155,19 @@ def get_display_headers(display):
if display == "explain":
return (
[
"Ref [L | C | H]",
"Cmp [L | C | H]",
"Ref [Lo | Ce | Hi]",
"Cmp [Lo | Ce | Hi]",
"Ref Noise",
"Cmp Noise",
"Reason",
"Change",
"Status",
],
["right", "right", "right", "right", "left", "center"],
["right", "right", "right", "right", "left", "right", "center"],
)
return (
["Ref", "Cmp", "Status"],
["right", "right", "center"],
["Ref", "Cmp", "Change", "Status"],
["right", "right", "right", "center"],
)
@@ -2065,10 +2197,89 @@ def append_display_row(row, comparison, no_color, display):
)
row.append(format_percentage(comparison.ref_noise))
row.append(format_percentage(comparison.cmp_noise))
row.append(comparison.reason.code)
row.append(format_reason_display_code(comparison.reason.code))
row.append(format_change(comparison))
row.append(colorize_comparison_status(comparison.status, no_color))
def align_explain_interval_columns(rows, comparisons, axis_count):
ref_widths = explicit_interval_column_widths(
comparisons,
lambda comparison: comparison.ref_time,
lambda comparison: comparison.ref_interval,
)
cmp_widths = explicit_interval_column_widths(
comparisons,
lambda comparison: comparison.cmp_time,
lambda comparison: comparison.cmp_interval,
)
for row, comparison in zip(rows, comparisons, strict=True):
row[axis_count] = format_timing_with_explicit_interval(
comparison.ref_time, comparison.ref_interval, value_widths=ref_widths
)
row[axis_count + 1] = format_timing_with_explicit_interval(
comparison.cmp_time, comparison.cmp_interval, value_widths=cmp_widths
)
def timing_interval_column_widths(comparisons, center_getter, interval_getter):
center_width = 0
interval_width = 0
for comparison in comparisons:
center = center_getter(comparison)
if center is None:
continue
center_multiplier, center_units = select_duration_units(center)
center_text = f"{center * center_multiplier:0.3f}"
center_width = max(center_width, len(center_text))
interval = interval_getter(comparison)
if interval is None:
continue
lower_delta = interval.lower - interval.center
upper_delta = interval.upper - interval.center
delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta)
if center_units != delta_units:
continue
precision = duration_precision_for_center(center, delta_multiplier)
interval_text = (
f"[{lower_delta * delta_multiplier:+0.{precision}f}, "
f"{upper_delta * delta_multiplier:+0.{precision}f}]"
)
interval_width = max(interval_width, len(interval_text))
return center_width, interval_width
def align_timing_interval_columns(rows, comparisons, axis_count):
ref_center_width, ref_interval_width = timing_interval_column_widths(
comparisons,
lambda comparison: comparison.ref_time,
lambda comparison: comparison.ref_interval,
)
cmp_center_width, cmp_interval_width = timing_interval_column_widths(
comparisons,
lambda comparison: comparison.cmp_time,
lambda comparison: comparison.cmp_interval,
)
for row, comparison in zip(rows, comparisons, strict=True):
row[axis_count] = format_timing_with_interval(
comparison.ref_time,
comparison.ref_interval,
center_width=ref_center_width,
interval_width=ref_interval_width,
)
row[axis_count + 1] = format_timing_with_interval(
comparison.cmp_time,
comparison.cmp_interval,
center_width=cmp_center_width,
interval_width=cmp_interval_width,
)
def has_finite_noise(noise):
return noise is not None and math.isfinite(noise)
@@ -2272,6 +2483,7 @@ def compare_benches(
)
rows = []
row_comparisons = []
plot_data: dict[str, dict[str, dict[float, float | None]]] = {
"cmp": {},
"ref": {},
@@ -2344,6 +2556,7 @@ def compare_benches(
append_display_row(row, comparison, no_color, display)
rows.append(row)
row_comparisons.append(comparison)
if bulk_debug_rows is not None:
bulk_debug_rows.append(
make_bulk_debug_row(
@@ -2386,6 +2599,10 @@ def compare_benches(
if len(rows) == 0:
continue
if display == "explain":
align_explain_interval_columns(rows, row_comparisons, len(axes))
elif display == "intervals":
align_timing_interval_columns(rows, row_comparisons, len(axes))
cmp_device = find_device_by_id(cmp_device_id, run_data.cmp_devices)
ref_device = find_device_by_id(ref_device_id, run_data.ref_devices)
@@ -2748,6 +2965,10 @@ def main() -> int:
reverse=True,
):
print(f" - {code}: {reason_summary.count} ({reason_summary.message})")
if args.display == "explain" and stats.reason_legend:
legend_entries = format_reason_legend_entries(stats.reason_legend)
if legend_entries:
print(f" - Reason legend: {'; '.join(legend_entries)}")
print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}")
try:
write_bulk_debug_python(bulk_debug_output, bulk_debug_rows or [])