mirror of
https://github.com/NVIDIA/nvbench.git
synced 2026-04-20 14:58:54 +00:00
Add multi-cuda wheel build (#289)
Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com> Co-authored-by: Nader Al Awar <naderalawar@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import warnings
|
||||
|
||||
@@ -31,29 +32,63 @@ except Exception as e:
|
||||
f"Version is set to fall-back value '{__version__}' instead."
|
||||
)
|
||||
|
||||
|
||||
# Detect CUDA runtime version and load appropriate extension
|
||||
def _get_cuda_major_version():
|
||||
"""Detect the CUDA runtime major version."""
|
||||
try:
|
||||
import cuda.bindings
|
||||
|
||||
# Get CUDA version from cuda-bindings package version
|
||||
# cuda-bindings version is in format like "12.9.1" or "13.0.0"
|
||||
version_str = cuda.bindings.__version__
|
||||
major = int(version_str.split(".")[0])
|
||||
return major
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"cuda-bindings is required for runtime CUDA version detection. "
|
||||
"Install with: pip install pynvbench[cu12] or pip install pynvbench[cu13]"
|
||||
)
|
||||
|
||||
|
||||
_cuda_major = _get_cuda_major_version()
|
||||
_extra_name = f"cu{_cuda_major}"
|
||||
_module_fullname = f"cuda.bench.{_extra_name}._nvbench"
|
||||
|
||||
try:
|
||||
_nvbench_module = importlib.import_module(_module_fullname)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"No pynvbench extension found for CUDA {_cuda_major}.x. "
|
||||
f"This wheel may not include support for your CUDA version. "
|
||||
f"Supported CUDA versions: 12, 13. "
|
||||
f"Original error: {e}"
|
||||
)
|
||||
|
||||
# Load required NVIDIA libraries
|
||||
for libname in ("cupti", "nvperf_target", "nvperf_host"):
|
||||
load_nvidia_dynamic_lib(libname)
|
||||
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
Benchmark as Benchmark,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
CudaStream as CudaStream,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
Launch as Launch,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
NVBenchRuntimeError as NVBenchRuntimeError,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
State as State,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
register as register,
|
||||
)
|
||||
from cuda.bench._nvbench import ( # noqa: E402
|
||||
run_all_benchmarks as run_all_benchmarks,
|
||||
)
|
||||
# Import and expose all public symbols from the CUDA-specific extension
|
||||
Benchmark = _nvbench_module.Benchmark
|
||||
CudaStream = _nvbench_module.CudaStream
|
||||
Launch = _nvbench_module.Launch
|
||||
NVBenchRuntimeError = _nvbench_module.NVBenchRuntimeError
|
||||
State = _nvbench_module.State
|
||||
register = _nvbench_module.register
|
||||
run_all_benchmarks = _nvbench_module.run_all_benchmarks
|
||||
test_cpp_exception = _nvbench_module.test_cpp_exception
|
||||
test_py_exception = _nvbench_module.test_py_exception
|
||||
|
||||
del load_nvidia_dynamic_lib
|
||||
# Expose the module as _nvbench for backward compatibility (e.g., for tests)
|
||||
_nvbench = _nvbench_module
|
||||
|
||||
# Clean up internal symbols
|
||||
del (
|
||||
load_nvidia_dynamic_lib,
|
||||
_nvbench_module,
|
||||
_cuda_major,
|
||||
_extra_name,
|
||||
_module_fullname,
|
||||
_get_cuda_major_version,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user