mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-07-01 19:57:41 +00:00
Load script tooling dependencies lazily
Add a shared nvbench_tooling_deps helper for importing packages required by NVBench console tools. Missing tooling packages now raise a dedicated error with an install recipe instead of failing with a raw ImportError. Update script imports to work both as installed package modules and as direct source-tree scripts by using the __package__ import pattern for nvbench_json and the new tooling helper. Defer nvbench-compare dependencies to the points where they are needed: NumPy/colorama during normal comparison setup, tabulate during table rendering, jsondiff only for device mismatch reporting, and plotting packages only for plot modes. Update tests to initialize compare tooling when calling internals directly and add coverage for the tooling dependency loader. Closes #384
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
@@ -16,15 +18,20 @@ from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Any, BinaryIO, Callable, Protocol
|
||||
|
||||
import jsondiff
|
||||
import numpy as np
|
||||
import tabulate
|
||||
from colorama import Fore
|
||||
|
||||
try:
|
||||
from nvbench_json import reader
|
||||
except ImportError:
|
||||
from scripts.nvbench_json import reader
|
||||
if __package__:
|
||||
from .nvbench_json import reader
|
||||
from .nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
else:
|
||||
from nvbench_json import reader # type: ignore[no-redef]
|
||||
from nvbench_tooling_deps import ( # type: ignore[no-redef]
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
|
||||
|
||||
# Parse version string into tuple, "x.y.z" -> (x, y, z)
|
||||
@@ -32,7 +39,40 @@ def version_tuple(v):
|
||||
return tuple(map(int, (v.split("."))))
|
||||
|
||||
|
||||
tabulate_version = version_tuple(tabulate.__version__)
|
||||
np: Any = None
|
||||
Fore: Any = None
|
||||
|
||||
|
||||
def load_nvbench_compare_tooling() -> None:
|
||||
global Fore, np
|
||||
|
||||
if np is None:
|
||||
np = require_tooling_dependency(
|
||||
ToolingDependency("numpy", "numpy", "bulk timing analysis"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
if Fore is None:
|
||||
colorama = require_tooling_dependency(
|
||||
ToolingDependency("colorama", "colorama", "colored status output"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
Fore = colorama.Fore
|
||||
|
||||
|
||||
def load_tabulate_for_table_output() -> tuple[Any, tuple[int, ...]]:
|
||||
tabulate_module = require_tooling_dependency(
|
||||
ToolingDependency("tabulate", "tabulate", "table output"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
return tabulate_module, version_tuple(tabulate_module.__version__)
|
||||
|
||||
|
||||
def load_jsondiff_for_device_diff() -> Any:
|
||||
return require_tooling_dependency(
|
||||
ToolingDependency("jsondiff", "jsondiff", "device metadata diffs"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
|
||||
|
||||
GPU_TIME_MIN_TAG = "nv/cold/time/gpu/min"
|
||||
GPU_TIME_MAX_TAG = "nv/cold/time/gpu/max"
|
||||
@@ -801,7 +841,7 @@ class Emoji(str, Enum):
|
||||
NONE = ""
|
||||
|
||||
|
||||
def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str:
|
||||
def colorize(msg: str, fore: str, emoji: Emoji, no_color: bool) -> str:
|
||||
if no_color:
|
||||
prefix = ""
|
||||
if emoji_s := emoji.value:
|
||||
@@ -2746,13 +2786,22 @@ def plot_comparison_entries(entries, title=None, dark=False):
|
||||
print("No comparison data to plot.")
|
||||
return 1
|
||||
|
||||
matplotlib = require_tooling_dependency(
|
||||
ToolingDependency("matplotlib", "matplotlib", "plot rendering"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
if not os.environ.get("DISPLAY"):
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import PercentFormatter
|
||||
plt = require_tooling_dependency(
|
||||
ToolingDependency("matplotlib.pyplot", "matplotlib", "plot rendering"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
ticker = require_tooling_dependency(
|
||||
ToolingDependency("matplotlib.ticker", "matplotlib", "plot axis formatting"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
PercentFormatter = ticker.PercentFormatter
|
||||
|
||||
labels, values, statuses, bench_names = map(list, zip(*entries))
|
||||
|
||||
@@ -2836,8 +2885,16 @@ def compare_benches(
|
||||
comparison_thresholds = get_default_thresholds()
|
||||
|
||||
if plot_along:
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
plt = require_tooling_dependency(
|
||||
ToolingDependency(
|
||||
"matplotlib.pyplot", "matplotlib", "per-axis plot rendering"
|
||||
),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
sns = require_tooling_dependency(
|
||||
ToolingDependency("seaborn", "seaborn", "per-axis plot styling"),
|
||||
tool_name="nvbench-compare",
|
||||
)
|
||||
|
||||
sns.set_theme()
|
||||
|
||||
@@ -3074,6 +3131,7 @@ def compare_benches(
|
||||
f"## [{ref_device['id']}] {ref_device['name']} vs. "
|
||||
f"[{cmp_device['id']}] {cmp_device['name']}\n"
|
||||
)
|
||||
tabulate, tabulate_version = load_tabulate_for_table_output()
|
||||
# colalign and github format require tabulate 0.8.3
|
||||
if tabulate_version >= (0, 8, 3):
|
||||
print(
|
||||
@@ -3291,6 +3349,12 @@ def main() -> int:
|
||||
print(dump_comparison_config(comparison_preset, comparison_thresholds), end="")
|
||||
return 0
|
||||
|
||||
try:
|
||||
load_nvbench_compare_tooling()
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
filter_plan = build_benchmark_filter_plan(args.filter_actions)
|
||||
reference_device_filter = parse_device_filter(
|
||||
@@ -3362,6 +3426,12 @@ def main() -> int:
|
||||
return 1
|
||||
|
||||
if selected_ref_devices != selected_cmp_devices:
|
||||
try:
|
||||
jsondiff = load_jsondiff_for_device_diff()
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED
|
||||
msg_text = "Device sections do not match"
|
||||
print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="")
|
||||
@@ -3404,6 +3474,9 @@ def main() -> int:
|
||||
display=args.display,
|
||||
bulk_debug_rows=bulk_debug_rows,
|
||||
)
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
except ValueError as exc:
|
||||
print(str(exc))
|
||||
return 1
|
||||
|
||||
@@ -4,15 +4,50 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
try:
|
||||
if __package__:
|
||||
from .nvbench_json import reader
|
||||
from .nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
else:
|
||||
from nvbench_json import reader
|
||||
except ImportError:
|
||||
from scripts.nvbench_json import reader
|
||||
from nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
|
||||
np = None
|
||||
pd = None
|
||||
plt = None
|
||||
sns = None
|
||||
|
||||
|
||||
def load_nvbench_histogram_tooling():
|
||||
global np, pd, plt, sns
|
||||
|
||||
if plt is None:
|
||||
plt = require_tooling_dependency(
|
||||
ToolingDependency("matplotlib.pyplot", "matplotlib", "histogram plotting"),
|
||||
tool_name="nvbench-histogram",
|
||||
)
|
||||
if np is None:
|
||||
np = require_tooling_dependency(
|
||||
ToolingDependency("numpy", "numpy", "sample loading"),
|
||||
tool_name="nvbench-histogram",
|
||||
)
|
||||
if pd is None:
|
||||
pd = require_tooling_dependency(
|
||||
ToolingDependency("pandas", "pandas", "sample table construction"),
|
||||
tool_name="nvbench-histogram",
|
||||
)
|
||||
if sns is None:
|
||||
sns = require_tooling_dependency(
|
||||
ToolingDependency("seaborn", "seaborn", "histogram plotting"),
|
||||
tool_name="nvbench-histogram",
|
||||
)
|
||||
|
||||
|
||||
def parse_files():
|
||||
@@ -115,12 +150,18 @@ def parse_json(filename):
|
||||
|
||||
def main():
|
||||
filenames = parse_files()
|
||||
try:
|
||||
load_nvbench_histogram_tooling()
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
dfs = [parse_json(filename) for filename in filenames]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
|
||||
sns.displot(df, rug=True, kind="kde", fill=True)
|
||||
plt.show()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,13 +4,44 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import PercentFormatter
|
||||
|
||||
try:
|
||||
if __package__:
|
||||
from .nvbench_json import reader
|
||||
from .nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
else:
|
||||
from nvbench_json import reader
|
||||
except ImportError:
|
||||
from scripts.nvbench_json import reader
|
||||
from nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
|
||||
plt = None
|
||||
PercentFormatter = None
|
||||
|
||||
|
||||
def load_nvbench_plot_bwutil_tooling():
|
||||
global PercentFormatter, plt
|
||||
|
||||
if plt is None:
|
||||
plt = require_tooling_dependency(
|
||||
ToolingDependency(
|
||||
"matplotlib.pyplot", "matplotlib", "bandwidth plot rendering"
|
||||
),
|
||||
tool_name="nvbench-plot-bwutil",
|
||||
)
|
||||
if PercentFormatter is None:
|
||||
ticker = require_tooling_dependency(
|
||||
ToolingDependency(
|
||||
"matplotlib.ticker", "matplotlib", "plot axis formatting"
|
||||
),
|
||||
tool_name="nvbench-plot-bwutil",
|
||||
)
|
||||
PercentFormatter = ticker.PercentFormatter
|
||||
|
||||
|
||||
UTILIZATION_TAG = "nv/cold/bw/global/utilization"
|
||||
|
||||
@@ -263,6 +294,12 @@ def plot_entries(entries, title=None, output=None, dark=False):
|
||||
|
||||
def main():
|
||||
args, filenames = parse_files()
|
||||
try:
|
||||
load_nvbench_plot_bwutil_tooling()
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
axis_filters = parse_axis_filters(args.axis)
|
||||
except ValueError as exc:
|
||||
|
||||
37
python/scripts/nvbench_tooling_deps.py
Normal file
37
python/scripts/nvbench_tooling_deps.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolingDependency:
|
||||
import_name: str
|
||||
package_name: str
|
||||
purpose: str
|
||||
extra: str = "tools"
|
||||
|
||||
|
||||
class MissingToolingDependencyError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def require_tooling_dependency(
|
||||
dependency: ToolingDependency, *, tool_name: str
|
||||
) -> ModuleType:
|
||||
try:
|
||||
return importlib.import_module(dependency.import_name)
|
||||
except ImportError as exc:
|
||||
raise MissingToolingDependencyError(
|
||||
f"{tool_name} requires {dependency.package_name!r} for "
|
||||
f"{dependency.purpose}.\n\n"
|
||||
"Install the required tooling dependencies with:\n"
|
||||
f" python -m pip install 'cuda-bench[{dependency.extra}]'\n\n"
|
||||
f"Original import error: {exc}"
|
||||
) from exc
|
||||
@@ -5,12 +5,23 @@ import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tabulate
|
||||
|
||||
try:
|
||||
if __package__:
|
||||
from .nvbench_json import reader
|
||||
from .nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
else:
|
||||
from nvbench_json import reader
|
||||
except ImportError:
|
||||
from scripts.nvbench_json import reader
|
||||
from nvbench_tooling_deps import (
|
||||
MissingToolingDependencyError,
|
||||
ToolingDependency,
|
||||
require_tooling_dependency,
|
||||
)
|
||||
|
||||
tabulate = None
|
||||
tabulate_version = (0, 0, 0)
|
||||
|
||||
|
||||
# Parse version string into tuple, "x.y.z" -> (x, y, z)
|
||||
@@ -18,7 +29,18 @@ def version_tuple(v):
|
||||
return tuple(map(int, (v.split("."))))
|
||||
|
||||
|
||||
tabulate_version = version_tuple(tabulate.__version__)
|
||||
def load_nvbench_walltime_tooling():
|
||||
global tabulate, tabulate_version
|
||||
|
||||
if tabulate is not None:
|
||||
return
|
||||
|
||||
tabulate = require_tooling_dependency(
|
||||
ToolingDependency("tabulate", "tabulate", "table output"),
|
||||
tool_name="nvbench-walltime",
|
||||
)
|
||||
tabulate_version = version_tuple(tabulate.__version__)
|
||||
|
||||
|
||||
all_devices = []
|
||||
|
||||
@@ -341,6 +363,12 @@ def main():
|
||||
|
||||
filenames.sort()
|
||||
|
||||
try:
|
||||
load_nvbench_walltime_tooling()
|
||||
except MissingToolingDependencyError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
data = {}
|
||||
|
||||
files_out = {}
|
||||
@@ -355,6 +383,7 @@ def main():
|
||||
|
||||
print_overview_section(data)
|
||||
print_files_section(data)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user