From 7a582db94ed85c19ead1a339404fba5c8b60c970 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:33:13 -0500 Subject: [PATCH] Improve nvbench-compare interval display readability Add compact reason labels for explain-mode tables while keeping canonical reason codes in the undecided summary. Emit a one-line legend only for non-trivial abbreviations. Refine interval displays so timing values align across table rows: - align Lo/Ce/Hi values in explain mode - align center values in intervals mode when some rows lack interval bounds - avoid repeating units when center and interval deltas use the same unit Add a Change column for non-legacy displays so FAST/SLOW rows show the signed interval-bound relative change, while SAME and UNDECIDED rows remain blank. Extend nvbench_compare tests to cover reason legend filtering, interval alignment, missing-interval alignment, and Change column formatting. --- python/scripts/nvbench_compare.py | 275 +++++++++++++++++++++++++--- python/test/test_nvbench_compare.py | 157 +++++++++++++++- 2 files changed, 396 insertions(+), 36 deletions(-) diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index 9ef895a..230e48d 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -441,6 +441,47 @@ class DecisionReason: severity: float = 0.0 +REASON_DISPLAY_CODES = { + "bulk_cycle_data_unusable": "bc-bad", + "bulk_cycle_gap_not_confirmed": "bc-gap-miss", + "bulk_cycle_same": "bc-same", + "bulk_cycle_support_mismatch": "bc-sup-miss", + "bulk_data_unavailable": "bulk-miss", + "bulk_same": "bulk-same", + "bulk_time_data_unusable": "bt-bad", + "bulk_time_same": "bt-same", + "bulk_time_support_mismatch": "bt-sup-miss", + "centers_not_close": "centers-far", + "clear_gap_confirmed_by_bulk_cycles": "bc-gap", + "clear_gap_confirmed_by_summary_cycles": "sc-gap", + "cycle_same_not_confirmed": "sc-same-miss", + "invalid_clock_rate": "clk-bad", + "missing_clock_rate": "clk-miss", + "missing_interval": "int-miss", + "no_clear_gap": "no-gap", + "noise_too_high": "noise-high", + "noise_unavailable": "noise-miss", + "same_confirmed_by_cycles": "sc-same", + "same_summary": "sum-same", + "same_without_clock_rate": "same-no-clk", + "summary_cycle_gap_not_confirmed": "sc-gap-miss", + "weak_interval_overlap": "weak-overlap", +} + + +def format_reason_display_code(code): + return REASON_DISPLAY_CODES.get(code, code) + + +def format_reason_legend_entries(reason_legend): + entries = [] + for code, reason_summary in sorted(reason_legend.items()): + if code == reason_summary.canonical_code.replace("_", "-"): + continue + entries.append(f"{code} = {reason_summary.canonical_code}") + return entries + + @dataclass(frozen=True) class TimingDecision: status: ComparisonStatus @@ -469,6 +510,7 @@ class SummaryComparison: @dataclass class DecisionReasonSummary: count: int = 0 + canonical_code: str = "" message: str = "" severity: float = 0.0 @@ -482,23 +524,43 @@ class ComparisonStats: undecided_count: int = 0 unknown_count: int = 0 undecided_reasons: dict[str, DecisionReasonSummary] = field(default_factory=dict) + reason_legend: dict[str, DecisionReasonSummary] = field(default_factory=dict) + + @staticmethod + def record_reason_summary( + summaries: dict[str, DecisionReasonSummary], + reason: DecisionReason, + *, + use_display_code, + ) -> None: + display_code = ( + format_reason_display_code(reason.code) if use_display_code else reason.code + ) + summary = summaries.setdefault( + display_code, DecisionReasonSummary(canonical_code=reason.code) + ) + if summary.count == 0 or reason.severity > summary.severity: + summary.canonical_code = reason.code + summary.message = reason.message + summary.severity = reason.severity + summary.count += 1 def record( self, status: ComparisonStatus, reason: DecisionReason | None = None ) -> None: self.config_count += 1 + if reason is not None: + self.record_reason_summary( + self.reason_legend, reason, use_display_code=True + ) if status == ComparisonStatus.UNKNOWN: self.unknown_count += 1 elif status == ComparisonStatus.UNDECIDED: self.undecided_count += 1 if reason is not None: - summary = self.undecided_reasons.setdefault( - reason.code, DecisionReasonSummary() + self.record_reason_summary( + self.undecided_reasons, reason, use_display_code=False ) - if summary.count == 0 or reason.severity > summary.severity: - summary.message = reason.message - summary.severity = reason.severity - summary.count += 1 elif status == ComparisonStatus.SAME: self.pass_count += 1 elif status == ComparisonStatus.FAST: @@ -1935,20 +1997,42 @@ def format_duration_range(bounds): return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}" -def format_timing_with_interval(center, interval): +def format_timing_with_interval( + center, interval, *, center_width=None, interval_width=None +): if center is None: return "n/a" if interval is None: + if center_width is not None and interval_width is not None: + center_multiplier, center_units = select_duration_units(center) + center_text = f"{center * center_multiplier:0.3f}" + center_text = f"{center_text:>{center_width}}" + if interval_width == 0: + return f"{center_text} {center_units}" + return f"{center_text} {' ' * interval_width} {center_units}" return format_duration(center) lower_delta = interval.lower - interval.center upper_delta = interval.upper - interval.center - multiplier, units = select_duration_units(lower_delta, upper_delta) - precision = duration_precision_for_center(center, multiplier) + center_multiplier, center_units = select_duration_units(center) + delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta) + precision = duration_precision_for_center(center, delta_multiplier) + if center_units == delta_units: + center_text = f"{center * center_multiplier:0.3f}" + interval_text = ( + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}]" + ) + if center_width is not None: + center_text = f"{center_text:>{center_width}}" + if interval_width is not None: + interval_text = f"{interval_text:>{interval_width}}" + return f"{center_text} {interval_text} {center_units}" + return ( f"{format_duration(center)} " - f"[{lower_delta * multiplier:+0.{precision}f}, " - f"{upper_delta * multiplier:+0.{precision}f}] {units}" + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}] {delta_units}" ) @@ -1964,22 +2048,63 @@ def longest_common_prefix(strings): return prefix -def format_timing_with_explicit_interval(center, interval): +def common_numeric_prefix_is_useful(prefix): + if "." not in prefix: + return False + + numeric_digits = sum(char.isdigit() for char in prefix) + fractional_prefix = prefix.split(".", 1)[1] + fractional_digits = sum(char.isdigit() for char in fractional_prefix) + return numeric_digits >= 2 and fractional_digits >= 1 + + +def align_interval_values(values, widths=None): + if widths is None: + widths = [max(len(value) for value in values)] * len(values) + return [f"{value:>{width}}" for value, width in zip(values, widths, strict=True)] + + +def explicit_interval_values(center, interval): + multiplier, units = select_duration_units( + interval.lower, interval.center, interval.upper + ) + return ( + [ + f"{interval.lower * multiplier:0.3f}", + f"{interval.center * multiplier:0.3f}", + f"{interval.upper * multiplier:0.3f}", + ], + units, + ) + + +def explicit_interval_column_widths(comparisons, center_getter, interval_getter): + widths = [0, 0, 0] + for comparison in comparisons: + center = center_getter(comparison) + interval = interval_getter(comparison) + if center is None or interval is None: + continue + + values, _ = explicit_interval_values(center, interval) + prefix = longest_common_prefix(values) + if common_numeric_prefix_is_useful(prefix): + continue + + widths = [max(width, len(value)) for width, value in zip(widths, values)] + return widths + + +def format_timing_with_explicit_interval(center, interval, *, value_widths=None): if center is None: return "n/a" if interval is None: return format_duration(center) - multiplier, units = select_duration_units( - interval.lower, interval.center, interval.upper - ) - values = [ - f"{interval.lower * multiplier:0.3f}", - f"{interval.center * multiplier:0.3f}", - f"{interval.upper * multiplier:0.3f}", - ] + values, units = explicit_interval_values(center, interval) prefix = longest_common_prefix(values) - if not prefix: + if not common_numeric_prefix_is_useful(prefix): + values = align_interval_values(values, value_widths) return f"[{values[0]} | {values[1]} | {values[2]}] {units}" suffixes = [value[len(prefix) :] for value in values] @@ -2007,6 +2132,12 @@ def format_percentage_bounds(bounds, status): return f"in [{lower * 100.0:+0.1f}%, {upper * 100.0:+0.1f}%]" +def format_change(comparison): + if comparison.status not in {ComparisonStatus.FAST, ComparisonStatus.SLOW}: + return "" + return format_percentage_bounds(comparison.frac_diff_interval, comparison.status) + + def get_display_headers(display): if display == "legacy": return ( @@ -2024,18 +2155,19 @@ def get_display_headers(display): if display == "explain": return ( [ - "Ref [L | C | H]", - "Cmp [L | C | H]", + "Ref [Lo | Ce | Hi]", + "Cmp [Lo | Ce | Hi]", "Ref Noise", "Cmp Noise", "Reason", + "Change", "Status", ], - ["right", "right", "right", "right", "left", "center"], + ["right", "right", "right", "right", "left", "right", "center"], ) return ( - ["Ref", "Cmp", "Status"], - ["right", "right", "center"], + ["Ref", "Cmp", "Change", "Status"], + ["right", "right", "right", "center"], ) @@ -2065,10 +2197,89 @@ def append_display_row(row, comparison, no_color, display): ) row.append(format_percentage(comparison.ref_noise)) row.append(format_percentage(comparison.cmp_noise)) - row.append(comparison.reason.code) + row.append(format_reason_display_code(comparison.reason.code)) + row.append(format_change(comparison)) row.append(colorize_comparison_status(comparison.status, no_color)) +def align_explain_interval_columns(rows, comparisons, axis_count): + ref_widths = explicit_interval_column_widths( + comparisons, + lambda comparison: comparison.ref_time, + lambda comparison: comparison.ref_interval, + ) + cmp_widths = explicit_interval_column_widths( + comparisons, + lambda comparison: comparison.cmp_time, + lambda comparison: comparison.cmp_interval, + ) + for row, comparison in zip(rows, comparisons, strict=True): + row[axis_count] = format_timing_with_explicit_interval( + comparison.ref_time, comparison.ref_interval, value_widths=ref_widths + ) + row[axis_count + 1] = format_timing_with_explicit_interval( + comparison.cmp_time, comparison.cmp_interval, value_widths=cmp_widths + ) + + +def timing_interval_column_widths(comparisons, center_getter, interval_getter): + center_width = 0 + interval_width = 0 + for comparison in comparisons: + center = center_getter(comparison) + if center is None: + continue + + center_multiplier, center_units = select_duration_units(center) + center_text = f"{center * center_multiplier:0.3f}" + center_width = max(center_width, len(center_text)) + + interval = interval_getter(comparison) + if interval is None: + continue + + lower_delta = interval.lower - interval.center + upper_delta = interval.upper - interval.center + delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta) + if center_units != delta_units: + continue + + precision = duration_precision_for_center(center, delta_multiplier) + interval_text = ( + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}]" + ) + interval_width = max(interval_width, len(interval_text)) + + return center_width, interval_width + + +def align_timing_interval_columns(rows, comparisons, axis_count): + ref_center_width, ref_interval_width = timing_interval_column_widths( + comparisons, + lambda comparison: comparison.ref_time, + lambda comparison: comparison.ref_interval, + ) + cmp_center_width, cmp_interval_width = timing_interval_column_widths( + comparisons, + lambda comparison: comparison.cmp_time, + lambda comparison: comparison.cmp_interval, + ) + for row, comparison in zip(rows, comparisons, strict=True): + row[axis_count] = format_timing_with_interval( + comparison.ref_time, + comparison.ref_interval, + center_width=ref_center_width, + interval_width=ref_interval_width, + ) + row[axis_count + 1] = format_timing_with_interval( + comparison.cmp_time, + comparison.cmp_interval, + center_width=cmp_center_width, + interval_width=cmp_interval_width, + ) + + def has_finite_noise(noise): return noise is not None and math.isfinite(noise) @@ -2272,6 +2483,7 @@ def compare_benches( ) rows = [] + row_comparisons = [] plot_data: dict[str, dict[str, dict[float, float | None]]] = { "cmp": {}, "ref": {}, @@ -2344,6 +2556,7 @@ def compare_benches( append_display_row(row, comparison, no_color, display) rows.append(row) + row_comparisons.append(comparison) if bulk_debug_rows is not None: bulk_debug_rows.append( make_bulk_debug_row( @@ -2386,6 +2599,10 @@ def compare_benches( if len(rows) == 0: continue + if display == "explain": + align_explain_interval_columns(rows, row_comparisons, len(axes)) + elif display == "intervals": + align_timing_interval_columns(rows, row_comparisons, len(axes)) cmp_device = find_device_by_id(cmp_device_id, run_data.cmp_devices) ref_device = find_device_by_id(ref_device_id, run_data.ref_devices) @@ -2748,6 +2965,10 @@ def main() -> int: reverse=True, ): print(f" - {code}: {reason_summary.count} ({reason_summary.message})") + if args.display == "explain" and stats.reason_legend: + legend_entries = format_reason_legend_entries(stats.reason_legend) + if legend_entries: + print(f" - Reason legend: {'; '.join(legend_entries)}") print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}") try: write_bulk_debug_python(bulk_debug_output, bulk_debug_rows or []) diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py index ddc6f21..9f0fdee 100644 --- a/python/test/test_nvbench_compare.py +++ b/python/test/test_nvbench_compare.py @@ -876,6 +876,30 @@ def test_format_diff_and_percent_ranges(nvbench_compare): ) +def test_format_change_only_reports_fast_and_slow_rows(nvbench_compare): + fast = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.FAST, + frac_diff_interval=(-0.3, -0.05), + ) + slow = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.SLOW, + frac_diff_interval=(0.07, 0.55), + ) + same = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.SAME, + frac_diff_interval=(-0.01, 0.01), + ) + undecided = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.UNDECIDED, + frac_diff_interval=(-0.01, 0.01), + ) + + assert nvbench_compare.format_change(fast) == "<= -5.0%" + assert nvbench_compare.format_change(slow) == ">= +7.0%" + assert nvbench_compare.format_change(same) == "" + assert nvbench_compare.format_change(undecided) == "" + + def test_format_timing_with_interval(nvbench_compare): interval = nvbench_compare.TimingInterval( lower=0.002237, upper=0.002389, center=0.0023 @@ -885,6 +909,14 @@ def test_format_timing_with_interval(nvbench_compare): == "2.300 ms [-63, +89] us" ) + interval = nvbench_compare.TimingInterval( + lower=19.380e-6, upper=20.508e-6, center=19.944e-6 + ) + assert ( + nvbench_compare.format_timing_with_interval(19.944e-6, interval) + == "19.944 [-0.564, +0.564] us" + ) + def test_format_timing_with_explicit_interval(nvbench_compare): interval = nvbench_compare.TimingInterval( @@ -895,6 +927,93 @@ def test_format_timing_with_explicit_interval(nvbench_compare): == "1.4[34 | 46 | 58] ms" ) + interval = nvbench_compare.TimingInterval( + lower=18.400e-6, upper=19.464e-6, center=18.736e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(18.736e-6, interval) + == "[18.400 | 18.736 | 19.464] us" + ) + + interval = nvbench_compare.TimingInterval( + lower=19.380e-6, upper=20.508e-6, center=19.944e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(19.944e-6, interval) + == "[19.380 | 19.944 | 20.508] us" + ) + + interval = nvbench_compare.TimingInterval( + lower=99.094e-6, upper=100.882e-6, center=99.988e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(99.988e-6, interval) + == "[ 99.094 | 99.988 | 100.882] us" + ) + + +def test_align_explain_interval_columns_pads_values_across_rows(nvbench_compare): + rows = [["", ""], ["", ""]] + comparisons = [ + types.SimpleNamespace( + ref_time=19.944e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.380e-6, center=19.944e-6, upper=20.508e-6 + ), + cmp_time=97.712e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=96.849e-6, center=97.712e-6, upper=98.574e-6 + ), + ), + types.SimpleNamespace( + ref_time=103.466e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=102.739e-6, center=103.466e-6, upper=104.193e-6 + ), + cmp_time=101.868e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=100.916e-6, center=101.868e-6, upper=102.819e-6 + ), + ), + ] + + nvbench_compare.align_explain_interval_columns(rows, comparisons, axis_count=0) + + assert rows[0][0] == "[ 19.380 | 19.944 | 20.508] us" + assert rows[1][0] == "[102.739 | 103.466 | 104.193] us" + assert rows[0][1] == "[ 96.849 | 97.712 | 98.574] us" + assert rows[1][1] == "[100.916 | 101.868 | 102.819] us" + + +def test_align_timing_interval_columns_reserves_missing_interval_slot(nvbench_compare): + rows = [["", ""], ["", ""]] + comparisons = [ + types.SimpleNamespace( + ref_time=19.944e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.380e-6, center=19.944e-6, upper=20.508e-6 + ), + cmp_time=18.736e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=18.400e-6, center=18.736e-6, upper=19.464e-6 + ), + ), + types.SimpleNamespace( + ref_time=20.390e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.659e-6, center=20.390e-6, upper=21.121e-6 + ), + cmp_time=20.480e-6, + cmp_interval=None, + ), + ] + + nvbench_compare.align_timing_interval_columns(rows, comparisons, axis_count=0) + + cmp_interval_slot = len("[-0.336, +0.728]") + assert rows[0][1] == "18.736 [-0.336, +0.728] us" + assert rows[1][1] == f"20.480 {' ' * cmp_interval_slot} us" + def test_compare_gpu_timings_keeps_bulk_mismatch_undecided(nvbench_compare): ref_timing = make_gpu_timing_data( @@ -1059,6 +1178,19 @@ def test_comparison_stats_records_undecided_reason(nvbench_compare): assert summary.message == "more severe reason" +def test_reason_legend_omits_trivial_aliases(nvbench_compare): + reason_legend = { + "bulk-same": nvbench_compare.DecisionReasonSummary(canonical_code="bulk_same"), + "bt-sup-miss": nvbench_compare.DecisionReasonSummary( + canonical_code="bulk_time_support_mismatch" + ), + } + + assert nvbench_compare.format_reason_legend_entries(reason_legend) == [ + "bt-sup-miss = bulk_time_support_mismatch" + ] + + @pytest.mark.parametrize("ref_time, cmp_time", [(None, 1.0), (1.0, None), (0.0, 1.0)]) def test_compare_gpu_timings_rejects_unusable_centers( nvbench_compare, ref_time, cmp_time @@ -1463,12 +1595,15 @@ def test_main_prints_undecided_reason_summary(monkeypatch, capsys, nvbench_compa return ref_root if path == "ref.json" else cmp_root monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) - monkeypatch.setattr(sys, "argv", ["nvbench_compare", "ref.json", "cmp.json"]) + monkeypatch.setattr( + sys, "argv", ["nvbench_compare", "--display", "explain", "ref.json", "cmp.json"] + ) assert nvbench_compare.main() == 0 output = capsys.readouterr().out assert "Undecided (comparison requires more evidence): 1" in output assert "noise_too_high: 1" in output + assert "Reason legend: noise-high = noise_too_high" in output def test_get_comparison_thresholds_returns_named_presets(nvbench_compare): @@ -1733,10 +1868,11 @@ def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compa no_color=True, ) - assert captured["headers"][-3:] == ["Ref", "Cmp", "Status"] + assert captured["headers"][-4:] == ["Ref", "Cmp", "Change", "Status"] row = captured["rows"][0] - assert row[-3].startswith("1.000 s") - assert row[-2].startswith("1.010 s") + assert row[-4].startswith("1.000 s") + assert row[-3].startswith("1.010 s") + assert row[-2] == "" def test_compare_benches_legacy_display_uses_scalar_diff(monkeypatch, nvbench_compare): @@ -1829,17 +1965,20 @@ def test_compare_benches_explain_display_uses_explicit_intervals( display="explain", ) - assert captured["headers"][-6:] == [ - "Ref [L | C | H]", - "Cmp [L | C | H]", + assert captured["headers"][-7:] == [ + "Ref [Lo | Ce | Hi]", + "Cmp [Lo | Ce | Hi]", "Ref Noise", "Cmp Noise", "Reason", + "Change", "Status", ] row = captured["rows"][0] - assert row[-6] == "1.0[00 | 20 | 30] s" - assert row[-5] == "1.0[10 | 30 | 40] s" + assert row[-7] == "1.0[00 | 20 | 30] s" + assert row[-6] == "1.0[10 | 30 | 40] s" + assert row[-3] == "centers-far" + assert row[-2] == "" def test_main_passes_selected_preset_to_compare_benches(monkeypatch, nvbench_compare):