Load script tooling dependencies lazily

Add a shared nvbench_tooling_deps helper for importing packages required
by NVBench console tools. Missing tooling packages now raise a dedicated
error with an install recipe instead of failing with a raw ImportError.

Update script imports to work both as installed package modules and as
direct source-tree scripts by using the __package__ import pattern for
nvbench_json and the new tooling helper.

Defer nvbench-compare dependencies to the points where they are needed:
NumPy/colorama during normal comparison setup, tabulate during table
rendering, jsondiff only for device mismatch reporting, and plotting
packages only for plot modes.

Update tests to initialize compare tooling when calling internals
directly and add coverage for the tooling dependency loader.

Closes #384
This commit is contained in:
Oleksandr Pavlyk
2026-06-29 11:38:42 -05:00
parent 6dae814da5
commit 5fd21dd7fa
7 changed files with 348 additions and 41 deletions

View File

@@ -3,6 +3,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from __future__ import annotations
import argparse
import math
import os
@@ -16,15 +18,20 @@ from enum import Enum
from functools import cached_property
from typing import Any, BinaryIO, Callable, Protocol
import jsondiff
import numpy as np
import tabulate
from colorama import Fore
try:
from nvbench_json import reader
except ImportError:
from scripts.nvbench_json import reader
if __package__:
from .nvbench_json import reader
from .nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
else:
from nvbench_json import reader # type: ignore[no-redef]
from nvbench_tooling_deps import ( # type: ignore[no-redef]
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
# Parse version string into tuple, "x.y.z" -> (x, y, z)
@@ -32,7 +39,40 @@ def version_tuple(v):
return tuple(map(int, (v.split("."))))
tabulate_version = version_tuple(tabulate.__version__)
np: Any = None
Fore: Any = None
def load_nvbench_compare_tooling() -> None:
global Fore, np
if np is None:
np = require_tooling_dependency(
ToolingDependency("numpy", "numpy", "bulk timing analysis"),
tool_name="nvbench-compare",
)
if Fore is None:
colorama = require_tooling_dependency(
ToolingDependency("colorama", "colorama", "colored status output"),
tool_name="nvbench-compare",
)
Fore = colorama.Fore
def load_tabulate_for_table_output() -> tuple[Any, tuple[int, ...]]:
tabulate_module = require_tooling_dependency(
ToolingDependency("tabulate", "tabulate", "table output"),
tool_name="nvbench-compare",
)
return tabulate_module, version_tuple(tabulate_module.__version__)
def load_jsondiff_for_device_diff() -> Any:
return require_tooling_dependency(
ToolingDependency("jsondiff", "jsondiff", "device metadata diffs"),
tool_name="nvbench-compare",
)
GPU_TIME_MIN_TAG = "nv/cold/time/gpu/min"
GPU_TIME_MAX_TAG = "nv/cold/time/gpu/max"
@@ -801,7 +841,7 @@ class Emoji(str, Enum):
NONE = ""
def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str:
def colorize(msg: str, fore: str, emoji: Emoji, no_color: bool) -> str:
if no_color:
prefix = ""
if emoji_s := emoji.value:
@@ -2746,13 +2786,22 @@ def plot_comparison_entries(entries, title=None, dark=False):
print("No comparison data to plot.")
return 1
matplotlib = require_tooling_dependency(
ToolingDependency("matplotlib", "matplotlib", "plot rendering"),
tool_name="nvbench-compare",
)
if not os.environ.get("DISPLAY"):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
plt = require_tooling_dependency(
ToolingDependency("matplotlib.pyplot", "matplotlib", "plot rendering"),
tool_name="nvbench-compare",
)
ticker = require_tooling_dependency(
ToolingDependency("matplotlib.ticker", "matplotlib", "plot axis formatting"),
tool_name="nvbench-compare",
)
PercentFormatter = ticker.PercentFormatter
labels, values, statuses, bench_names = map(list, zip(*entries))
@@ -2836,8 +2885,16 @@ def compare_benches(
comparison_thresholds = get_default_thresholds()
if plot_along:
import matplotlib.pyplot as plt
import seaborn as sns
plt = require_tooling_dependency(
ToolingDependency(
"matplotlib.pyplot", "matplotlib", "per-axis plot rendering"
),
tool_name="nvbench-compare",
)
sns = require_tooling_dependency(
ToolingDependency("seaborn", "seaborn", "per-axis plot styling"),
tool_name="nvbench-compare",
)
sns.set_theme()
@@ -3074,6 +3131,7 @@ def compare_benches(
f"## [{ref_device['id']}] {ref_device['name']} vs. "
f"[{cmp_device['id']}] {cmp_device['name']}\n"
)
tabulate, tabulate_version = load_tabulate_for_table_output()
# colalign and github format require tabulate 0.8.3
if tabulate_version >= (0, 8, 3):
print(
@@ -3291,6 +3349,12 @@ def main() -> int:
print(dump_comparison_config(comparison_preset, comparison_thresholds), end="")
return 0
try:
load_nvbench_compare_tooling()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
try:
filter_plan = build_benchmark_filter_plan(args.filter_actions)
reference_device_filter = parse_device_filter(
@@ -3362,6 +3426,12 @@ def main() -> int:
return 1
if selected_ref_devices != selected_cmp_devices:
try:
jsondiff = load_jsondiff_for_device_diff()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED
msg_text = "Device sections do not match"
print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="")
@@ -3404,6 +3474,9 @@ def main() -> int:
display=args.display,
bulk_debug_rows=bulk_debug_rows,
)
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
except ValueError as exc:
print(str(exc))
return 1

View File

@@ -4,15 +4,50 @@ import argparse
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
try:
if __package__:
from .nvbench_json import reader
from .nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
else:
from nvbench_json import reader
except ImportError:
from scripts.nvbench_json import reader
from nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
np = None
pd = None
plt = None
sns = None
def load_nvbench_histogram_tooling():
global np, pd, plt, sns
if plt is None:
plt = require_tooling_dependency(
ToolingDependency("matplotlib.pyplot", "matplotlib", "histogram plotting"),
tool_name="nvbench-histogram",
)
if np is None:
np = require_tooling_dependency(
ToolingDependency("numpy", "numpy", "sample loading"),
tool_name="nvbench-histogram",
)
if pd is None:
pd = require_tooling_dependency(
ToolingDependency("pandas", "pandas", "sample table construction"),
tool_name="nvbench-histogram",
)
if sns is None:
sns = require_tooling_dependency(
ToolingDependency("seaborn", "seaborn", "histogram plotting"),
tool_name="nvbench-histogram",
)
def parse_files():
@@ -115,12 +150,18 @@ def parse_json(filename):
def main():
filenames = parse_files()
try:
load_nvbench_histogram_tooling()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
dfs = [parse_json(filename) for filename in filenames]
df = pd.concat(dfs, ignore_index=True)
sns.displot(df, rug=True, kind="kde", fill=True)
plt.show()
return 0
if __name__ == "__main__":

View File

@@ -4,13 +4,44 @@ import argparse
import os
import sys
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
try:
if __package__:
from .nvbench_json import reader
from .nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
else:
from nvbench_json import reader
except ImportError:
from scripts.nvbench_json import reader
from nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
plt = None
PercentFormatter = None
def load_nvbench_plot_bwutil_tooling():
global PercentFormatter, plt
if plt is None:
plt = require_tooling_dependency(
ToolingDependency(
"matplotlib.pyplot", "matplotlib", "bandwidth plot rendering"
),
tool_name="nvbench-plot-bwutil",
)
if PercentFormatter is None:
ticker = require_tooling_dependency(
ToolingDependency(
"matplotlib.ticker", "matplotlib", "plot axis formatting"
),
tool_name="nvbench-plot-bwutil",
)
PercentFormatter = ticker.PercentFormatter
UTILIZATION_TAG = "nv/cold/bw/global/utilization"
@@ -263,6 +294,12 @@ def plot_entries(entries, title=None, output=None, dark=False):
def main():
args, filenames = parse_files()
try:
load_nvbench_plot_bwutil_tooling()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
try:
axis_filters = parse_axis_filters(args.axis)
except ValueError as exc:

View File

@@ -0,0 +1,37 @@
#!/usr/bin/env python
#
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from __future__ import annotations
import importlib
from dataclasses import dataclass
from types import ModuleType
@dataclass(frozen=True)
class ToolingDependency:
import_name: str
package_name: str
purpose: str
extra: str = "tools"
class MissingToolingDependencyError(RuntimeError):
pass
def require_tooling_dependency(
dependency: ToolingDependency, *, tool_name: str
) -> ModuleType:
try:
return importlib.import_module(dependency.import_name)
except ImportError as exc:
raise MissingToolingDependencyError(
f"{tool_name} requires {dependency.package_name!r} for "
f"{dependency.purpose}.\n\n"
"Install the required tooling dependencies with:\n"
f" python -m pip install 'cuda-bench[{dependency.extra}]'\n\n"
f"Original import error: {exc}"
) from exc

View File

@@ -5,12 +5,23 @@ import math
import os
import sys
import tabulate
try:
if __package__:
from .nvbench_json import reader
from .nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
else:
from nvbench_json import reader
except ImportError:
from scripts.nvbench_json import reader
from nvbench_tooling_deps import (
MissingToolingDependencyError,
ToolingDependency,
require_tooling_dependency,
)
tabulate = None
tabulate_version = (0, 0, 0)
# Parse version string into tuple, "x.y.z" -> (x, y, z)
@@ -18,7 +29,18 @@ def version_tuple(v):
return tuple(map(int, (v.split("."))))
tabulate_version = version_tuple(tabulate.__version__)
def load_nvbench_walltime_tooling():
global tabulate, tabulate_version
if tabulate is not None:
return
tabulate = require_tooling_dependency(
ToolingDependency("tabulate", "tabulate", "table output"),
tool_name="nvbench-walltime",
)
tabulate_version = version_tuple(tabulate.__version__)
all_devices = []
@@ -341,6 +363,12 @@ def main():
filenames.sort()
try:
load_nvbench_walltime_tooling()
except MissingToolingDependencyError as exc:
print(str(exc), file=sys.stderr)
return 1
data = {}
files_out = {}
@@ -355,6 +383,7 @@ def main():
print_overview_section(data)
print_files_section(data)
return 0
if __name__ == "__main__":