Add multi-cuda wheel build (#289)

Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com>
Co-authored-by: Nader Al Awar <naderalawar@gmail.com>
This commit is contained in:
Ashwin Srinath
2026-01-28 10:37:55 -05:00
committed by GitHub
parent f3fa93f388
commit a681e2185d
13 changed files with 379 additions and 288 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import importlib.metadata
import warnings
@@ -31,29 +32,63 @@ except Exception as e:
f"Version is set to fall-back value '{__version__}' instead."
)
# Detect CUDA runtime version and load appropriate extension
def _get_cuda_major_version():
"""Detect the CUDA runtime major version."""
try:
import cuda.bindings
# Get CUDA version from cuda-bindings package version
# cuda-bindings version is in format like "12.9.1" or "13.0.0"
version_str = cuda.bindings.__version__
major = int(version_str.split(".")[0])
return major
except ImportError:
raise ImportError(
"cuda-bindings is required for runtime CUDA version detection. "
"Install with: pip install pynvbench[cu12] or pip install pynvbench[cu13]"
)
_cuda_major = _get_cuda_major_version()
_extra_name = f"cu{_cuda_major}"
_module_fullname = f"cuda.bench.{_extra_name}._nvbench"
try:
_nvbench_module = importlib.import_module(_module_fullname)
except ImportError as e:
raise ImportError(
f"No pynvbench extension found for CUDA {_cuda_major}.x. "
f"This wheel may not include support for your CUDA version. "
f"Supported CUDA versions: 12, 13. "
f"Original error: {e}"
)
# Load required NVIDIA libraries
for libname in ("cupti", "nvperf_target", "nvperf_host"):
load_nvidia_dynamic_lib(libname)
from cuda.bench._nvbench import ( # noqa: E402
Benchmark as Benchmark,
)
from cuda.bench._nvbench import ( # noqa: E402
CudaStream as CudaStream,
)
from cuda.bench._nvbench import ( # noqa: E402
Launch as Launch,
)
from cuda.bench._nvbench import ( # noqa: E402
NVBenchRuntimeError as NVBenchRuntimeError,
)
from cuda.bench._nvbench import ( # noqa: E402
State as State,
)
from cuda.bench._nvbench import ( # noqa: E402
register as register,
)
from cuda.bench._nvbench import ( # noqa: E402
run_all_benchmarks as run_all_benchmarks,
)
# Import and expose all public symbols from the CUDA-specific extension
Benchmark = _nvbench_module.Benchmark
CudaStream = _nvbench_module.CudaStream
Launch = _nvbench_module.Launch
NVBenchRuntimeError = _nvbench_module.NVBenchRuntimeError
State = _nvbench_module.State
register = _nvbench_module.register
run_all_benchmarks = _nvbench_module.run_all_benchmarks
test_cpp_exception = _nvbench_module.test_cpp_exception
test_py_exception = _nvbench_module.test_py_exception
del load_nvidia_dynamic_lib
# Expose the module as _nvbench for backward compatibility (e.g., for tests)
_nvbench = _nvbench_module
# Clean up internal symbols
del (
load_nvidia_dynamic_lib,
_nvbench_module,
_cuda_major,
_extra_name,
_module_fullname,
_get_cuda_major_version,
)