Refactor nvbench-compare timing comparison state

Introduce GpuTimingData, SummaryComparison, ComparisonStats, and
ComparisonRunData to make timing extraction, classification, and run-level
state explicit.

Load sample-time and SM-frequency bulk data from JSON binary output into
GpuTimingData when available, preserving count validation between paired
sample and frequency arrays.

Move GPU timing comparison logic into compare_gpu_timings(), prefer robust
median/IQR data when available, and fall back to mean/stdev summaries otherwise.
Keep missing or invalid noise on the unknown path.

Replace module-level comparison counters and selected-device globals with
per-run data passed into compare_benches(). Update tests to validate timing
classification, bulk-data loading, device pairing, filtered duplicate matching,
and summary counters through the new structures.
This commit is contained in:
Oleksandr Pavlyk
2026-06-02 15:04:39 -05:00
parent 9dfc742876
commit b43b1dcf70
2 changed files with 347 additions and 188 deletions

View File

@@ -10,6 +10,7 @@ import sys
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from typing import Any, Mapping
import jsondiff
import numpy as np
@@ -29,14 +30,6 @@ def version_tuple(v):
tabulate_version = version_tuple(tabulate.__version__)
all_ref_devices: list[dict] = []
all_cmp_devices: list[dict] = []
config_count = 0
unknown_count = 0
improvement_count = 0
regression_count = 0
pass_count = 0
GPU_TIME_MIN_TAG = "nv/cold/time/gpu/min"
GPU_TIME_MAX_TAG = "nv/cold/time/gpu/max"
GPU_TIME_MEAN_TAG = "nv/cold/time/gpu/mean"
@@ -53,7 +46,7 @@ SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs"
@dataclass(frozen=True)
class GpuTimeSummary:
class GpuTimingData:
minimum: float | None
maximum: float | None
mean: float | None
@@ -72,6 +65,59 @@ class TimeEstimate:
relative_dispersion: float | None
class ComparisonStatus(str, Enum):
UNKNOWN = "????"
SAME = "SAME"
FAST = "FAST"
SLOW = "SLOW"
@dataclass(frozen=True)
class SummaryComparison:
ref_estimate: TimeEstimate
cmp_estimate: TimeEstimate
ref_time: float
cmp_time: float
ref_noise: float | None
cmp_noise: float | None
diff: float
frac_diff: float
max_noise: float | None
status: ComparisonStatus
@dataclass
class ComparisonStats:
config_count: int = 0
pass_count: int = 0
improvement_count: int = 0
regression_count: int = 0
unknown_count: int = 0
def record(self, status: ComparisonStatus) -> None:
self.config_count += 1
if status == ComparisonStatus.UNKNOWN:
self.unknown_count += 1
elif status == ComparisonStatus.SAME:
self.pass_count += 1
elif status == ComparisonStatus.FAST:
self.improvement_count += 1
else:
self.regression_count += 1
DeviceInfo = Mapping[str, Any]
@dataclass(frozen=True)
class ComparisonRunData:
# Device metadata fields are treated as read-only; stats is intentionally
# mutable and accumulates counts across one comparison run.
stats: ComparisonStats
ref_devices: tuple[DeviceInfo, ...]
cmp_devices: tuple[DeviceInfo, ...]
@dataclass(frozen=True)
class BenchmarkFilterScope:
benchmark_name: str
@@ -319,7 +365,7 @@ def extract_sample_frequencies(summaries, json_dir):
return read_float32_binary(frequency_count, frequencies_filename, json_dir)
def extract_gpu_time_summary(summaries, json_dir=None):
def extract_gpu_timing_data(summaries, json_dir=None):
samples = extract_sample_times(summaries, json_dir)
frequencies = extract_sample_frequencies(summaries, json_dir)
if (
@@ -332,7 +378,7 @@ def extract_gpu_time_summary(summaries, json_dir=None):
f"frequency count ({len(frequencies)})"
)
return GpuTimeSummary(
return GpuTimingData(
minimum=extract_summary_float(summaries, GPU_TIME_MIN_TAG),
maximum=extract_summary_float(summaries, GPU_TIME_MAX_TAG),
mean=extract_summary_float(summaries, GPU_TIME_MEAN_TAG),
@@ -384,59 +430,106 @@ def select_relative_dispersion(relative_dispersion, absolute_dispersion, center)
return compute_relative_dispersion(absolute_dispersion, center)
def compute_common_time_estimates(ref_summary, cmp_summary):
if has_robust_estimate(ref_summary) and has_robust_estimate(cmp_summary):
def compute_common_time_estimates(ref_timing, cmp_timing):
if has_robust_estimate(ref_timing) and has_robust_estimate(cmp_timing):
return (
TimeEstimate(
center=ref_summary.median,
center=ref_timing.median,
relative_dispersion=select_relative_dispersion(
ref_summary.interquartile_range_relative,
ref_summary.interquartile_range,
ref_summary.median,
ref_timing.interquartile_range_relative,
ref_timing.interquartile_range,
ref_timing.median,
),
),
TimeEstimate(
center=cmp_summary.median,
center=cmp_timing.median,
relative_dispersion=select_relative_dispersion(
cmp_summary.interquartile_range_relative,
cmp_summary.interquartile_range,
cmp_summary.median,
cmp_timing.interquartile_range_relative,
cmp_timing.interquartile_range,
cmp_timing.median,
),
),
)
if has_mean_estimate(ref_summary) and has_mean_estimate(cmp_summary):
if has_mean_estimate(ref_timing) and has_mean_estimate(cmp_timing):
return (
TimeEstimate(
center=ref_summary.mean,
center=ref_timing.mean,
relative_dispersion=select_relative_dispersion(
ref_summary.stdev_relative, ref_summary.stdev, ref_summary.mean
ref_timing.stdev_relative, ref_timing.stdev, ref_timing.mean
),
),
TimeEstimate(
center=cmp_summary.mean,
center=cmp_timing.mean,
relative_dispersion=select_relative_dispersion(
cmp_summary.stdev_relative, cmp_summary.stdev, cmp_summary.mean
cmp_timing.stdev_relative, cmp_timing.stdev, cmp_timing.mean
),
),
)
return (
TimeEstimate(
center=ref_summary.mean,
center=ref_timing.mean,
relative_dispersion=compute_relative_dispersion(
ref_summary.stdev, ref_summary.mean
ref_timing.stdev, ref_timing.mean
),
),
TimeEstimate(
center=cmp_summary.mean,
center=cmp_timing.mean,
relative_dispersion=compute_relative_dispersion(
cmp_summary.stdev, cmp_summary.mean
cmp_timing.stdev, cmp_timing.mean
),
),
)
def compare_gpu_timings(ref_timing, cmp_timing):
ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing)
cmp_time = cmp_estimate.center
ref_time = ref_estimate.center
if cmp_time is None or ref_time is None:
return None
if not math.isfinite(cmp_time) or not math.isfinite(ref_time):
return None
if cmp_time <= 0.0 or ref_time <= 0.0:
return None
cmp_noise = cmp_estimate.relative_dispersion
ref_noise = ref_estimate.relative_dispersion
diff = cmp_time - ref_time
frac_diff = diff / ref_time
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
max_noise = None
status = ComparisonStatus.UNKNOWN
else:
max_noise = max(ref_noise, cmp_noise)
if abs(frac_diff) <= max_noise:
status = ComparisonStatus.SAME
elif diff < 0:
status = ComparisonStatus.FAST
else:
status = ComparisonStatus.SLOW
return SummaryComparison(
ref_estimate=ref_estimate,
cmp_estimate=cmp_estimate,
ref_time=ref_time,
cmp_time=cmp_time,
ref_noise=ref_noise,
cmp_noise=cmp_noise,
diff=diff,
frac_diff=frac_diff,
max_noise=max_noise,
status=status,
)
def find_matching_bench(needle, haystack):
for hay in haystack:
if hay["name"] == needle["name"]:
@@ -657,6 +750,16 @@ def has_finite_noise(noise):
return noise is not None and math.isfinite(noise)
def colorize_comparison_status(status, no_color):
if status == ComparisonStatus.UNKNOWN:
return colorize(status.value, Fore.YELLOW, Emoji.YELLOW, no_color)
if status == ComparisonStatus.SAME:
return colorize(status.value, Fore.BLUE, Emoji.BLUE, no_color)
if status == ComparisonStatus.FAST:
return colorize(status.value, Fore.GREEN, Emoji.GREEN, no_color)
return colorize(status.value, Fore.RED, Emoji.RED, no_color)
def format_axis_values(axis_values, axes, axis_filters=None):
if not axis_values:
return ""
@@ -749,6 +852,7 @@ def plot_comparison_entries(entries, title=None, dark=False):
def compare_benches(
run_data: ComparisonRunData,
ref_benches,
cmp_benches,
threshold,
@@ -847,8 +951,13 @@ def compare_benches(
)
rows = []
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
counters = {}
plot_data: dict[str, dict[str, dict[float, float | None]]] = {
"cmp": {},
"ref": {},
"cmp_noise": {},
"ref_noise": {},
}
counters: dict[str, int] = {}
for cmp_state in cmp_device_states:
cmp_state_name = state_match_key(cmp_state)
@@ -874,45 +983,22 @@ def compare_benches(
# TODO: Use other timings, too. Maybe multiple rows, with a
# "Timing" column + values "CPU/GPU/Batch"?
cmp_gpu_time = extract_gpu_time_summary(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_time_summary(ref_summaries, ref_json_dir)
ref_estimate, cmp_estimate = compute_common_time_estimates(
ref_gpu_time, cmp_gpu_time
)
cmp_time = cmp_estimate.center
ref_time = ref_estimate.center
if cmp_time is None or ref_time is None:
cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir)
ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir)
comparison = compare_gpu_timings(ref_gpu_time, cmp_gpu_time)
if comparison is None:
continue
if not math.isfinite(cmp_time) or not math.isfinite(ref_time):
continue
if cmp_time <= 0.0 or ref_time <= 0.0:
continue
cmp_noise = cmp_estimate.relative_dispersion
ref_noise = ref_estimate.relative_dispersion
diff = cmp_time - ref_time
frac_diff = diff / ref_time
if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise):
max_noise = None
else:
max_noise = max(ref_noise, cmp_noise)
if plot_along:
axis_name = []
axis_name_parts = []
axis_value = None
for av in axis_values:
if av["name"] != plot_along:
axis_name.append(f"""{av["name"]} = {av["value"]}""")
axis_name_parts.append(f"""{av["name"]} = {av["value"]}""")
else:
axis_value = float(av["value"])
if axis_value is not None:
axis_name = ", ".join(axis_name)
axis_name = ", ".join(axis_name_parts)
if axis_name not in plot_data["cmp"]:
plot_data["cmp"][axis_name] = {}
@@ -920,43 +1006,26 @@ def compare_benches(
plot_data["cmp_noise"][axis_name] = {}
plot_data["ref_noise"][axis_name] = {}
plot_data["cmp"][axis_name][axis_value] = cmp_time
plot_data["ref"][axis_name][axis_value] = ref_time
plot_data["cmp_noise"][axis_name][axis_value] = cmp_noise
plot_data["ref_noise"][axis_name][axis_value] = ref_noise
plot_data["cmp"][axis_name][axis_value] = comparison.cmp_time
plot_data["ref"][axis_name][axis_value] = comparison.ref_time
plot_data["cmp_noise"][axis_name][axis_value] = (
comparison.cmp_noise
)
plot_data["ref_noise"][axis_name][axis_value] = (
comparison.ref_noise
)
global config_count
global unknown_count
global pass_count
global improvement_count
global regression_count
run_data.stats.record(comparison.status)
status = colorize_comparison_status(comparison.status, no_color)
config_count += 1
if max_noise is None:
unknown_count += 1
status_label = "????"
status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color)
elif abs(frac_diff) <= max_noise:
pass_count += 1
status_label = "SAME"
status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color)
elif diff < 0:
improvement_count += 1
status_label = "FAST"
status = colorize(status_label, Fore.GREEN, Emoji.GREEN, no_color)
else:
regression_count += 1
status_label = "SLOW"
status = colorize(status_label, Fore.RED, Emoji.RED, no_color)
if abs(frac_diff) >= threshold:
if abs(comparison.frac_diff) >= threshold:
axis_filters = matching_axis_filters(cmp_state, axis_filter_groups)
row.append(format_duration(ref_time))
row.append(format_percentage(ref_noise))
row.append(format_duration(cmp_time))
row.append(format_percentage(cmp_noise))
row.append(format_duration(diff))
row.append(format_percentage(frac_diff))
row.append(format_duration(comparison.ref_time))
row.append(format_percentage(comparison.ref_noise))
row.append(format_duration(comparison.cmp_time))
row.append(format_percentage(comparison.cmp_noise))
row.append(format_duration(comparison.diff))
row.append(format_percentage(comparison.frac_diff))
row.append(status)
rows.append(row)
@@ -967,19 +1036,24 @@ def compare_benches(
else:
label = cmp_bench["name"]
cmp_device = find_device_by_id(
cmp_state["device"], all_cmp_devices
cmp_state["device"], run_data.cmp_devices
)
if cmp_device:
comparison_device_names.add(cmp_device["name"])
comparison_entries.append(
(label, frac_diff, status_label, cmp_bench["name"])
(
label,
comparison.frac_diff,
comparison.status.value,
cmp_bench["name"],
)
)
if len(rows) == 0:
continue
cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices)
ref_device = find_device_by_id(ref_device_id, all_ref_devices)
cmp_device = find_device_by_id(cmp_device_id, run_data.cmp_devices)
ref_device = find_device_by_id(ref_device_id, run_data.ref_devices)
if ref_device is None or cmp_device is None:
raise ValueError(
f"benchmark {cmp_bench['name']!r} references device pair "
@@ -1201,44 +1275,55 @@ def main() -> int:
else:
to_compare = [(files_or_dirs[0], files_or_dirs[1])]
stats = ComparisonStats()
for ref, comp in to_compare:
ref_root = reader.read_file(ref)
cmp_root = reader.read_file(comp)
global all_ref_devices
global all_cmp_devices
try:
all_ref_devices = select_devices(
selected_ref_devices = select_devices(
ref_root["devices"], reference_device_filter, "--reference-devices"
)
all_cmp_devices = select_devices(
selected_cmp_devices = select_devices(
cmp_root["devices"], compare_device_filter, "--compare-devices"
)
except ValueError as exc:
print(str(exc))
return 1
if len(all_ref_devices) != len(all_cmp_devices):
if len(selected_ref_devices) != len(selected_cmp_devices):
print(
f"--reference-devices selected {len(all_ref_devices)} device(s), "
f"but --compare-devices selected {len(all_cmp_devices)} device(s)"
f"--reference-devices selected {len(selected_ref_devices)} device(s), "
f"but --compare-devices selected {len(selected_cmp_devices)} device(s)"
)
return 1
if all_ref_devices != all_cmp_devices:
if selected_ref_devices != selected_cmp_devices:
warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED
msg_text = "Device sections do not match"
print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="")
print(": ", end="")
print(jsondiff.diff(all_ref_devices, all_cmp_devices, syntax="symmetric"))
print(
jsondiff.diff(
selected_ref_devices, selected_cmp_devices, syntax="symmetric"
)
)
if not args.ignore_devices and require_matching_device_sections(
reference_device_filter, compare_device_filter
):
return 1
run_data = ComparisonRunData(
stats=stats,
ref_devices=tuple(selected_ref_devices),
cmp_devices=tuple(selected_cmp_devices),
)
try:
compare_benches(
run_data,
ref_root["benchmarks"],
cmp_root["benchmarks"],
args.threshold,
@@ -1257,11 +1342,16 @@ def main() -> int:
return 1
print("# Summary\n")
print(f"- Total Matches: {config_count}")
print(f" - Pass (abs(%Diff) <= max_noise): {pass_count}")
print(f" - Improvement (abs(%Diff) > max_noise, %Diff < 0): {improvement_count}")
print(f" - Regression (abs(%Diff) > max_noise, %Diff > 0): {regression_count}")
print(f" - Unknown (infinite or unavailable noise): {unknown_count}")
print(f"- Total Matches: {stats.config_count}")
print(f" - Pass (abs(%Diff) <= max_noise): {stats.pass_count}")
print(
" - Improvement (abs(%Diff) > max_noise, %Diff < 0): "
f"{stats.improvement_count}"
)
print(
f" - Regression (abs(%Diff) > max_noise, %Diff > 0): {stats.regression_count}"
)
print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}")
return 0