Implement --bulk-debug-python option

Use this option to generate Python script with information needed to load
bulk data from reference/compare datasets for further drill-down into
data.
This commit is contained in:
Oleksandr Pavlyk
2026-06-04 12:49:55 -05:00
parent 997e0be9db
commit 9890aad294
3 changed files with 387 additions and 0 deletions

View File

@@ -6,6 +6,7 @@
import argparse
import math
import os
import pprint
import sys
import warnings
from collections import Counter
@@ -403,6 +404,15 @@ class GpuTimingData:
return self.frequency_source.values
@dataclass(frozen=True)
class BulkDebugOutput:
destination: str
@property
def is_stdout(self) -> bool:
return self.destination.lower() == "stdout"
@dataclass(frozen=True)
class TimeEstimate:
center: float | None
@@ -842,6 +852,137 @@ def extract_gpu_timing_data(summaries, json_dir=None, float32_reader=read_float3
)
def resolve_bulk_source_filename(source: Float32BinarySource | None) -> str | None:
if source is None:
return None
return resolve_binary_filename(source.json_dir, source.filename)
def get_bulk_source_count(source: Float32BinarySource | None) -> int | None:
if source is None:
return None
return source.count
def make_axis_debug_values(axis_values, axes) -> list[dict[str, Any]]:
return [
{
"name": axis_value.get("name"),
"type": axis_value.get("type"),
"value": axis_value.get("value"),
"display": format_axis_value(axis_value["name"], axis_value, axes),
}
for axis_value in axis_values
]
def make_bulk_debug_row(
*,
row_index: int,
table_row_index: int,
benchmark_name: str,
ref_json_path: str | None,
cmp_json_path: str | None,
ref_device_id: int,
cmp_device_id: int,
cmp_state_name: str,
occurrence: int,
occurrence_count: int,
axis_values,
axes,
ref_timing: GpuTimingData,
cmp_timing: GpuTimingData,
comparison: SummaryComparison,
) -> dict[str, Any]:
return {
"row_index": row_index,
"table_row_index": table_row_index,
"benchmark": benchmark_name,
"reference_json": ref_json_path,
"compare_json": cmp_json_path,
"reference_device_id": ref_device_id,
"compare_device_id": cmp_device_id,
"state_key": cmp_state_name,
"occurrence": occurrence,
"occurrence_count": occurrence_count,
"axis_values": make_axis_debug_values(axis_values, axes),
"status": comparison.status.value,
"reason": comparison.reason.code,
"reason_message": comparison.reason.message,
"reference_time": comparison.ref_time,
"compare_time": comparison.cmp_time,
"fractional_difference": comparison.frac_diff,
"reference_sample_filename": resolve_bulk_source_filename(
ref_timing.sample_source
),
"reference_sample_count": get_bulk_source_count(ref_timing.sample_source),
"reference_frequency_filename": resolve_bulk_source_filename(
ref_timing.frequency_source
),
"reference_frequency_count": get_bulk_source_count(ref_timing.frequency_source),
"compare_sample_filename": resolve_bulk_source_filename(
cmp_timing.sample_source
),
"compare_sample_count": get_bulk_source_count(cmp_timing.sample_source),
"compare_frequency_filename": resolve_bulk_source_filename(
cmp_timing.frequency_source
),
"compare_frequency_count": get_bulk_source_count(cmp_timing.frequency_source),
}
def format_bulk_debug_python(bulk_rows: list[dict[str, Any]]) -> str:
return (
"# Generated by nvbench-compare --bulk-debug-python.\n"
"import numpy as np\n\n"
f"bulk_rows = {pprint.pformat(bulk_rows, sort_dicts=False)}\n\n"
"def read_float32(filename, expected_count=None):\n"
" if filename is None:\n"
" return None\n"
" values = np.fromfile(filename, dtype='<f4')\n"
" if expected_count is not None and len(values) != expected_count:\n"
" raise ValueError(\n"
" f'{filename!r}: expected {expected_count} float32 values, '\n"
" f'found {len(values)}'\n"
" )\n"
" return values\n\n"
"def load_bulk_data(row):\n"
" return {\n"
" 'reference_samples': read_float32(\n"
" row['reference_sample_filename'], row['reference_sample_count']\n"
" ),\n"
" 'reference_frequencies': read_float32(\n"
" row['reference_frequency_filename'], row['reference_frequency_count']\n"
" ),\n"
" 'compare_samples': read_float32(\n"
" row['compare_sample_filename'], row['compare_sample_count']\n"
" ),\n"
" 'compare_frequencies': read_float32(\n"
" row['compare_frequency_filename'], row['compare_frequency_count']\n"
" ),\n"
" }\n\n"
"# Examples:\n"
"# row = bulk_rows[0]\n"
"# arrays = load_bulk_data(row)\n"
"# undecided = [row for row in bulk_rows if row['status'] == 'UNDECIDED']\n"
)
def write_bulk_debug_python(
output: BulkDebugOutput | None, bulk_rows: list[dict[str, Any]]
) -> None:
if output is None:
return
script = format_bulk_debug_python(bulk_rows)
if output.is_stdout:
print(script, end="")
return
with open(output.destination, "w", encoding="utf-8") as output_file:
output_file.write(script)
def compute_relative_dispersion(dispersion, center):
if (
dispersion is None
@@ -2049,8 +2190,11 @@ def compare_benches(
compare_device_filter=None,
ref_json_dir=None,
cmp_json_dir=None,
ref_json_path=None,
cmp_json_path=None,
comparison_thresholds=None,
display="intervals",
bulk_debug_rows=None,
):
if comparison_thresholds is None:
comparison_thresholds = ComparisonThresholds()
@@ -2200,6 +2344,26 @@ def compare_benches(
append_display_row(row, comparison, no_color, display)
rows.append(row)
if bulk_debug_rows is not None:
bulk_debug_rows.append(
make_bulk_debug_row(
row_index=len(bulk_debug_rows),
table_row_index=len(rows) - 1,
benchmark_name=cmp_bench["name"],
ref_json_path=ref_json_path,
cmp_json_path=cmp_json_path,
ref_device_id=ref_device_id,
cmp_device_id=cmp_device_id,
cmp_state_name=cmp_state_name,
occurrence=occurrence,
occurrence_count=cmp_state_counts[cmp_state_name],
axis_values=axis_values,
axes=axes,
ref_timing=ref_gpu_time,
cmp_timing=cmp_gpu_time,
comparison=comparison,
)
)
if plot:
axis_label = format_axis_values(axis_values, axes, axis_filters)
if axis_label:
@@ -2382,6 +2546,14 @@ def main() -> int:
default="intervals",
help="comparison table display mode",
)
parser.add_argument(
"--bulk-debug-python",
default=None,
help=(
"Write Python code that describes bulk sample/frequency files for "
"each displayed row. Use 'stdout' to print the code to stdout."
),
)
parser.add_argument(
"--plot-along", type=str, dest="plot_along", default=None, help="plot results"
)
@@ -2461,6 +2633,15 @@ def main() -> int:
parser.print_help()
return 1
bulk_debug_output = (
None
if args.bulk_debug_python is None
else BulkDebugOutput(args.bulk_debug_python)
)
bulk_debug_rows: list[dict[str, Any]] | None = (
[] if bulk_debug_output is not None else None
)
# if provided two directories, find all the exactly named files
# in both and treat them as the reference and compare
to_compare = []
@@ -2541,8 +2722,11 @@ def main() -> int:
compare_device_filter=compare_device_filter,
ref_json_dir=os.path.dirname(ref),
cmp_json_dir=os.path.dirname(comp),
ref_json_path=ref,
cmp_json_path=comp,
comparison_thresholds=comparison_thresholds,
display=args.display,
bulk_debug_rows=bulk_debug_rows,
)
except ValueError as exc:
print(str(exc))
@@ -2565,6 +2749,11 @@ def main() -> int:
):
print(f" - {code}: {reason_summary.count} ({reason_summary.message})")
print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}")
try:
write_bulk_debug_python(bulk_debug_output, bulk_debug_rows or [])
except OSError as exc:
print(f"failed to write bulk debug Python output: {exc}")
return 1
return 0