v4.4 update. (#2979)

This commit is contained in:
Junkai-Wu
2026-01-25 00:46:17 +08:00
committed by GitHub
parent 2fafefb7b9
commit 9fba3195f9
293 changed files with 46344 additions and 2996 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,107 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass
import cutlass.cute as cute
import os
import subprocess
import cuda.bindings.driver as cuda
"""Example demonstrating how to export a CuTe function.
This example shows how to:
1. Compile a CuTe function
2. Export the compiled function to a object file and a C header file
3. Compile the object file to a shared library
To run this example:
.. code-block:: bash
# export the compiled function to a object file and a C header file
python cutlass_ir/compiler/python/examples/cute/export/export_to_c.py
"""
@cute.kernel
def print_tensor_kernel(a: cute.Tensor):
cute.printf("a: {}", a)
@cute.jit
def print_tensor(a: cute.Tensor, stream: cuda.CUstream):
print_tensor_kernel(a).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)
@cute.kernel
def add_one_kernel(a: cute.Tensor, b: cute.Tensor):
a[0] = b[0] + 1
@cute.jit
def add_one(a: cute.Tensor, b: cute.Tensor, stream: cuda.CUstream):
add_one_kernel(a, b).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)
def run():
from cutlass.cute.runtime import make_fake_compact_tensor
shape = (cute.SymInt(divisibility=16), cute.SymInt())
a = make_fake_compact_tensor(cutlass.Float32, shape, stride_order=(1, 0))
b = make_fake_compact_tensor(cutlass.Float32, shape, stride_order=(1, 0))
stream = cuda.CUstream(0)
compiled_print_tensor = cute.compile(print_tensor, a, stream=stream)
compiled_add_one = cute.compile(add_one, a, b, stream=stream)
os.makedirs("./build", exist_ok=True)
compiled_print_tensor.export_to_c(
file_path="./build",
file_name="print_tensor_example",
function_prefix="print_tensor",
)
compiled_add_one.export_to_c(
file_path="./build", file_name="add_one_example", function_prefix="add_one"
)
cc = os.environ.get("CC", "gcc")
# compile the object file to a shared library
cmd = [
cc,
"-shared",
"-o",
"./build/libexport_example.so",
"./build/print_tensor_example.o",
"./build/add_one_example.o",
]
subprocess.run(cmd, check=True)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,79 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass
import cutlass.cute as cute
import cuda.bindings.driver as cuda
"""Example demonstrating how to load a CuTe module/function in Python.
This example shows how to:
1. Load a CuTe module from a object file or a shared library
2. Extract the function from the module and call it in Python
To run this example:
.. code-block:: bash
# prerequesites: export the compiled functions to object files and compile them into a shared library
python examples/cute/export/export_to_c.py
# load the module from a object file or a shared library
python examples/cute/export/load_in_python.py
"""
def run():
import torch
from cutlass.cute.runtime import from_dlpack
a = (
torch.arange(16 * 10, dtype=torch.float32, device="cuda")
.reshape(16, 10)
.permute(1, 0)
)
b = torch.ones(16, 10, dtype=torch.float32, device="cuda").permute(1, 0)
a_cute = from_dlpack(a).mark_layout_dynamic()
b_cute = from_dlpack(b).mark_layout_dynamic()
stream = cuda.CUstream(0)
print_tensor_mod = cute.runtime.load_module("./build/print_tensor_example.o")
print_tensor_mod.print_tensor(a_cute, stream=stream)
add_one_mod = cute.runtime.load_module("./build/add_one_example.o")
add_one_mod.add_one(a_cute, b_cute, stream=stream)
assert a[0, 0] == b[0, 0] + 1
shared_mod = cute.runtime.load_module("./build/libexport_example.so")
shared_mod.print_tensor(a_cute, stream=stream)
shared_mod.add_one(a_cute, b_cute, stream=stream)
assert a[0, 0] == b[0, 0] + 1
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,254 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
/**
* This example demonstrates how to load a CuTe module/function from a object
* file or a shared library and run it in a CUDA kernel.
*
* To run this example:
*
* .. code-block:: bash
*
* # prerequesites: export the compiled functions to object files and compile
* them into a shared library python examples/cute/export/export_to_c.py # run
* the example bash ./examples/cute/export/run_with_dynamic_loading.sh
*/
#include "CuteDSLRuntime.h"
#include <cuda_runtime.h>
#include <fstream>
#include <vector>
std::vector<unsigned char> read_file(const std::string &filename) {
std::ifstream file(filename, std::ios::binary);
std::vector<unsigned char> content((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
return content;
}
void initialize_cuda_context() {
// Initialize cuda context
cudaSetDevice(0);
}
void check_error(CuteDSLRT_Error_t error) {
if (error != CuteDSLRT_Error_Success) {
printf("Got runtime error: %d\n", error);
exit(1);
}
}
// Copy the definition of the tensor from `print_tensor_example.h` to here
typedef struct {
void *data;
int32_t dynamic_shapes[2];
int64_t dynamic_strides[1];
} print_tensor_Tensor_a_t;
// Copy the definition of the tensor from `add_one_example.h` to here
typedef struct {
void *data;
int32_t dynamic_shapes[2];
int64_t dynamic_strides[1];
} add_one_Tensor_a_t;
typedef struct {
void *data;
int32_t dynamic_shapes[2];
int64_t dynamic_strides[1];
} add_one_Tensor_b_t;
print_tensor_Tensor_a_t prepare_print_tensor_tensor() {
print_tensor_Tensor_a_t tensor;
tensor.data = NULL;
tensor.dynamic_shapes[0] = 32;
tensor.dynamic_shapes[1] = 16;
tensor.dynamic_strides[0] = 16;
return tensor;
}
add_one_Tensor_a_t prepare_add_one_tensor_a() {
float a_host_ptr[32 * 16];
for (int i = 0; i < 32 * 16; i++) {
a_host_ptr[i] = i;
}
void *a_device_ptr;
cudaMalloc(&a_device_ptr, sizeof(float) * 32 * 16);
cudaMemcpy(a_device_ptr, a_host_ptr, sizeof(float) * 32 * 16,
cudaMemcpyHostToDevice);
add_one_Tensor_a_t tensor;
tensor.data = (void *)a_device_ptr;
tensor.dynamic_shapes[0] = 32;
tensor.dynamic_shapes[1] = 16;
tensor.dynamic_strides[0] = 16;
return tensor;
}
add_one_Tensor_b_t prepare_add_one_tensor_b() {
float b_host_ptr[32 * 16];
for (int i = 0; i < 32 * 16; i++) {
b_host_ptr[i] = 32 * 16 - i;
}
void *b_device_ptr;
cudaMalloc(&b_device_ptr, sizeof(float) * 32 * 16);
cudaMemcpy(b_device_ptr, b_host_ptr, sizeof(float) * 32 * 16,
cudaMemcpyHostToDevice);
add_one_Tensor_b_t tensor;
tensor.data = (void *)b_device_ptr;
tensor.dynamic_shapes[0] = 32;
tensor.dynamic_shapes[1] = 16;
tensor.dynamic_strides[0] = 16;
return tensor;
}
void run_print_tensor_with_object_file() {
// prepare tensor
print_tensor_Tensor_a_t tensor = prepare_print_tensor_tensor();
// prepare stream
cudaStream_t stream;
cudaStreamCreate(&stream);
// load module and function
CuteDSLRT_Module_t *module = nullptr;
std::vector<unsigned char> print_tensor_example_o_bytes =
read_file("build/print_tensor_example.o");
size_t print_tensor_example_o_size = print_tensor_example_o_bytes.size();
CuteDSLRT_Error_t error = CuteDSLRT_Module_Create_From_Bytes(
&module, print_tensor_example_o_bytes.data(), print_tensor_example_o_size,
nullptr, 0);
check_error(error);
CuteDSLRT_Function_t *function = nullptr;
error = CuteDSLRT_Module_Get_Function(&function, module, "print_tensor");
check_error(error);
// run kernel, refer to the wrapper function in `print_tensor_example.h`
int32_t cuda_result;
void *args[3] = {&tensor, &stream, &cuda_result};
error = CuteDSLRT_Function_Run(function, args, 3);
check_error(error);
// synchronize stream
cudaStreamSynchronize(stream);
// unload module
error = CuteDSLRT_Module_Destroy(module);
check_error(error);
cudaStreamDestroy(stream);
}
void run_add_one_with_object_file() {
// prepare a tensor
add_one_Tensor_a_t tensor_a = prepare_add_one_tensor_a();
// prepare b tensor
add_one_Tensor_b_t tensor_b = prepare_add_one_tensor_b();
// prepare stream
cudaStream_t stream;
cudaStreamCreate(&stream);
// load module
CuteDSLRT_Module_t *module = nullptr;
std::vector<unsigned char> add_one_example_o_bytes =
read_file("build/add_one_example.o");
size_t add_one_example_o_size = add_one_example_o_bytes.size();
CuteDSLRT_Error_t error = CuteDSLRT_Module_Create_From_Bytes(
&module, add_one_example_o_bytes.data(), add_one_example_o_size, nullptr,
0);
check_error(error);
CuteDSLRT_Function_t *function = nullptr;
error = CuteDSLRT_Module_Get_Function(&function, module, "add_one");
check_error(error);
// run kernel, refer to the wrapper function in `add_one_example.h`
int32_t cuda_result;
void *args[4] = {&tensor_a, &tensor_b, &stream, &cuda_result};
error = CuteDSLRT_Function_Run(function, args, 4);
check_error(error);
cudaStreamSynchronize(stream);
// unload module
error = CuteDSLRT_Module_Destroy(module);
check_error(error);
cudaStreamDestroy(stream);
// check result
float a_host_ptr[32 * 16];
cudaMemcpy(a_host_ptr, tensor_a.data, sizeof(float) * 32 * 16,
cudaMemcpyDeviceToHost);
if (a_host_ptr[0] != 32 * 16 + 1) {
printf("Error: a_host_ptr[0] = %f, expected %d\n", a_host_ptr[0], 32 * 16);
exit(1);
}
}
void run_example_with_shared_library() {
// prepare tensor
print_tensor_Tensor_a_t tensor = prepare_print_tensor_tensor();
// prepare a tensor
add_one_Tensor_a_t tensor_a = prepare_add_one_tensor_a();
// prepare b tensor
add_one_Tensor_b_t tensor_b = prepare_add_one_tensor_b();
// prepare stream
cudaStream_t stream;
cudaStreamCreate(&stream);
// load module
CuteDSLRT_Module_t *module = nullptr;
const char *shared_libs[] = {"build/libexport_example.so"};
CuteDSLRT_Error_t error =
CuteDSLRT_Module_Create_From_Bytes(&module, nullptr, 0, shared_libs, 1);
check_error(error);
// get print tensor function
CuteDSLRT_Function_t *print_tensor_function = nullptr;
error = CuteDSLRT_Module_Get_Function(&print_tensor_function, module,
"print_tensor");
check_error(error);
// get add one function
CuteDSLRT_Function_t *add_one_function = nullptr;
error = CuteDSLRT_Module_Get_Function(&add_one_function, module, "add_one");
check_error(error);
// run print tensor kernel
int32_t cuda_result;
void *print_tensor_args[3] = {&tensor, &stream, &cuda_result};
error = CuteDSLRT_Function_Run(print_tensor_function, print_tensor_args, 3);
check_error(error);
cudaStreamSynchronize(stream);
// run add one kernel
void *add_one_args[4] = {&tensor_a, &tensor_b, &stream, &cuda_result};
error = CuteDSLRT_Function_Run(add_one_function, add_one_args, 4);
check_error(error);
cudaStreamSynchronize(stream);
// unload module
error = CuteDSLRT_Module_Destroy(module);
check_error(error);
cudaStreamDestroy(stream);
}
int main() {
initialize_cuda_context();
run_print_tensor_with_object_file();
run_add_one_with_object_file();
run_example_with_shared_library();
return 0;
}

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#!/bin/bash
set -eu
# Try to find the wheel path of nvidia-cutlass-dsl
WHEEL_PATH=$(python3 -c "import cutlass, os; print(os.path.dirname(cutlass.__file__))" 2>/dev/null)/../..
if [[ -z "$WHEEL_PATH" ]]; then
echo "nvidia-cutlass-dsl wheel not found in the current Python environment."
exit 1
else
echo "nvidia-cutlass-dsl wheel path found at: $WHEEL_PATH"
fi
CUTE_DSL_LIB_PATH="${WHEEL_PATH}/lib/"
export LD_LIBRARY_PATH=${CUTE_DSL_LIB_PATH}:${CUDA_HOME}/lib64:./build
if [ -z "$CUDA_HOME" ]; then
CUDA_HOME=/usr/local/cuda
echo "CUDA_HOME not set, using default: $CUDA_HOME"
else
echo "CUDA_HOME found: $CUDA_HOME"
fi
SOURCE_FILE="$(dirname "$0")/run_with_dynamic_loading.cpp"
echo "Compiling the executable..."
# Search for a common C++ compiler: g++, clang++, or c++
if [ -n "${CXX-}" ] && command -v "$CXX" &> /dev/null; then
CXX="$CXX"
elif command -v g++ &> /dev/null; then
CXX="g++"
elif command -v clang++ &> /dev/null; then
CXX="clang++"
elif command -v c++ &> /dev/null; then
CXX="c++"
else
echo "Error: No common C++ compiler found (g++, clang++, or c++). Please install a C++ compiler to continue."
exit 1
fi
$CXX -o build/run_with_dynamic_loading \
-I${CUDA_HOME}/include \
-I${WHEEL_PATH}/include \
${SOURCE_FILE} \
-L${CUTE_DSL_LIB_PATH} \
-L${CUDA_HOME}/lib64 \
-L./build \
-lcudart \
-lcute_dsl_runtime \
-lexport_example
echo "Running the executable..."
./build/run_with_dynamic_loading

View File

@@ -0,0 +1,124 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
/**
* This example demonstrates how to load a CuTe module/function from compilation
* and static linking and run it in a CUDA kernel.
*
* To run this example:
*
* .. code-block:: bash
*
* # prerequesites: export the compiled functions to object files and compile
* them into a shared library python examples/cute/export/export_to_c.py # run
* the example bash ./examples/cute/export/run_with_static_linking.sh
*/
#include "add_one_example.h"
#include "print_tensor_example.h"
#include <cuda_runtime.h>
void initialize_cuda_context() {
// Initialize cuda context
int device_id = 0;
cudaSetDevice(device_id);
}
void run_print_tensor() {
// prepare tensor
print_tensor_Tensor_a_t tensor;
tensor.data = NULL;
tensor.dynamic_shapes[0] = 32;
tensor.dynamic_shapes[1] = 16;
tensor.dynamic_strides[0] = 16;
// prepare stream
cudaStream_t stream;
cudaStreamCreate(&stream);
// load module
print_tensor_Kernel_Module_t module;
print_tensor_Kernel_Module_Load(&module);
// run kernel
cute_dsl_print_tensor_wrapper(&module, &tensor, stream);
// synchronize stream
cudaStreamSynchronize(stream);
// unload module
print_tensor_Kernel_Module_Unload(&module);
cudaStreamDestroy(stream);
}
void run_add_one() {
// prepare a tensor
add_one_Tensor_a_t a_tensor;
float a_host_ptr[32 * 16];
for (int i = 0; i < 32 * 16; i++) {
a_host_ptr[i] = i;
}
void *a_device_ptr;
cudaMalloc(&a_device_ptr, sizeof(float) * 32 * 16);
cudaMemcpy(a_device_ptr, a_host_ptr, sizeof(float) * 32 * 16,
cudaMemcpyHostToDevice);
a_tensor.data = (void *)a_device_ptr;
a_tensor.dynamic_shapes[0] = 32;
a_tensor.dynamic_shapes[1] = 16;
a_tensor.dynamic_strides[0] = 16;
// prepare b tensor
add_one_Tensor_b_t b_tensor;
float b_host_ptr[32 * 16];
for (int i = 0; i < 32 * 16; i++) {
b_host_ptr[i] = 32 * 16 - i;
}
void *b_device_ptr;
cudaMalloc(&b_device_ptr, sizeof(float) * 32 * 16);
cudaMemcpy(b_device_ptr, b_host_ptr, sizeof(float) * 32 * 16,
cudaMemcpyHostToDevice);
b_tensor.data = (void *)b_device_ptr;
b_tensor.dynamic_shapes[0] = 32;
b_tensor.dynamic_shapes[1] = 16;
b_tensor.dynamic_strides[0] = 16;
// prepare stream
cudaStream_t stream;
cudaStreamCreate(&stream);
// load module
add_one_Kernel_Module_t module;
add_one_Kernel_Module_Load(&module);
// run kernel
cute_dsl_add_one_wrapper(&module, &a_tensor, &b_tensor, stream);
cudaStreamSynchronize(stream);
// unload module
add_one_Kernel_Module_Unload(&module);
cudaStreamDestroy(stream);
// check result
cudaMemcpy(a_host_ptr, a_device_ptr, sizeof(float) * 32 * 16,
cudaMemcpyDeviceToHost);
if (a_host_ptr[0] != 32 * 16 + 1) {
printf("Error: a_host_ptr[0] = %f, expected %d\n", a_host_ptr[0], 32 * 16);
}
}
int main() {
initialize_cuda_context();
run_print_tensor();
run_add_one();
return 0;
}

View File

@@ -0,0 +1,76 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#!/bin/bash
set -eu
# Try to find the wheel path of nvidia-cutlass-dsl
WHEEL_PATH=$(python3 -c "import cutlass, os; print(os.path.dirname(cutlass.__file__))" 2>/dev/null)/../..
if [[ -z "$WHEEL_PATH" ]]; then
echo "nvidia-cutlass-dsl wheel not found in the current Python environment."
exit 1
else
echo "nvidia-cutlass-dsl wheel path found at: $WHEEL_PATH"
fi
CUTE_DSL_LIB_PATH="${WHEEL_PATH}/lib/"
export LD_LIBRARY_PATH=${CUTE_DSL_LIB_PATH}
if [ -z "$CUDA_HOME" ]; then
CUDA_HOME=/usr/local/cuda
echo "CUDA_HOME not set, using default: $CUDA_HOME"
else
echo "CUDA_HOME found: $CUDA_HOME"
fi
SOURCE_FILE="$(dirname "$0")/run_with_static_linking.cpp"
echo "Compiling the executable..."
# Search for a common C++ compiler: g++, clang++, or c++
if [ -n "${CXX-}" ] && command -v "$CXX" &> /dev/null; then
CXX="$CXX"
elif command -v g++ &> /dev/null; then
CXX="g++"
elif command -v clang++ &> /dev/null; then
CXX="clang++"
elif command -v c++ &> /dev/null; then
CXX="c++"
else
echo "Error: No common C++ compiler found (g++, clang++, or c++). Please install a C++ compiler to continue."
exit 1
fi
# Note: The options -ldl, -lrt, and -lpthread are optional to satisfy symbol dependencies required by `libcudart_static.a`.
$CXX -o build/run_with_static_linking \
-I${CUDA_HOME}/include \
-I./build \
${SOURCE_FILE} build/add_one_example.o build/print_tensor_example.o \
${CUTE_DSL_LIB_PATH}/libcuda_dialect_runtime_static.a \
${CUDA_HOME}/lib64/libcudart_static.a -ldl -lrt -lpthread
echo "Running the executable..."
./build/run_with_static_linking

View File

@@ -37,21 +37,19 @@ To run this example:
.. code-block:: bash
python examples/cute/tvm_ffi/aot_export.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_export.py
# run example to use in torch
python examples/cute/tvm_ffi/aot_use_in_torch.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_torch.py
# run example to use in jax
python examples/cute/tvm_ffi/aot_use_in_jax.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_jax.py
# run example to use in c++ bundle
bash examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
bash cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
"""
from pathlib import Path
import torch
import os
import subprocess
import tvm_ffi
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@@ -69,6 +67,8 @@ def add_one(a: cute.Tensor, b: cute.Tensor):
def main():
import torch
# compile the kernel with "--enable-tvm-ffi" option
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")

View File

@@ -15,32 +15,31 @@
// This example shows how to interface with an AOT compiled function in a C++
// bundle. to build and run the example, run the following command in project
// root bash
// examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
// cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
#include <cuda_runtime.h>
#include <iostream>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/module.h>
#include <iostream>
#include <vector>
namespace ffi = tvm::ffi;
struct CUDANDAlloc {
void AllocData(DLTensor *tensor) {
void AllocData(DLTensor* tensor) {
size_t data_size = ffi::GetDataSize(*tensor);
void *ptr = nullptr;
void* ptr = nullptr;
cudaError_t err = cudaMalloc(&ptr, data_size);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaMalloc failed: " << cudaGetErrorString(err);
TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << cudaGetErrorString(err);
tensor->data = ptr;
}
void FreeData(DLTensor *tensor) {
void FreeData(DLTensor* tensor) {
if (tensor->data != nullptr) {
cudaError_t err = cudaFree(tensor->data);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaFree failed: " << cudaGetErrorString(err);
TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << cudaGetErrorString(err);
tensor->data = nullptr;
}
}
@@ -51,47 +50,45 @@ inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice device) {
}
// symbol from the shared library
extern "C" int __tvm_ffi_add_one(void *, const TVMFFIAny *, int32_t,
TVMFFIAny *);
extern "C" int __tvm_ffi_add_one(void*, const TVMFFIAny*, int32_t, TVMFFIAny*);
// Redirects into the exported function in object
void CallAddOne(ffi::TensorView x, ffi::TensorView y) {
tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add_one, x, y);
tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add_one, x, y);
}
int main() {
DLDataType f32_dtype{kDLFloat, 32, 1};
DLDevice cuda_device{kDLCUDA, 0};
DLDataType f32_dtype{kDLFloat, 32, 1};
DLDevice cuda_device{kDLCUDA, 0};
constexpr int ARRAY_SIZE = 10;
constexpr int ARRAY_SIZE = 10;
ffi::Tensor x = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
ffi::Tensor y = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
ffi::Tensor x = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
ffi::Tensor y = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
std::vector<float> host_x(ARRAY_SIZE);
for (int i = 0; i < ARRAY_SIZE; ++i) {
host_x[i] = static_cast<float>(i);
}
std::vector<float> host_x(ARRAY_SIZE);
for (int i = 0; i < ARRAY_SIZE; ++i) {
host_x[i] = static_cast<float>(i);
}
size_t nbytes = host_x.size() * sizeof(float);
cudaError_t err =
cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, cudaMemcpyHostToDevice);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
size_t nbytes = host_x.size() * sizeof(float);
cudaError_t err = cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, cudaMemcpyHostToDevice);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
// Call into the FFI function; tensors remain on device because they carry a
// kDLCUDA device tag.
CallAddOne(x, y);
// Call into the FFI function; tensors remain on device because they carry a
// kDLCUDA device tag.
CallAddOne(x, y);
std::vector<float> host_y(host_x.size());
err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, cudaMemcpyDeviceToHost);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
std::vector<float> host_y(host_x.size());
err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, cudaMemcpyDeviceToHost);
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
<< "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
std::cout << "y after add_one_cuda(x, y)" << std::endl;
for (float value : host_y) {
std::cout << value << " ";
}
std::cout << std::endl;
return 0;
std::cout << "y after add_one_cuda(x, y)" << std::endl;
for (float value : host_y) {
std::cout << value << " ";
}
std::cout << std::endl;
return 0;
}

View File

@@ -27,8 +27,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#!/bin/bash
CUDA_DIALECT_PATH="build/lib/"
export LD_LIBRARY_PATH=${CUDA_DIALECT_PATH}:`tvm-ffi-config --libdir`
# Set up library paths for runtime
export LD_LIBRARY_PATH=$(python3 -m cutlass.cute.export.aot_config --libdir):$(tvm-ffi-config --libdir)
CUDA_HOME=/usr/local/cuda
SOURCE_FILE="$(dirname "$0")/aot_use_in_cpp_bundle.cpp"
@@ -38,11 +38,11 @@ g++ -o build/aot_use_in_cpp_bundle \
-I${CUDA_HOME}/include \
`tvm-ffi-config --cxxflags` \
${SOURCE_FILE} build/add_one.o \
-L${CUDA_DIALECT_PATH} \
$(python3 -m cutlass.cute.export.aot_config --ldflags) \
-L${CUDA_HOME}/lib64 \
-lcuda_dialect_runtime -lcuda -lcudart \
`tvm-ffi-config --ldflags` \
`tvm-ffi-config --libs`
$(python3 -m cutlass.cute.export.aot_config --libs) -lcuda -lcudart \
$(tvm-ffi-config --ldflags) \
$(tvm-ffi-config --libs)
echo "Running the executable..."
./build/aot_use_in_cpp_bundle

View File

@@ -37,7 +37,7 @@ def main():
a_jax = jnp.arange(10, dtype=jnp.float32)
b_jax = jnp.zeros(10, dtype=jnp.float32)
lib_path = "./build/add_one.so"
aot_mod = cute.runtime.load_module(lib_path)
aot_mod = cute.runtime.load_module(lib_path, enable_tvm_ffi=True)
jax_tvm_ffi.register_ffi_target("add_one_cute", aot_mod.add_one, platform="gpu")
b_jax = jax.ffi.ffi_call(
"add_one_cute",

View File

@@ -27,14 +27,16 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
import torch
# now load it back
def main():
import torch
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")
lib_path = "./build/add_one.so"
aot_mod = cute.runtime.load_module(lib_path)
aot_mod = cute.runtime.load_module(lib_path, enable_tvm_ffi=True)
aot_mod.add_one(a_torch, b_torch)
print("result of b after aot_mod.add_one(a, b)")
print(b_torch)

View File

@@ -36,8 +36,9 @@ To run this example:
.. code-block:: bash
python examples/cute/tvm_ffi/error_reporting.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/error_reporting.py
"""
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

View File

@@ -39,7 +39,7 @@ To run this example:
pip install jax-tvm-ffi
pip install jax[cuda13]
python examples/cute/tvm_ffi/jit_and_use_in_jax.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/jit_and_use_in_jax.py
"""
import jax

View File

@@ -36,9 +36,9 @@ To run this example:
.. code-block:: bash
python examples/cute/tvm_ffi/jit_and_use_in_torch.py
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/jit_and_use_in_torch.py
"""
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@@ -56,6 +56,8 @@ def add_one(a: cute.Tensor, b: cute.Tensor):
def main():
import torch
# compile the kernel with "--enable-tvm-ffi" option
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")

View File

@@ -1,2 +1,2 @@
apache-tvm-ffi
torch-c-dlpack-ext
torch-c-dlpack-ext

View File

@@ -0,0 +1,260 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from functools import partial
import jax
import jax.numpy as jnp
import cutlass
import cutlass.cute as cute
import cutlass.jax as cjax
import cuda.bindings.driver as cuda
"""
Examples of calling CuTe DSL from jax.jit function using cutlass_call.
cutlass_call is a Jax primitive the enables calling of CuTe DSL kernels within a
a jit-compiled Jax function. During the lowering process cutlass_call will
trigger compilation of the kernel and embed it into the HLO computation. It can
then be efficiently launched by XLA without callback to Python.
This example assumes familiarity with CuTe DSL concepts such as layouts and
dynamic shapes.
To run this example:
.. code-block:: bash
# Run with addition operation
python examples/jax/cutlass_call_basic.py
"""
# This is a typical CuTe DSL kernel function that accepts both tensor and scalar values.
@cute.jit
def launch(
A: cute.Tensor,
B: cute.Tensor,
x: cute.Int32,
y: cute.Int32,
C: cute.Tensor,
D: cute.Tensor,
stream: cuda.CUstream,
):
# Print layouts
print("A layout: ", A.layout)
print("B layout: ", B.layout)
print("C layout: ", C.layout)
print("D layout: ", D.layout)
cute.printf("A layout: {}", A.layout)
cute.printf("B layout: {}", B.layout)
cute.printf("C layout: {}", C.layout)
cute.printf("D layout: {}", D.layout)
cute.printf("")
# Print non-tensor values
print("X is: ", x)
print("Y is: ", y)
cute.printf("X is: {}", x)
cute.printf("Y is: {}", y)
print()
# cutlass_call uses a fixed function signature to pass arguments between Jax and CuTeDSL kernel.
#
# Function Signature Requirement:
# stream, inputs, outputs, *, kwargs...
#
# The first argument must be the CUstream that the kernel is run. This stream is managed by the XLA runtime
# and is necessary to schedule and synchronize launches with the rest of your computation.
#
# The second set of arguments are the Jax arrays for inputs and outputs. Inputs must be passed before
# outputs.
#
# Lastly static arguments (i.e. static_argnums or static_argnames) values are passed as keyword only arguments
# by name.
#
# The the kernel does not match this signature a wrapper functions like the one shown below can be written
# or an inline lambda function can be used to rebind the arguments into the appropriate order.
@cute.jit
def launch_jax_wrapper(
stream: cuda.CUstream,
A: cute.Tensor,
B: cute.Tensor,
C: cute.Tensor,
D: cute.Tensor,
*,
x: cute.Int32,
y: cute.Int32,
):
launch(A, B, x, y, C, D, stream)
@cute.jit
def launch_aliased(
A: cute.Tensor, B: cute.Tensor, x: cute.Int32, y: cute.Int32, stream: cuda.CUstream
):
# Print layouts
print("A layout: ", A.layout)
print("B layout: ", B.layout)
cute.printf("A layout: {}", A.layout)
cute.printf("B layout: {}", B.layout)
cute.printf("")
# Print non-tensor values
print("X is: ", x)
print("Y is: ", y)
cute.printf("X is: {}", x)
cute.printf("Y is: {}", y)
print()
if __name__ == "__main__":
@partial(jax.jit, static_argnums=[2, 3])
def run_cutlass_kernel(a, b, x, y):
call = cjax.cutlass_call(
launch_jax_wrapper,
# Jax requires output shapes/dtype information for each output
output_shape_dtype=(
jax.ShapeDtypeStruct(a.shape, a.dtype),
jax.ShapeDtypeStruct(b.shape, a.dtype),
),
# Static jit arguments are passed via additional keyword arguments
x=x,
y=y,
)
# Returned value is a callable to invoke the kernel passing only jax arrays.
return call(a, b)
print("\nExample: example_basic_call_from_jit")
A = jnp.zeros((512, 32, 64))
B = jnp.zeros((1, 256, 64, 128))
C, D = run_cutlass_kernel(A, B, 0, 1)
@partial(jax.jit, static_argnums=[2, 3])
def run_cutlass_kernel_lambda(a, b, x, y):
call = cjax.cutlass_call(
# A lambda function may be used to wrap and bind arguments passed by jax
# to the kernel. Alternatively you can wrap using another separate cute.jit
# function.
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
# Jax requires output shapes/dtype information for each output
output_shape_dtype=(
jax.ShapeDtypeStruct(a.shape, a.dtype),
jax.ShapeDtypeStruct(b.shape, a.dtype),
),
# Static jit arguments are passed via additional keyword arguments
x=x,
y=y,
)
# Returned value is a callable to invoke the kernel passing only jax arrays.
return call(a, b)
print("\nExample: run_cutlass_kernel_lambda")
A = jnp.zeros((512, 32, 64))
B = jnp.zeros((1, 256, 64, 128))
C, D = run_cutlass_kernel_lambda(A, B, 1, 2)
@partial(jax.jit, static_argnums=[2, 3])
def run_cutlass_kernel_static_shapes(a, b, x, y):
call = cjax.cutlass_call(
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
output_shape_dtype=(
jax.ShapeDtypeStruct(a.shape, a.dtype),
jax.ShapeDtypeStruct(b.shape, a.dtype),
),
# By default cutlass_call will treat all tensors as dynamic shape.
# Dynamic shapes are often expected for kernels so this default ensures
# the broadest support. If you know that a kernel can accept fully static
# tensors then you can enable this flag to pass all tensors shapes and
# layouts known at compile time.
use_static_tensors=True,
x=x,
y=y,
)
return call(a, b)
print("\nExample: run_cutlass_kernel_static_shapes")
A = jnp.zeros((512, 32, 64))
B = jnp.zeros((1, 256, 64, 128))
C, D = run_cutlass_kernel_static_shapes(A, B, 3, 4)
@partial(jax.jit, static_argnums=[2, 3])
def run_cutlass_kernel_with_modes(a, b, x, y):
call = cjax.cutlass_call(
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
output_shape_dtype=(
jax.ShapeDtypeStruct(a.shape, a.dtype),
jax.ShapeDtypeStruct(b.shape, a.dtype),
),
# The modes of the layout for each tensor can be specified using the
# TensorSpec. By default modes will align with the physical layout
# but can be mapped to specific index position. If None is passed
# then the default mode is assumed for that tensor.
#
# Individual static/dynamic settings may also be applied. For example
# a specific tensor can be marked to have static shape.
input_spec=(
cjax.TensorSpec(mode=(1, 0, 2), static=True),
cjax.TensorSpec(mode=(3, 1, 2, 0)),
),
output_spec=(None, cjax.TensorSpec(mode=(0, 1, 3, 2))),
x=x,
y=y,
)
return call(a, b)
print("\nExample: run_cutlass_kernel_with_modes")
A = jnp.zeros((512, 32, 64))
B = jnp.zeros((1, 256, 64, 128))
C, D = run_cutlass_kernel_with_modes(A, B, 5, 6)
@partial(jax.jit, static_argnums=[2, 3], donate_argnums=[0, 1])
def run_cutlass_kernel_aliased_outputs(a, b, x, y):
call = cjax.cutlass_call(
lambda stream, a, b, *, x, y: launch_aliased(a, b, x, y, stream),
output_shape_dtype=(
jax.ShapeDtypeStruct(a.shape, a.dtype),
jax.ShapeDtypeStruct(b.shape, b.dtype),
),
# Can specify the input tensors that are aliasing outputs of this call.
# To avoid allocating separate output buffers. This is useful for kernels
# that update a tensor.
input_output_aliases={0: 0, 1: 1},
x=x,
y=y,
)
return call(a, b)
print("\nExample: run_cutlass_kernel_aliased_outputs")
A = jnp.zeros((512, 32, 64))
B = jnp.zeros((1, 256, 64, 128))
A, B = run_cutlass_kernel_aliased_outputs(A, B, 7, 8)

View File

@@ -0,0 +1,168 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
from functools import partial
import argparse
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import jax
import jax.numpy as jnp
from jax import export
from cutlass.jax import cutlass_call, get_export_disabled_safety_checks
from cutlass.jax.testing import create_tensor
"""
Examples of using jax.export APIs with functions using cutlass_call.
This example demonstrates the use of jax.export with CuTe DSL kernel. It assumes
familiarity with CuTe DSL concepts such as layouts and dynamic shapes as well as
Jax's exporting and serialization features:
https://docs.jax.dev/en/latest/export/index.html#export
To run this example:
.. code-block:: bash
# Run with defaults
python examples/jax/cutlass_call_export.py
# Run with shape (1024, 512)
python examples/jax/cutlass_call_export.py --M 1024 --N 512
# Export with symbolic shapes.
python examples/jax/cutlass_call_export.py --export_symbolic
"""
@cute.kernel
def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
thread_idx = bidx * bdim + tidx
m, n = gA.shape
ni = thread_idx % n
mi = thread_idx // n
a_val = gA[mi, ni]
b_val = gB[mi, ni]
gC[mi, ni] = a_val + b_val
@cute.jit
def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
print("mA: ", mA.layout)
print("mB: ", mB.layout)
print("mC: ", mC.layout)
num_threads_per_block = 256
m, n = mA.shape
kernel(mA, mB, mC).launch(
grid=((m * n) // num_threads_per_block, 1, 1),
block=(num_threads_per_block, 1, 1),
stream=stream,
)
def run_example(M, N, export_symbolic_shapes):
@jax.jit
def f(a, b):
call = cutlass_call(launch, output_shape_dtype=a)
return jax.nn.sigmoid(call(a, b))
@jax.jit
def ref_f(a, b):
return jax.nn.sigmoid(a + b)
# Symbolic or partially shapes are supported by cutlass_call and cute.Tensor
# This allows export of functions calling Cut eDSL kernels w/o having to re-compile
# the kernel for each new shape.
if export_symbolic_shapes:
a, b = export.symbolic_shape("a, b")
export_shape_dtype = jax.ShapeDtypeStruct((a, b), jnp.float32)
else:
export_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)
print("Exporting with input signature: ")
print(f"({export_shape_dtype}, {export_shape_dtype})")
# jax.export can be used to export a jit function containing cutlass_call.
# The function get_export_disabled_safety_checks() returns a list of custom
# call targets that are used by cutlass_call not part of Jax's built-in
# list of stable custom calls.
exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())
traced = exported(export_shape_dtype, export_shape_dtype)
# Serialize the computation to a byte blob.
blob = traced.serialize()
print(f"Serialized computation is {len(blob)} bytes.")
# Deserialize and run
rehydrated = export.deserialize(blob)
key = jax.random.key(1123)
a_key, b_key = jax.random.split(key, 2)
a = create_tensor((M, N), dtype=jnp.float32, key=a_key)
b = create_tensor((M, N), dtype=jnp.float32, key=b_key)
c = rehydrated.call(a, b)
assert jnp.allclose(c, ref_f(a, b))
# If the computation was exported with dynamic shapes then we can also
# call it with different shapes. The kernel will not be re-compiled
# even though the shapes are changing.
if export_symbolic_shapes:
a = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=a_key)
b = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=b_key)
c = rehydrated.call(a, b)
assert jnp.allclose(c, ref_f(a, b))
a = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=a_key)
b = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=b_key)
c = rehydrated.call(a, b)
assert jnp.allclose(c, ref_f(a, b))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Demonstration of using jax.export with functions with cutlass_call"
)
parser.add_argument("--M", default=512, type=int)
parser.add_argument("--N", default=256, type=int)
parser.add_argument("--export_symbolic", action="store_true")
args = parser.parse_args()
run_example(args.M, args.N, args.export_symbolic)
print("PASS")

View File

@@ -0,0 +1,140 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from functools import partial
import argparse
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.custom_partitioning import custom_partitioning
import cutlass
import cutlass.cute as cute
import cutlass.jax as cjax
from cutlass.jax.testing import create_tensor
import cuda.bindings.driver as cuda
"""
Examples of combining jax.jit and jax.shard_map for sharding and executing kernels
across multiple GPU devices.
To run this example:
.. code-block:: bash
# Run with addition operation
python examples/jax/cutlass_call_sharding.py
"""
@cute.kernel
def kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)
frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)
frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)
cute.autovec_copy(a[None, tidx, bidx], frgA)
cute.autovec_copy(b[None, tidx, bidx], frgB)
frgC.store(frgA.load() + frgB.load())
cute.autovec_copy(frgC, c[None, tidx, bidx])
@cute.jit
def launch(
stream: cuda.CUstream,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
):
cute.printf("a: {}", a.layout)
cute.printf("b: {}", b.layout)
cute.printf("c: {}", c.layout)
kernel(a, b, c).launch(
grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream
)
def run_example():
# Create a device mesh with one axis b
ngpu = jax.device_count()
mesh = jax.make_mesh((ngpu,), "b")
if ngpu == 1:
print("Note: only 1 GPU was detected.")
# We will shard our 3D tensors over b
sharding = P("b", None, None)
@partial(jax.jit, static_argnums=[0, 1])
def allocate_sharded_tensors(shape, dtype):
key = jax.random.key(1123)
a_key, b_keys = jax.random.split(key, 2)
a = create_tensor(shape, dtype, a_key)
b = create_tensor(shape, dtype, b_keys)
a = jax.lax.with_sharding_constraint(a, NamedSharding(mesh, sharding))
b = jax.lax.with_sharding_constraint(b, NamedSharding(mesh, sharding))
return a, b
@jax.jit
def compute(a, b):
# This jax.shard_map partitions the cutlass_call over the mesh.
@partial(
jax.shard_map,
mesh=mesh,
in_specs=(sharding, sharding),
out_specs=(sharding, sharding),
)
def sharded_call(a_block, b_block):
call = cjax.cutlass_call(
launch,
use_static_tensors=True,
output_shape_dtype=jax.ShapeDtypeStruct(a_block.shape, a_block.dtype),
)
ref_result = a_block + b_block
return call(a_block, b_block), ref_result
return sharded_call(a, b)
# Allocate (32, 16, 64) on each GPU
shape = (32 * ngpu, 16, 64)
dtype = jnp.float32
a, b = allocate_sharded_tensors(shape, dtype)
c, c_ref = compute(a, b)
assert jnp.allclose(c, c_ref)
if __name__ == "__main__":
run_example()
print("PASS")

View File

@@ -0,0 +1,329 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import operator
from functools import partial
from typing import List, Type
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
"""
An Elementwise Apply Example using CuTe DSL with cutlass.jax.cutlass_call
This example is similar to examples/ampere/elementwise_apply.py but demonstrates
how to run the code in a jax specific way using the cutlass_call primitive. It assumes
familiarity with basic CuTe DSL concepts as well as the cutlass_call primitive.
To run this example:
.. code-block:: bash
# Run with addition operation
python examples/jax/elementwise_apply_example.py --M 1024 --N 512 --op add
# Run with multiplication operation
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op mul
# Run with subtraction operation
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op sub
"""
@cute.kernel
def elementwise_apply_kernel(
op: cutlass.Constexpr,
mInputs: List[cute.Tensor],
mC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
tv_layout: cute.Layout, # (tid, vid) -> logic coord
):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
###############################################################################
# Slice to local tile of thread block
###############################################################################
blk_crd = ((None, None), (bidx, bidy))
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
# All for loops below on input tensors would be fully unrolled automatically at compile time
# logical coord -> memory address
gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN)
gC = mC[blk_crd] # (TileM, TileN)
gCrd = cC[blk_crd] # (TileM, TileN)
print("[DSL INFO] Sliced Tensors per thread block:")
for i in cutlass.range_constexpr(len(gInputs)):
print(f"[DSL INFO] ctaInputs{i} = {gInputs[i].type}")
print(f"[DSL INFO] gC = {gC.type}")
print(f"[DSL INFO] gCrd = {gCrd.type}")
###############################################################################
# Compose with thread block TV layout to map thread & value indices to memory address
###############################################################################
# (tid, vid) -> memory address
tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs]
tidfrgC = cute.composition(gC, tv_layout)
tidfrgCrd = cute.composition(gCrd, tv_layout)
# repeat None like vid to remove hierarchy of layout
thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1]))
###############################################################################
# Slice to local tile of thread
###############################################################################
# vid -> address
thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V)
thrC = tidfrgC[thr_crd] # (V)
thrCrd = tidfrgCrd[thr_crd]
print("[DSL INFO] Sliced Tensors per thread:")
for i in cutlass.range_constexpr(len(thrInputs)):
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
print(f"[DSL INFO] thrC = {thrC.type}")
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
###############################################################################
# Compute predicate for out of boundary checks
###############################################################################
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
print(f"[DSL INFO] frgPred = {frgPred.type}")
for i in cutlass.range_constexpr(cute.size(frgPred)):
frgPred[i] = cute.elem_less(thrCrd[i], shape)
# if tidx == 0 and bidx == 0:
# cute.print_tensor(frgPred)
##########################################################
# Load data and compute result
##########################################################
# Load data before use. The compiler will optimize the copy and load
# operations to convert some memory ld/st into register uses.
result = op(*[thrInput.load() for thrInput in thrInputs])
thrC.store(result)
@cute.jit
def elementwise_apply(
op: cutlass.Constexpr, inputs, result: cute.Tensor, stream: cuda.CUstream
):
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
CuTe Python and store to result tensor.
:param op: Binary operator or lambda function to apply element-wise
:type op: cutlass.Constexpr
:param a: First input tensor
:type a: cute.Tensor
:param b: Second input tensor
:type b: cute.Tensor
:param result: Output tensor to store the results of op(a, b)
:type result: cute.Tensor
:return: None
:rtype: None
"""
# Baseline: naive TV layout
# * mA layout: (4096, 4096):(4096, 1)
# * TV layout map to (512, 4) tile
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
# cta_tiler = (512, 4)
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
# * mA layout: (4096, 4096):(4096, 1)
# * TV layout map to (4, 512) tile
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
# cta_tiler = (4, 512)
# Opt-2: 2D tile but worse
# * mA layout: (4096, 4096):(4096, 1)
# * TV layout map to (128, 16) logical tile
# * V layout is bad as contiguous mode is not on right-most
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
# cta_tiler = (128, 16)
# Opt-3: SOL with 2D thread tile
# * mA layout: (4096, 4096):(4096, 1)
# * TV layout map to (64, 256) logical tile
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
# Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type
coalesced_ldst_bytes = 16
# Compile time validation: expect same element type for all input tensors
assert all(t.element_type == inputs[0].element_type for t in inputs)
dtype = inputs[0].element_type
thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))
val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))
val_layout = cute.recast_layout(dtype.width, 8, val_layout)
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
print("[DSL INFO] Input Tensors:")
for i, t in enumerate(inputs):
print(f"[DSL INFO] inputs{i} = {t}")
print(f"[DSL INFO] result = {result}")
print("[DSL INFO] Tiling Parameters:")
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
print(f"[DSL INFO] tv_layout = {tv_layout}")
print("[DSL INFO] Tiled Tensors:")
mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs]
# ((TileM, TileN), (RestM, RestN))
mC = cute.zipped_divide(result, tiler_mn)
# (RestM, RestN) -> (RestN, RestM)
remap_block = cute.make_ordered_layout(
cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0)
)
for i, t in enumerate(mInputs):
print(f"[DSL INFO] gInputs{i} = {mInputs[i]}")
mInputs[i] = cute.composition(t, (None, remap_block))
print(f"[DSL INFO] gInputs{i} (remapped) = {mInputs[i]}")
mC = cute.composition(mC, (None, remap_block))
print(f"[DSL INFO] gC = {mC}")
idC = cute.make_identity_tensor(result.shape)
cC = cute.zipped_divide(idC, tiler=tiler_mn)
print(f"[DSL INFO] coord tensor = {cC}")
# Launch the kernel asynchronously
# Group input tensors into a list as a single argument
elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch(
# Compute production at each mode of mC.shape[1] to get multi-dimensional grid size
grid=cute.product_each(mC.shape[1]),
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
)
@cutlass.dsl_user_op
def leaky_relu(x, alpha, *, loc=None, ip=None):
return cute.where(x > 0, x, alpha * x, loc=loc, ip=ip)
def leaky_relu_ref(x, alpha):
import jax.numpy as jnp
return jnp.where(x > 0, x, alpha * x)
def run_and_verify(op, M, N, dtype, skip_ref_check=False):
import jax
import jax.numpy as jnp
import cutlass.jax as cjax
import cutlass.jax.testing as testing
if op == "leaky_relu":
op = partial(leaky_relu, alpha=0.01)
ref_op = partial(leaky_relu_ref, alpha=0.01)
num_inputs = 1
else:
op = getattr(operator, op)
ref_op = op
num_inputs = 2
# This jax function is transformed using jax.jit to compile its contents
# into an efficient HLO executable.
@partial(jax.jit, static_argnums=[1])
def jax_function(inputs, op):
call = cjax.cutlass_call(
# Bind jax arguments to kernel signature
lambda stream, inputs, output, *, op: elementwise_apply(
op, inputs, output, stream
),
# Specify output shape/dtype of result
output_shape_dtype=jax.ShapeDtypeStruct(inputs[0].shape, inputs[0].dtype),
# Pass static/constexpr values as kwargs
op=op,
)
# Call the kernel!
return call(inputs)
@partial(jax.jit, static_argnums=[1])
def jax_ref_function(inputs, op):
return op(*inputs)
print("\nRunning Elementwise Apply test with:")
print(f"Tensor dimensions: [{M}, {N}]")
print(f"Input and Output Data type: {dtype}")
jax_dtype = cjax.cutlass_to_jax_dtype(dtype)
keys = jax.random.split(jax.random.key(1435), num_inputs)
inputs = [testing.create_tensor((M, N), jax_dtype, key) for key in keys]
print("Input tensor shapes:")
for i in range(num_inputs):
print(f"inputs[{i}]: {inputs[i].shape}, dtype: {inputs[i].dtype}")
epsilon = 1.2
if op in (operator.truediv, operator.floordiv):
inputs[1] = jnp.where(inputs[1] == 0, epsilon, inputs[1])
# Call the jax.jit function which will compile the kernel
c = jax_function(inputs, op)
if not skip_ref_check:
print("Executing elementwise apply kernel...")
c = jax_function(inputs, op)
print("Verifying results...")
assert jnp.allclose(ref_op(*inputs), c)
print("Results verified successfully!")
print(f"First few elements of result: \n{c[:3, :3]}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Demonstration of calling a kernel with cutlass_call"
)
parser.add_argument("--M", default=4096, type=int)
parser.add_argument("--N", default=4096, type=int)
parser.add_argument("--op", default="add", type=str)
parser.add_argument("--skip_ref_check", action="store_true")
args = parser.parse_args()
run_and_verify(
args.op,
args.M,
args.N,
dtype=cutlass.Float32,
skip_ref_check=args.skip_ref_check,
)
print("\nPASS")