Add nvbench_compare display modes and interval-based table views

Extend nvbench_compare with multiple table display modes and richer interval
formatting for timing comparisons.

Highlights:
  - add `--display` with `intervals`, `legacy`, and `explain` modes
  - keep `legacy` output using scalar Diff/%Diff
  - make `intervals` the default, showing compact center-plus-delta timing
    intervals
  - add `explain` mode with explicit `[L | C | H]` interval rendering and
    self-describing headers
  - compute and store diff and relative-diff intervals in SummaryComparison
  - add formatting helpers for absolute and relative interval displays
  - make default preset slightly more permissive by lowering
    `bulk_same_sample_coverage` to 0.97

Add focused tests covering:
  - diff/%diff interval computation
  - compact and explicit interval formatting
  - default, legacy, and explain table layouts
  - CLI propagation of `--display` and preset selection
This commit is contained in:
Oleksandr Pavlyk
2026-06-04 08:49:06 -05:00
parent 2a515c2569
commit 4cf75dcaf5
2 changed files with 400 additions and 39 deletions

View File

@@ -77,7 +77,7 @@ COMPARISON_THRESHOLD_PRESET_VALUES = {
"same_center_relative": 0.005,
"same_overlap_fraction": 0.5,
"same_relative_dispersion_ceiling": 0.02,
"bulk_same_sample_coverage": 0.99,
"bulk_same_sample_coverage": 0.97,
"bulk_same_support_coverage": 0.80,
"bulk_support_rare_sample_fraction": 0.001,
"bulk_support_max_removed_sample_fraction": 0.01,
@@ -204,6 +204,8 @@ class TimingDecision:
@dataclass(frozen=True)
class SummaryComparison:
ref_interval: TimingInterval | None
cmp_interval: TimingInterval | None
ref_estimate: TimeEstimate
cmp_estimate: TimeEstimate
ref_time: float
@@ -212,6 +214,8 @@ class SummaryComparison:
cmp_noise: float | None
diff: float
frac_diff: float
diff_interval: tuple[float, float] | None
frac_diff_interval: tuple[float, float] | None
max_noise: float | None
status: ComparisonStatus
reason: DecisionReason
@@ -680,6 +684,20 @@ def compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds):
return None
def compute_diff_interval(ref_interval, cmp_interval):
return (
cmp_interval.lower - ref_interval.upper,
cmp_interval.upper - ref_interval.lower,
)
def compute_frac_diff_interval(ref_interval, cmp_interval):
return (
cmp_interval.lower / ref_interval.upper - 1.0,
cmp_interval.upper / ref_interval.lower - 1.0,
)
def centers_are_close(ref_center, cmp_center, thresholds):
if not is_positive_finite(ref_center) or not is_positive_finite(cmp_center):
return False
@@ -1240,8 +1258,15 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
cmp_noise = cmp_estimate.relative_dispersion
ref_noise = ref_estimate.relative_dispersion
ref_interval = compute_timing_interval(ref_timing)
cmp_interval = compute_timing_interval(cmp_timing)
diff = cmp_time - ref_time
frac_diff = diff / ref_time
diff_interval = None
frac_diff_interval = None
if ref_interval is not None and cmp_interval is not None:
diff_interval = compute_diff_interval(ref_interval, cmp_interval)
frac_diff_interval = compute_frac_diff_interval(ref_interval, cmp_interval)
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
max_noise = None
@@ -1266,6 +1291,8 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
decision = bulk_decision
return SummaryComparison(
ref_interval=ref_interval,
cmp_interval=cmp_interval,
ref_estimate=ref_estimate,
cmp_estimate=cmp_estimate,
ref_time=ref_time,
@@ -1274,6 +1301,8 @@ def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None):
cmp_noise=cmp_noise,
diff=diff,
frac_diff=frac_diff,
diff_interval=diff_interval,
frac_diff_interval=frac_diff_interval,
max_noise=max_noise,
status=decision.status,
reason=decision.reason,
@@ -1486,6 +1515,82 @@ def format_duration(seconds):
return f"{seconds * multiplier:0.3f} {units}"
def select_duration_units(*seconds_values):
max_abs_seconds = max(abs(value) for value in seconds_values)
if max_abs_seconds >= 1:
return 1.0, "s"
if max_abs_seconds >= 1e-3:
return 1e3, "ms"
return 1e6, "us"
def duration_precision_for_center(center, delta_multiplier):
center_multiplier, _ = select_duration_units(center)
center_quantum = 10.0**-3 * (delta_multiplier / center_multiplier)
if center_quantum >= 1.0:
return 0
return int(math.ceil(-math.log10(center_quantum)))
def format_duration_range(bounds):
if bounds is None:
return "n/a"
lower, upper = bounds
multiplier, units = select_duration_units(lower, upper)
return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}"
def format_timing_with_interval(center, interval):
if center is None:
return "n/a"
if interval is None:
return format_duration(center)
lower_delta = interval.lower - interval.center
upper_delta = interval.upper - interval.center
multiplier, units = select_duration_units(lower_delta, upper_delta)
precision = duration_precision_for_center(center, multiplier)
return (
f"{format_duration(center)} "
f"[{lower_delta * multiplier:+0.{precision}f}, "
f"{upper_delta * multiplier:+0.{precision}f}] {units}"
)
def longest_common_prefix(strings):
if not strings:
return ""
prefix = strings[0]
for text in strings[1:]:
while not text.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
def format_timing_with_explicit_interval(center, interval):
if center is None:
return "n/a"
if interval is None:
return format_duration(center)
multiplier, units = select_duration_units(
interval.lower, interval.center, interval.upper
)
values = [
f"{interval.lower * multiplier:0.3f}",
f"{interval.center * multiplier:0.3f}",
f"{interval.upper * multiplier:0.3f}",
]
prefix = longest_common_prefix(values)
if not prefix:
return f"[{values[0]} | {values[1]} | {values[2]}] {units}"
suffixes = [value[len(prefix) :] for value in values]
return f"{prefix}[{suffixes[0]} | {suffixes[1]} | {suffixes[2]}] {units}"
def format_percentage(percentage):
if percentage is None:
return "n/a"
@@ -1496,6 +1601,79 @@ def format_percentage(percentage):
return f"{percentage * 100.0:0.2f}%"
def format_percentage_bounds(bounds, status):
if bounds is None:
return "n/a"
lower, upper = bounds
if status == ComparisonStatus.FAST:
return f"<= {upper * 100.0:+0.1f}%"
if status == ComparisonStatus.SLOW:
return f">= {lower * 100.0:+0.1f}%"
return f"in [{lower * 100.0:+0.1f}%, {upper * 100.0:+0.1f}%]"
def get_display_headers(display):
if display == "legacy":
return (
[
"Ref Time",
"Ref Noise",
"Cmp Time",
"Cmp Noise",
"Diff",
"%Diff",
"Status",
],
["right", "right", "right", "right", "right", "right", "center"],
)
if display == "explain":
return (
[
"Ref [L | C | H]",
"Cmp [L | C | H]",
"Ref Noise",
"Cmp Noise",
"Reason",
"Status",
],
["right", "right", "right", "right", "left", "center"],
)
return (
["Ref", "Cmp", "Status"],
["right", "right", "center"],
)
def append_display_row(row, comparison, no_color, display):
if display == "legacy":
row.append(format_duration(comparison.ref_time))
row.append(format_percentage(comparison.ref_noise))
row.append(format_duration(comparison.cmp_time))
row.append(format_percentage(comparison.cmp_noise))
row.append(format_duration(comparison.diff))
row.append(format_percentage(comparison.frac_diff))
row.append(colorize_comparison_status(comparison.status, no_color))
return
row.append(
format_timing_with_interval(comparison.ref_time, comparison.ref_interval)
)
row.append(
format_timing_with_interval(comparison.cmp_time, comparison.cmp_interval)
)
if display == "explain":
row[-2] = format_timing_with_explicit_interval(
comparison.ref_time, comparison.ref_interval
)
row[-1] = format_timing_with_explicit_interval(
comparison.cmp_time, comparison.cmp_interval
)
row.append(format_percentage(comparison.ref_noise))
row.append(format_percentage(comparison.cmp_noise))
row.append(comparison.reason.code)
row.append(colorize_comparison_status(comparison.status, no_color))
def has_finite_noise(noise):
return noise is not None and math.isfinite(noise)
@@ -1618,6 +1796,7 @@ def compare_benches(
ref_json_dir=None,
cmp_json_dir=None,
comparison_thresholds=None,
display="intervals",
):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
@@ -1664,21 +1843,9 @@ def compare_benches(
headers = [x["name"] for x in axes]
colalign = ["center"] * len(headers)
headers.append("Ref Time")
colalign.append("right")
headers.append("Ref Noise")
colalign.append("right")
headers.append("Cmp Time")
colalign.append("right")
headers.append("Cmp Noise")
colalign.append("right")
headers.append("Diff")
colalign.append("right")
headers.append("%Diff")
colalign.append("right")
headers.append("Status")
colalign.append("center")
display_headers, display_colalign = get_display_headers(display)
headers.extend(display_headers)
colalign.extend(display_colalign)
for cmp_device_index, cmp_device_id in enumerate(cmp_device_ids):
ref_device_id = ref_device_ids[cmp_device_index]
@@ -1774,17 +1941,9 @@ def compare_benches(
)
run_data.stats.record(comparison.status, comparison.reason)
status = colorize_comparison_status(comparison.status, no_color)
if abs(comparison.frac_diff) >= threshold:
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
row.append(format_duration(comparison.ref_time))
row.append(format_percentage(comparison.ref_noise))
row.append(format_duration(comparison.cmp_time))
row.append(format_percentage(comparison.cmp_noise))
row.append(format_duration(comparison.diff))
row.append(format_percentage(comparison.frac_diff))
row.append(status)
append_display_row(row, comparison, no_color, display)
rows.append(row)
if plot:
@@ -1953,6 +2112,12 @@ def main() -> int:
default="default",
help="comparison threshold preset",
)
parser.add_argument(
"--display",
choices=["intervals", "legacy", "explain"],
default="intervals",
help="comparison table display mode",
)
parser.add_argument(
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
)
@@ -2091,17 +2256,18 @@ def main() -> int:
run_data,
ref_root["benchmarks"],
cmp_root["benchmarks"],
args.threshold,
args.plot_along,
args.plot,
args.dark,
filter_plan,
args.no_color,
reference_device_filter,
compare_device_filter,
os.path.dirname(ref),
os.path.dirname(comp),
comparison_thresholds,
threshold=args.threshold,
plot_along=args.plot_along,
plot=args.plot,
dark=args.dark,
filter_plan=filter_plan,
no_color=args.no_color,
reference_device_filter=reference_device_filter,
compare_device_filter=compare_device_filter,
ref_json_dir=os.path.dirname(ref),
cmp_json_dir=os.path.dirname(comp),
comparison_thresholds=comparison_thresholds,
display=args.display,
)
except ValueError as exc:
print(str(exc))