mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.4 update. (#2979)
This commit is contained in:
1832
examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py
Normal file
1832
examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py
Normal file
File diff suppressed because it is too large
Load Diff
107
examples/python/CuTeDSL/cute/export/export_to_c.py
Normal file
107
examples/python/CuTeDSL/cute/export/export_to_c.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
"""Example demonstrating how to export a CuTe function.
|
||||
|
||||
This example shows how to:
|
||||
1. Compile a CuTe function
|
||||
2. Export the compiled function to a object file and a C header file
|
||||
3. Compile the object file to a shared library
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# export the compiled function to a object file and a C header file
|
||||
python cutlass_ir/compiler/python/examples/cute/export/export_to_c.py
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def print_tensor_kernel(a: cute.Tensor):
|
||||
cute.printf("a: {}", a)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def print_tensor(a: cute.Tensor, stream: cuda.CUstream):
|
||||
print_tensor_kernel(a).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def add_one_kernel(a: cute.Tensor, b: cute.Tensor):
|
||||
a[0] = b[0] + 1
|
||||
|
||||
|
||||
@cute.jit
|
||||
def add_one(a: cute.Tensor, b: cute.Tensor, stream: cuda.CUstream):
|
||||
add_one_kernel(a, b).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)
|
||||
|
||||
|
||||
def run():
|
||||
from cutlass.cute.runtime import make_fake_compact_tensor
|
||||
|
||||
shape = (cute.SymInt(divisibility=16), cute.SymInt())
|
||||
a = make_fake_compact_tensor(cutlass.Float32, shape, stride_order=(1, 0))
|
||||
b = make_fake_compact_tensor(cutlass.Float32, shape, stride_order=(1, 0))
|
||||
stream = cuda.CUstream(0)
|
||||
compiled_print_tensor = cute.compile(print_tensor, a, stream=stream)
|
||||
compiled_add_one = cute.compile(add_one, a, b, stream=stream)
|
||||
|
||||
os.makedirs("./build", exist_ok=True)
|
||||
compiled_print_tensor.export_to_c(
|
||||
file_path="./build",
|
||||
file_name="print_tensor_example",
|
||||
function_prefix="print_tensor",
|
||||
)
|
||||
compiled_add_one.export_to_c(
|
||||
file_path="./build", file_name="add_one_example", function_prefix="add_one"
|
||||
)
|
||||
|
||||
cc = os.environ.get("CC", "gcc")
|
||||
|
||||
# compile the object file to a shared library
|
||||
cmd = [
|
||||
cc,
|
||||
"-shared",
|
||||
"-o",
|
||||
"./build/libexport_example.so",
|
||||
"./build/print_tensor_example.o",
|
||||
"./build/add_one_example.o",
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
79
examples/python/CuTeDSL/cute/export/load_in_python.py
Normal file
79
examples/python/CuTeDSL/cute/export/load_in_python.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
"""Example demonstrating how to load a CuTe module/function in Python.
|
||||
|
||||
This example shows how to:
|
||||
1. Load a CuTe module from a object file or a shared library
|
||||
2. Extract the function from the module and call it in Python
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# prerequesites: export the compiled functions to object files and compile them into a shared library
|
||||
python examples/cute/export/export_to_c.py
|
||||
# load the module from a object file or a shared library
|
||||
python examples/cute/export/load_in_python.py
|
||||
"""
|
||||
|
||||
|
||||
def run():
|
||||
import torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
a = (
|
||||
torch.arange(16 * 10, dtype=torch.float32, device="cuda")
|
||||
.reshape(16, 10)
|
||||
.permute(1, 0)
|
||||
)
|
||||
b = torch.ones(16, 10, dtype=torch.float32, device="cuda").permute(1, 0)
|
||||
a_cute = from_dlpack(a).mark_layout_dynamic()
|
||||
b_cute = from_dlpack(b).mark_layout_dynamic()
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
print_tensor_mod = cute.runtime.load_module("./build/print_tensor_example.o")
|
||||
print_tensor_mod.print_tensor(a_cute, stream=stream)
|
||||
|
||||
add_one_mod = cute.runtime.load_module("./build/add_one_example.o")
|
||||
add_one_mod.add_one(a_cute, b_cute, stream=stream)
|
||||
assert a[0, 0] == b[0, 0] + 1
|
||||
|
||||
shared_mod = cute.runtime.load_module("./build/libexport_example.so")
|
||||
shared_mod.print_tensor(a_cute, stream=stream)
|
||||
shared_mod.add_one(a_cute, b_cute, stream=stream)
|
||||
assert a[0, 0] == b[0, 0] + 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
254
examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.cpp
Normal file
254
examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.cpp
Normal file
@@ -0,0 +1,254 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
/**
|
||||
* This example demonstrates how to load a CuTe module/function from a object
|
||||
* file or a shared library and run it in a CUDA kernel.
|
||||
*
|
||||
* To run this example:
|
||||
*
|
||||
* .. code-block:: bash
|
||||
*
|
||||
* # prerequesites: export the compiled functions to object files and compile
|
||||
* them into a shared library python examples/cute/export/export_to_c.py # run
|
||||
* the example bash ./examples/cute/export/run_with_dynamic_loading.sh
|
||||
*/
|
||||
|
||||
#include "CuteDSLRuntime.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
std::vector<unsigned char> read_file(const std::string &filename) {
|
||||
std::ifstream file(filename, std::ios::binary);
|
||||
std::vector<unsigned char> content((std::istreambuf_iterator<char>(file)),
|
||||
std::istreambuf_iterator<char>());
|
||||
return content;
|
||||
}
|
||||
|
||||
void initialize_cuda_context() {
|
||||
// Initialize cuda context
|
||||
cudaSetDevice(0);
|
||||
}
|
||||
|
||||
void check_error(CuteDSLRT_Error_t error) {
|
||||
if (error != CuteDSLRT_Error_Success) {
|
||||
printf("Got runtime error: %d\n", error);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the definition of the tensor from `print_tensor_example.h` to here
|
||||
typedef struct {
|
||||
void *data;
|
||||
int32_t dynamic_shapes[2];
|
||||
int64_t dynamic_strides[1];
|
||||
} print_tensor_Tensor_a_t;
|
||||
|
||||
// Copy the definition of the tensor from `add_one_example.h` to here
|
||||
typedef struct {
|
||||
void *data;
|
||||
int32_t dynamic_shapes[2];
|
||||
int64_t dynamic_strides[1];
|
||||
} add_one_Tensor_a_t;
|
||||
typedef struct {
|
||||
void *data;
|
||||
int32_t dynamic_shapes[2];
|
||||
int64_t dynamic_strides[1];
|
||||
} add_one_Tensor_b_t;
|
||||
|
||||
print_tensor_Tensor_a_t prepare_print_tensor_tensor() {
|
||||
print_tensor_Tensor_a_t tensor;
|
||||
tensor.data = NULL;
|
||||
tensor.dynamic_shapes[0] = 32;
|
||||
tensor.dynamic_shapes[1] = 16;
|
||||
tensor.dynamic_strides[0] = 16;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
add_one_Tensor_a_t prepare_add_one_tensor_a() {
|
||||
float a_host_ptr[32 * 16];
|
||||
for (int i = 0; i < 32 * 16; i++) {
|
||||
a_host_ptr[i] = i;
|
||||
}
|
||||
void *a_device_ptr;
|
||||
cudaMalloc(&a_device_ptr, sizeof(float) * 32 * 16);
|
||||
cudaMemcpy(a_device_ptr, a_host_ptr, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyHostToDevice);
|
||||
add_one_Tensor_a_t tensor;
|
||||
tensor.data = (void *)a_device_ptr;
|
||||
tensor.dynamic_shapes[0] = 32;
|
||||
tensor.dynamic_shapes[1] = 16;
|
||||
tensor.dynamic_strides[0] = 16;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
add_one_Tensor_b_t prepare_add_one_tensor_b() {
|
||||
float b_host_ptr[32 * 16];
|
||||
for (int i = 0; i < 32 * 16; i++) {
|
||||
b_host_ptr[i] = 32 * 16 - i;
|
||||
}
|
||||
void *b_device_ptr;
|
||||
cudaMalloc(&b_device_ptr, sizeof(float) * 32 * 16);
|
||||
cudaMemcpy(b_device_ptr, b_host_ptr, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyHostToDevice);
|
||||
add_one_Tensor_b_t tensor;
|
||||
tensor.data = (void *)b_device_ptr;
|
||||
tensor.dynamic_shapes[0] = 32;
|
||||
tensor.dynamic_shapes[1] = 16;
|
||||
tensor.dynamic_strides[0] = 16;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void run_print_tensor_with_object_file() {
|
||||
|
||||
// prepare tensor
|
||||
print_tensor_Tensor_a_t tensor = prepare_print_tensor_tensor();
|
||||
|
||||
// prepare stream
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
// load module and function
|
||||
CuteDSLRT_Module_t *module = nullptr;
|
||||
std::vector<unsigned char> print_tensor_example_o_bytes =
|
||||
read_file("build/print_tensor_example.o");
|
||||
size_t print_tensor_example_o_size = print_tensor_example_o_bytes.size();
|
||||
CuteDSLRT_Error_t error = CuteDSLRT_Module_Create_From_Bytes(
|
||||
&module, print_tensor_example_o_bytes.data(), print_tensor_example_o_size,
|
||||
nullptr, 0);
|
||||
check_error(error);
|
||||
|
||||
CuteDSLRT_Function_t *function = nullptr;
|
||||
error = CuteDSLRT_Module_Get_Function(&function, module, "print_tensor");
|
||||
check_error(error);
|
||||
|
||||
// run kernel, refer to the wrapper function in `print_tensor_example.h`
|
||||
int32_t cuda_result;
|
||||
void *args[3] = {&tensor, &stream, &cuda_result};
|
||||
error = CuteDSLRT_Function_Run(function, args, 3);
|
||||
check_error(error);
|
||||
|
||||
// synchronize stream
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// unload module
|
||||
error = CuteDSLRT_Module_Destroy(module);
|
||||
check_error(error);
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
void run_add_one_with_object_file() {
|
||||
// prepare a tensor
|
||||
add_one_Tensor_a_t tensor_a = prepare_add_one_tensor_a();
|
||||
|
||||
// prepare b tensor
|
||||
add_one_Tensor_b_t tensor_b = prepare_add_one_tensor_b();
|
||||
|
||||
// prepare stream
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
// load module
|
||||
CuteDSLRT_Module_t *module = nullptr;
|
||||
std::vector<unsigned char> add_one_example_o_bytes =
|
||||
read_file("build/add_one_example.o");
|
||||
size_t add_one_example_o_size = add_one_example_o_bytes.size();
|
||||
CuteDSLRT_Error_t error = CuteDSLRT_Module_Create_From_Bytes(
|
||||
&module, add_one_example_o_bytes.data(), add_one_example_o_size, nullptr,
|
||||
0);
|
||||
check_error(error);
|
||||
|
||||
CuteDSLRT_Function_t *function = nullptr;
|
||||
error = CuteDSLRT_Module_Get_Function(&function, module, "add_one");
|
||||
check_error(error);
|
||||
|
||||
// run kernel, refer to the wrapper function in `add_one_example.h`
|
||||
int32_t cuda_result;
|
||||
void *args[4] = {&tensor_a, &tensor_b, &stream, &cuda_result};
|
||||
error = CuteDSLRT_Function_Run(function, args, 4);
|
||||
check_error(error);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// unload module
|
||||
error = CuteDSLRT_Module_Destroy(module);
|
||||
check_error(error);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// check result
|
||||
float a_host_ptr[32 * 16];
|
||||
cudaMemcpy(a_host_ptr, tensor_a.data, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyDeviceToHost);
|
||||
|
||||
if (a_host_ptr[0] != 32 * 16 + 1) {
|
||||
printf("Error: a_host_ptr[0] = %f, expected %d\n", a_host_ptr[0], 32 * 16);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void run_example_with_shared_library() {
|
||||
// prepare tensor
|
||||
print_tensor_Tensor_a_t tensor = prepare_print_tensor_tensor();
|
||||
|
||||
// prepare a tensor
|
||||
add_one_Tensor_a_t tensor_a = prepare_add_one_tensor_a();
|
||||
|
||||
// prepare b tensor
|
||||
add_one_Tensor_b_t tensor_b = prepare_add_one_tensor_b();
|
||||
|
||||
// prepare stream
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
// load module
|
||||
CuteDSLRT_Module_t *module = nullptr;
|
||||
const char *shared_libs[] = {"build/libexport_example.so"};
|
||||
CuteDSLRT_Error_t error =
|
||||
CuteDSLRT_Module_Create_From_Bytes(&module, nullptr, 0, shared_libs, 1);
|
||||
check_error(error);
|
||||
|
||||
// get print tensor function
|
||||
CuteDSLRT_Function_t *print_tensor_function = nullptr;
|
||||
error = CuteDSLRT_Module_Get_Function(&print_tensor_function, module,
|
||||
"print_tensor");
|
||||
check_error(error);
|
||||
|
||||
// get add one function
|
||||
CuteDSLRT_Function_t *add_one_function = nullptr;
|
||||
error = CuteDSLRT_Module_Get_Function(&add_one_function, module, "add_one");
|
||||
check_error(error);
|
||||
|
||||
// run print tensor kernel
|
||||
int32_t cuda_result;
|
||||
void *print_tensor_args[3] = {&tensor, &stream, &cuda_result};
|
||||
error = CuteDSLRT_Function_Run(print_tensor_function, print_tensor_args, 3);
|
||||
check_error(error);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// run add one kernel
|
||||
void *add_one_args[4] = {&tensor_a, &tensor_b, &stream, &cuda_result};
|
||||
error = CuteDSLRT_Function_Run(add_one_function, add_one_args, 4);
|
||||
check_error(error);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// unload module
|
||||
error = CuteDSLRT_Module_Destroy(module);
|
||||
check_error(error);
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
int main() {
|
||||
initialize_cuda_context();
|
||||
run_print_tensor_with_object_file();
|
||||
run_add_one_with_object_file();
|
||||
run_example_with_shared_library();
|
||||
return 0;
|
||||
}
|
||||
80
examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.sh
Executable file
80
examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.sh
Executable file
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#!/bin/bash
|
||||
set -eu
|
||||
|
||||
# Try to find the wheel path of nvidia-cutlass-dsl
|
||||
WHEEL_PATH=$(python3 -c "import cutlass, os; print(os.path.dirname(cutlass.__file__))" 2>/dev/null)/../..
|
||||
|
||||
if [[ -z "$WHEEL_PATH" ]]; then
|
||||
echo "nvidia-cutlass-dsl wheel not found in the current Python environment."
|
||||
exit 1
|
||||
else
|
||||
echo "nvidia-cutlass-dsl wheel path found at: $WHEEL_PATH"
|
||||
fi
|
||||
|
||||
CUTE_DSL_LIB_PATH="${WHEEL_PATH}/lib/"
|
||||
export LD_LIBRARY_PATH=${CUTE_DSL_LIB_PATH}:${CUDA_HOME}/lib64:./build
|
||||
|
||||
if [ -z "$CUDA_HOME" ]; then
|
||||
CUDA_HOME=/usr/local/cuda
|
||||
echo "CUDA_HOME not set, using default: $CUDA_HOME"
|
||||
else
|
||||
echo "CUDA_HOME found: $CUDA_HOME"
|
||||
fi
|
||||
SOURCE_FILE="$(dirname "$0")/run_with_dynamic_loading.cpp"
|
||||
|
||||
echo "Compiling the executable..."
|
||||
# Search for a common C++ compiler: g++, clang++, or c++
|
||||
if [ -n "${CXX-}" ] && command -v "$CXX" &> /dev/null; then
|
||||
CXX="$CXX"
|
||||
elif command -v g++ &> /dev/null; then
|
||||
CXX="g++"
|
||||
elif command -v clang++ &> /dev/null; then
|
||||
CXX="clang++"
|
||||
elif command -v c++ &> /dev/null; then
|
||||
CXX="c++"
|
||||
else
|
||||
echo "Error: No common C++ compiler found (g++, clang++, or c++). Please install a C++ compiler to continue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
$CXX -o build/run_with_dynamic_loading \
|
||||
-I${CUDA_HOME}/include \
|
||||
-I${WHEEL_PATH}/include \
|
||||
${SOURCE_FILE} \
|
||||
-L${CUTE_DSL_LIB_PATH} \
|
||||
-L${CUDA_HOME}/lib64 \
|
||||
-L./build \
|
||||
-lcudart \
|
||||
-lcute_dsl_runtime \
|
||||
-lexport_example
|
||||
|
||||
echo "Running the executable..."
|
||||
./build/run_with_dynamic_loading
|
||||
124
examples/python/CuTeDSL/cute/export/run_with_static_linking.cpp
Normal file
124
examples/python/CuTeDSL/cute/export/run_with_static_linking.cpp
Normal file
@@ -0,0 +1,124 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
/**
|
||||
* This example demonstrates how to load a CuTe module/function from compilation
|
||||
* and static linking and run it in a CUDA kernel.
|
||||
*
|
||||
* To run this example:
|
||||
*
|
||||
* .. code-block:: bash
|
||||
*
|
||||
* # prerequesites: export the compiled functions to object files and compile
|
||||
* them into a shared library python examples/cute/export/export_to_c.py # run
|
||||
* the example bash ./examples/cute/export/run_with_static_linking.sh
|
||||
*/
|
||||
|
||||
#include "add_one_example.h"
|
||||
#include "print_tensor_example.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
void initialize_cuda_context() {
|
||||
// Initialize cuda context
|
||||
int device_id = 0;
|
||||
cudaSetDevice(device_id);
|
||||
}
|
||||
|
||||
void run_print_tensor() {
|
||||
// prepare tensor
|
||||
print_tensor_Tensor_a_t tensor;
|
||||
tensor.data = NULL;
|
||||
tensor.dynamic_shapes[0] = 32;
|
||||
tensor.dynamic_shapes[1] = 16;
|
||||
tensor.dynamic_strides[0] = 16;
|
||||
|
||||
// prepare stream
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
// load module
|
||||
print_tensor_Kernel_Module_t module;
|
||||
print_tensor_Kernel_Module_Load(&module);
|
||||
|
||||
// run kernel
|
||||
cute_dsl_print_tensor_wrapper(&module, &tensor, stream);
|
||||
|
||||
// synchronize stream
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// unload module
|
||||
print_tensor_Kernel_Module_Unload(&module);
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
void run_add_one() {
|
||||
// prepare a tensor
|
||||
add_one_Tensor_a_t a_tensor;
|
||||
float a_host_ptr[32 * 16];
|
||||
for (int i = 0; i < 32 * 16; i++) {
|
||||
a_host_ptr[i] = i;
|
||||
}
|
||||
void *a_device_ptr;
|
||||
cudaMalloc(&a_device_ptr, sizeof(float) * 32 * 16);
|
||||
cudaMemcpy(a_device_ptr, a_host_ptr, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyHostToDevice);
|
||||
a_tensor.data = (void *)a_device_ptr;
|
||||
a_tensor.dynamic_shapes[0] = 32;
|
||||
a_tensor.dynamic_shapes[1] = 16;
|
||||
a_tensor.dynamic_strides[0] = 16;
|
||||
|
||||
// prepare b tensor
|
||||
add_one_Tensor_b_t b_tensor;
|
||||
float b_host_ptr[32 * 16];
|
||||
for (int i = 0; i < 32 * 16; i++) {
|
||||
b_host_ptr[i] = 32 * 16 - i;
|
||||
}
|
||||
void *b_device_ptr;
|
||||
cudaMalloc(&b_device_ptr, sizeof(float) * 32 * 16);
|
||||
cudaMemcpy(b_device_ptr, b_host_ptr, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
b_tensor.data = (void *)b_device_ptr;
|
||||
b_tensor.dynamic_shapes[0] = 32;
|
||||
b_tensor.dynamic_shapes[1] = 16;
|
||||
b_tensor.dynamic_strides[0] = 16;
|
||||
|
||||
// prepare stream
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
// load module
|
||||
add_one_Kernel_Module_t module;
|
||||
add_one_Kernel_Module_Load(&module);
|
||||
|
||||
// run kernel
|
||||
cute_dsl_add_one_wrapper(&module, &a_tensor, &b_tensor, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// unload module
|
||||
add_one_Kernel_Module_Unload(&module);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// check result
|
||||
cudaMemcpy(a_host_ptr, a_device_ptr, sizeof(float) * 32 * 16,
|
||||
cudaMemcpyDeviceToHost);
|
||||
|
||||
if (a_host_ptr[0] != 32 * 16 + 1) {
|
||||
printf("Error: a_host_ptr[0] = %f, expected %d\n", a_host_ptr[0], 32 * 16);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
initialize_cuda_context();
|
||||
run_print_tensor();
|
||||
run_add_one();
|
||||
return 0;
|
||||
}
|
||||
76
examples/python/CuTeDSL/cute/export/run_with_static_linking.sh
Executable file
76
examples/python/CuTeDSL/cute/export/run_with_static_linking.sh
Executable file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#!/bin/bash
|
||||
set -eu
|
||||
|
||||
# Try to find the wheel path of nvidia-cutlass-dsl
|
||||
WHEEL_PATH=$(python3 -c "import cutlass, os; print(os.path.dirname(cutlass.__file__))" 2>/dev/null)/../..
|
||||
|
||||
if [[ -z "$WHEEL_PATH" ]]; then
|
||||
echo "nvidia-cutlass-dsl wheel not found in the current Python environment."
|
||||
exit 1
|
||||
else
|
||||
echo "nvidia-cutlass-dsl wheel path found at: $WHEEL_PATH"
|
||||
fi
|
||||
CUTE_DSL_LIB_PATH="${WHEEL_PATH}/lib/"
|
||||
export LD_LIBRARY_PATH=${CUTE_DSL_LIB_PATH}
|
||||
|
||||
if [ -z "$CUDA_HOME" ]; then
|
||||
CUDA_HOME=/usr/local/cuda
|
||||
echo "CUDA_HOME not set, using default: $CUDA_HOME"
|
||||
else
|
||||
echo "CUDA_HOME found: $CUDA_HOME"
|
||||
fi
|
||||
SOURCE_FILE="$(dirname "$0")/run_with_static_linking.cpp"
|
||||
|
||||
echo "Compiling the executable..."
|
||||
# Search for a common C++ compiler: g++, clang++, or c++
|
||||
if [ -n "${CXX-}" ] && command -v "$CXX" &> /dev/null; then
|
||||
CXX="$CXX"
|
||||
elif command -v g++ &> /dev/null; then
|
||||
CXX="g++"
|
||||
elif command -v clang++ &> /dev/null; then
|
||||
CXX="clang++"
|
||||
elif command -v c++ &> /dev/null; then
|
||||
CXX="c++"
|
||||
else
|
||||
echo "Error: No common C++ compiler found (g++, clang++, or c++). Please install a C++ compiler to continue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Note: The options -ldl, -lrt, and -lpthread are optional to satisfy symbol dependencies required by `libcudart_static.a`.
|
||||
$CXX -o build/run_with_static_linking \
|
||||
-I${CUDA_HOME}/include \
|
||||
-I./build \
|
||||
${SOURCE_FILE} build/add_one_example.o build/print_tensor_example.o \
|
||||
${CUTE_DSL_LIB_PATH}/libcuda_dialect_runtime_static.a \
|
||||
${CUDA_HOME}/lib64/libcudart_static.a -ldl -lrt -lpthread
|
||||
|
||||
echo "Running the executable..."
|
||||
./build/run_with_static_linking
|
||||
@@ -37,21 +37,19 @@ To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/cute/tvm_ffi/aot_export.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_export.py
|
||||
# run example to use in torch
|
||||
python examples/cute/tvm_ffi/aot_use_in_torch.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_torch.py
|
||||
# run example to use in jax
|
||||
python examples/cute/tvm_ffi/aot_use_in_jax.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_jax.py
|
||||
# run example to use in c++ bundle
|
||||
bash examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
|
||||
bash cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import os
|
||||
import subprocess
|
||||
import tvm_ffi
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@@ -69,6 +67,8 @@ def add_one(a: cute.Tensor, b: cute.Tensor):
|
||||
|
||||
|
||||
def main():
|
||||
import torch
|
||||
|
||||
# compile the kernel with "--enable-tvm-ffi" option
|
||||
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
|
||||
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")
|
||||
|
||||
@@ -15,32 +15,31 @@
|
||||
// This example shows how to interface with an AOT compiled function in a C++
|
||||
// bundle. to build and run the example, run the following command in project
|
||||
// root bash
|
||||
// examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
|
||||
// cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <tvm/ffi/container/tensor.h>
|
||||
#include <tvm/ffi/error.h>
|
||||
#include <tvm/ffi/extra/module.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace ffi = tvm::ffi;
|
||||
|
||||
struct CUDANDAlloc {
|
||||
void AllocData(DLTensor *tensor) {
|
||||
void AllocData(DLTensor* tensor) {
|
||||
size_t data_size = ffi::GetDataSize(*tensor);
|
||||
void *ptr = nullptr;
|
||||
void* ptr = nullptr;
|
||||
cudaError_t err = cudaMalloc(&ptr, data_size);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaMalloc failed: " << cudaGetErrorString(err);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << cudaGetErrorString(err);
|
||||
tensor->data = ptr;
|
||||
}
|
||||
|
||||
void FreeData(DLTensor *tensor) {
|
||||
void FreeData(DLTensor* tensor) {
|
||||
if (tensor->data != nullptr) {
|
||||
cudaError_t err = cudaFree(tensor->data);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaFree failed: " << cudaGetErrorString(err);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << cudaGetErrorString(err);
|
||||
tensor->data = nullptr;
|
||||
}
|
||||
}
|
||||
@@ -51,47 +50,45 @@ inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice device) {
|
||||
}
|
||||
|
||||
// symbol from the shared library
|
||||
extern "C" int __tvm_ffi_add_one(void *, const TVMFFIAny *, int32_t,
|
||||
TVMFFIAny *);
|
||||
extern "C" int __tvm_ffi_add_one(void*, const TVMFFIAny*, int32_t, TVMFFIAny*);
|
||||
|
||||
// Redirects into the exported function in object
|
||||
void CallAddOne(ffi::TensorView x, ffi::TensorView y) {
|
||||
tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add_one, x, y);
|
||||
tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add_one, x, y);
|
||||
}
|
||||
|
||||
int main() {
|
||||
DLDataType f32_dtype{kDLFloat, 32, 1};
|
||||
DLDevice cuda_device{kDLCUDA, 0};
|
||||
DLDataType f32_dtype{kDLFloat, 32, 1};
|
||||
DLDevice cuda_device{kDLCUDA, 0};
|
||||
|
||||
constexpr int ARRAY_SIZE = 10;
|
||||
constexpr int ARRAY_SIZE = 10;
|
||||
|
||||
ffi::Tensor x = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
|
||||
ffi::Tensor y = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
|
||||
ffi::Tensor x = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
|
||||
ffi::Tensor y = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
|
||||
|
||||
std::vector<float> host_x(ARRAY_SIZE);
|
||||
for (int i = 0; i < ARRAY_SIZE; ++i) {
|
||||
host_x[i] = static_cast<float>(i);
|
||||
}
|
||||
std::vector<float> host_x(ARRAY_SIZE);
|
||||
for (int i = 0; i < ARRAY_SIZE; ++i) {
|
||||
host_x[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
size_t nbytes = host_x.size() * sizeof(float);
|
||||
cudaError_t err =
|
||||
cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, cudaMemcpyHostToDevice);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
|
||||
size_t nbytes = host_x.size() * sizeof(float);
|
||||
cudaError_t err = cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, cudaMemcpyHostToDevice);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
|
||||
|
||||
// Call into the FFI function; tensors remain on device because they carry a
|
||||
// kDLCUDA device tag.
|
||||
CallAddOne(x, y);
|
||||
// Call into the FFI function; tensors remain on device because they carry a
|
||||
// kDLCUDA device tag.
|
||||
CallAddOne(x, y);
|
||||
|
||||
std::vector<float> host_y(host_x.size());
|
||||
err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, cudaMemcpyDeviceToHost);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
|
||||
std::vector<float> host_y(host_x.size());
|
||||
err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, cudaMemcpyDeviceToHost);
|
||||
TVM_FFI_ICHECK_EQ(err, cudaSuccess)
|
||||
<< "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
|
||||
|
||||
std::cout << "y after add_one_cuda(x, y)" << std::endl;
|
||||
for (float value : host_y) {
|
||||
std::cout << value << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
return 0;
|
||||
std::cout << "y after add_one_cuda(x, y)" << std::endl;
|
||||
for (float value : host_y) {
|
||||
std::cout << value << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#!/bin/bash
|
||||
CUDA_DIALECT_PATH="build/lib/"
|
||||
export LD_LIBRARY_PATH=${CUDA_DIALECT_PATH}:`tvm-ffi-config --libdir`
|
||||
# Set up library paths for runtime
|
||||
export LD_LIBRARY_PATH=$(python3 -m cutlass.cute.export.aot_config --libdir):$(tvm-ffi-config --libdir)
|
||||
|
||||
CUDA_HOME=/usr/local/cuda
|
||||
SOURCE_FILE="$(dirname "$0")/aot_use_in_cpp_bundle.cpp"
|
||||
@@ -38,11 +38,11 @@ g++ -o build/aot_use_in_cpp_bundle \
|
||||
-I${CUDA_HOME}/include \
|
||||
`tvm-ffi-config --cxxflags` \
|
||||
${SOURCE_FILE} build/add_one.o \
|
||||
-L${CUDA_DIALECT_PATH} \
|
||||
$(python3 -m cutlass.cute.export.aot_config --ldflags) \
|
||||
-L${CUDA_HOME}/lib64 \
|
||||
-lcuda_dialect_runtime -lcuda -lcudart \
|
||||
`tvm-ffi-config --ldflags` \
|
||||
`tvm-ffi-config --libs`
|
||||
$(python3 -m cutlass.cute.export.aot_config --libs) -lcuda -lcudart \
|
||||
$(tvm-ffi-config --ldflags) \
|
||||
$(tvm-ffi-config --libs)
|
||||
|
||||
echo "Running the executable..."
|
||||
./build/aot_use_in_cpp_bundle
|
||||
|
||||
@@ -37,7 +37,7 @@ def main():
|
||||
a_jax = jnp.arange(10, dtype=jnp.float32)
|
||||
b_jax = jnp.zeros(10, dtype=jnp.float32)
|
||||
lib_path = "./build/add_one.so"
|
||||
aot_mod = cute.runtime.load_module(lib_path)
|
||||
aot_mod = cute.runtime.load_module(lib_path, enable_tvm_ffi=True)
|
||||
jax_tvm_ffi.register_ffi_target("add_one_cute", aot_mod.add_one, platform="gpu")
|
||||
b_jax = jax.ffi.ffi_call(
|
||||
"add_one_cute",
|
||||
|
||||
@@ -27,14 +27,16 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
|
||||
|
||||
# now load it back
|
||||
def main():
|
||||
import torch
|
||||
|
||||
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
|
||||
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")
|
||||
lib_path = "./build/add_one.so"
|
||||
aot_mod = cute.runtime.load_module(lib_path)
|
||||
aot_mod = cute.runtime.load_module(lib_path, enable_tvm_ffi=True)
|
||||
aot_mod.add_one(a_torch, b_torch)
|
||||
print("result of b after aot_mod.add_one(a, b)")
|
||||
print(b_torch)
|
||||
|
||||
@@ -36,8 +36,9 @@ To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/cute/tvm_ffi/error_reporting.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/error_reporting.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@@ -39,7 +39,7 @@ To run this example:
|
||||
pip install jax-tvm-ffi
|
||||
pip install jax[cuda13]
|
||||
|
||||
python examples/cute/tvm_ffi/jit_and_use_in_jax.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/jit_and_use_in_jax.py
|
||||
"""
|
||||
|
||||
import jax
|
||||
|
||||
@@ -36,9 +36,9 @@ To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/cute/tvm_ffi/jit_and_use_in_torch.py
|
||||
python cutlass_ir/compiler/python/examples/cute/tvm_ffi/jit_and_use_in_torch.py
|
||||
"""
|
||||
import torch
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
@@ -56,6 +56,8 @@ def add_one(a: cute.Tensor, b: cute.Tensor):
|
||||
|
||||
|
||||
def main():
|
||||
import torch
|
||||
|
||||
# compile the kernel with "--enable-tvm-ffi" option
|
||||
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
|
||||
b_torch = torch.zeros(10, dtype=torch.float32, device="cuda")
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
apache-tvm-ffi
|
||||
torch-c-dlpack-ext
|
||||
torch-c-dlpack-ext
|
||||
|
||||
260
examples/python/CuTeDSL/jax/cutlass_call_basic.py
Normal file
260
examples/python/CuTeDSL/jax/cutlass_call_basic.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from functools import partial
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.jax as cjax
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
"""
|
||||
Examples of calling CuTe DSL from jax.jit function using cutlass_call.
|
||||
|
||||
cutlass_call is a Jax primitive the enables calling of CuTe DSL kernels within a
|
||||
a jit-compiled Jax function. During the lowering process cutlass_call will
|
||||
trigger compilation of the kernel and embed it into the HLO computation. It can
|
||||
then be efficiently launched by XLA without callback to Python.
|
||||
|
||||
This example assumes familiarity with CuTe DSL concepts such as layouts and
|
||||
dynamic shapes.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with addition operation
|
||||
python examples/jax/cutlass_call_basic.py
|
||||
"""
|
||||
|
||||
|
||||
# This is a typical CuTe DSL kernel function that accepts both tensor and scalar values.
|
||||
@cute.jit
|
||||
def launch(
|
||||
A: cute.Tensor,
|
||||
B: cute.Tensor,
|
||||
x: cute.Int32,
|
||||
y: cute.Int32,
|
||||
C: cute.Tensor,
|
||||
D: cute.Tensor,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
# Print layouts
|
||||
print("A layout: ", A.layout)
|
||||
print("B layout: ", B.layout)
|
||||
print("C layout: ", C.layout)
|
||||
print("D layout: ", D.layout)
|
||||
cute.printf("A layout: {}", A.layout)
|
||||
cute.printf("B layout: {}", B.layout)
|
||||
cute.printf("C layout: {}", C.layout)
|
||||
cute.printf("D layout: {}", D.layout)
|
||||
cute.printf("")
|
||||
|
||||
# Print non-tensor values
|
||||
print("X is: ", x)
|
||||
print("Y is: ", y)
|
||||
cute.printf("X is: {}", x)
|
||||
cute.printf("Y is: {}", y)
|
||||
print()
|
||||
|
||||
|
||||
# cutlass_call uses a fixed function signature to pass arguments between Jax and CuTeDSL kernel.
|
||||
#
|
||||
# Function Signature Requirement:
|
||||
# stream, inputs, outputs, *, kwargs...
|
||||
#
|
||||
# The first argument must be the CUstream that the kernel is run. This stream is managed by the XLA runtime
|
||||
# and is necessary to schedule and synchronize launches with the rest of your computation.
|
||||
#
|
||||
# The second set of arguments are the Jax arrays for inputs and outputs. Inputs must be passed before
|
||||
# outputs.
|
||||
#
|
||||
# Lastly static arguments (i.e. static_argnums or static_argnames) values are passed as keyword only arguments
|
||||
# by name.
|
||||
#
|
||||
# The the kernel does not match this signature a wrapper functions like the one shown below can be written
|
||||
# or an inline lambda function can be used to rebind the arguments into the appropriate order.
|
||||
@cute.jit
|
||||
def launch_jax_wrapper(
|
||||
stream: cuda.CUstream,
|
||||
A: cute.Tensor,
|
||||
B: cute.Tensor,
|
||||
C: cute.Tensor,
|
||||
D: cute.Tensor,
|
||||
*,
|
||||
x: cute.Int32,
|
||||
y: cute.Int32,
|
||||
):
|
||||
launch(A, B, x, y, C, D, stream)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_aliased(
|
||||
A: cute.Tensor, B: cute.Tensor, x: cute.Int32, y: cute.Int32, stream: cuda.CUstream
|
||||
):
|
||||
# Print layouts
|
||||
print("A layout: ", A.layout)
|
||||
print("B layout: ", B.layout)
|
||||
cute.printf("A layout: {}", A.layout)
|
||||
cute.printf("B layout: {}", B.layout)
|
||||
cute.printf("")
|
||||
|
||||
# Print non-tensor values
|
||||
print("X is: ", x)
|
||||
print("Y is: ", y)
|
||||
cute.printf("X is: {}", x)
|
||||
cute.printf("Y is: {}", y)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@partial(jax.jit, static_argnums=[2, 3])
|
||||
def run_cutlass_kernel(a, b, x, y):
|
||||
call = cjax.cutlass_call(
|
||||
launch_jax_wrapper,
|
||||
# Jax requires output shapes/dtype information for each output
|
||||
output_shape_dtype=(
|
||||
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
jax.ShapeDtypeStruct(b.shape, a.dtype),
|
||||
),
|
||||
# Static jit arguments are passed via additional keyword arguments
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
|
||||
# Returned value is a callable to invoke the kernel passing only jax arrays.
|
||||
return call(a, b)
|
||||
|
||||
print("\nExample: example_basic_call_from_jit")
|
||||
A = jnp.zeros((512, 32, 64))
|
||||
B = jnp.zeros((1, 256, 64, 128))
|
||||
C, D = run_cutlass_kernel(A, B, 0, 1)
|
||||
|
||||
@partial(jax.jit, static_argnums=[2, 3])
|
||||
def run_cutlass_kernel_lambda(a, b, x, y):
|
||||
call = cjax.cutlass_call(
|
||||
# A lambda function may be used to wrap and bind arguments passed by jax
|
||||
# to the kernel. Alternatively you can wrap using another separate cute.jit
|
||||
# function.
|
||||
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
|
||||
# Jax requires output shapes/dtype information for each output
|
||||
output_shape_dtype=(
|
||||
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
jax.ShapeDtypeStruct(b.shape, a.dtype),
|
||||
),
|
||||
# Static jit arguments are passed via additional keyword arguments
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
|
||||
# Returned value is a callable to invoke the kernel passing only jax arrays.
|
||||
return call(a, b)
|
||||
|
||||
print("\nExample: run_cutlass_kernel_lambda")
|
||||
A = jnp.zeros((512, 32, 64))
|
||||
B = jnp.zeros((1, 256, 64, 128))
|
||||
C, D = run_cutlass_kernel_lambda(A, B, 1, 2)
|
||||
|
||||
@partial(jax.jit, static_argnums=[2, 3])
|
||||
def run_cutlass_kernel_static_shapes(a, b, x, y):
|
||||
call = cjax.cutlass_call(
|
||||
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
|
||||
output_shape_dtype=(
|
||||
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
jax.ShapeDtypeStruct(b.shape, a.dtype),
|
||||
),
|
||||
# By default cutlass_call will treat all tensors as dynamic shape.
|
||||
# Dynamic shapes are often expected for kernels so this default ensures
|
||||
# the broadest support. If you know that a kernel can accept fully static
|
||||
# tensors then you can enable this flag to pass all tensors shapes and
|
||||
# layouts known at compile time.
|
||||
use_static_tensors=True,
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
return call(a, b)
|
||||
|
||||
print("\nExample: run_cutlass_kernel_static_shapes")
|
||||
A = jnp.zeros((512, 32, 64))
|
||||
B = jnp.zeros((1, 256, 64, 128))
|
||||
C, D = run_cutlass_kernel_static_shapes(A, B, 3, 4)
|
||||
|
||||
@partial(jax.jit, static_argnums=[2, 3])
|
||||
def run_cutlass_kernel_with_modes(a, b, x, y):
|
||||
call = cjax.cutlass_call(
|
||||
lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream),
|
||||
output_shape_dtype=(
|
||||
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
jax.ShapeDtypeStruct(b.shape, a.dtype),
|
||||
),
|
||||
# The modes of the layout for each tensor can be specified using the
|
||||
# TensorSpec. By default modes will align with the physical layout
|
||||
# but can be mapped to specific index position. If None is passed
|
||||
# then the default mode is assumed for that tensor.
|
||||
#
|
||||
# Individual static/dynamic settings may also be applied. For example
|
||||
# a specific tensor can be marked to have static shape.
|
||||
input_spec=(
|
||||
cjax.TensorSpec(mode=(1, 0, 2), static=True),
|
||||
cjax.TensorSpec(mode=(3, 1, 2, 0)),
|
||||
),
|
||||
output_spec=(None, cjax.TensorSpec(mode=(0, 1, 3, 2))),
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
return call(a, b)
|
||||
|
||||
print("\nExample: run_cutlass_kernel_with_modes")
|
||||
A = jnp.zeros((512, 32, 64))
|
||||
B = jnp.zeros((1, 256, 64, 128))
|
||||
C, D = run_cutlass_kernel_with_modes(A, B, 5, 6)
|
||||
|
||||
@partial(jax.jit, static_argnums=[2, 3], donate_argnums=[0, 1])
|
||||
def run_cutlass_kernel_aliased_outputs(a, b, x, y):
|
||||
call = cjax.cutlass_call(
|
||||
lambda stream, a, b, *, x, y: launch_aliased(a, b, x, y, stream),
|
||||
output_shape_dtype=(
|
||||
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
jax.ShapeDtypeStruct(b.shape, b.dtype),
|
||||
),
|
||||
# Can specify the input tensors that are aliasing outputs of this call.
|
||||
# To avoid allocating separate output buffers. This is useful for kernels
|
||||
# that update a tensor.
|
||||
input_output_aliases={0: 0, 1: 1},
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
return call(a, b)
|
||||
|
||||
print("\nExample: run_cutlass_kernel_aliased_outputs")
|
||||
A = jnp.zeros((512, 32, 64))
|
||||
B = jnp.zeros((1, 256, 64, 128))
|
||||
A, B = run_cutlass_kernel_aliased_outputs(A, B, 7, 8)
|
||||
168
examples/python/CuTeDSL/jax/cutlass_call_export.py
Normal file
168
examples/python/CuTeDSL/jax/cutlass_call_export.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import pytest
|
||||
from functools import partial
|
||||
import argparse
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import export
|
||||
|
||||
from cutlass.jax import cutlass_call, get_export_disabled_safety_checks
|
||||
from cutlass.jax.testing import create_tensor
|
||||
|
||||
"""
|
||||
Examples of using jax.export APIs with functions using cutlass_call.
|
||||
|
||||
This example demonstrates the use of jax.export with CuTe DSL kernel. It assumes
|
||||
familiarity with CuTe DSL concepts such as layouts and dynamic shapes as well as
|
||||
Jax's exporting and serialization features:
|
||||
https://docs.jax.dev/en/latest/export/index.html#export
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with defaults
|
||||
python examples/jax/cutlass_call_export.py
|
||||
|
||||
# Run with shape (1024, 512)
|
||||
python examples/jax/cutlass_call_export.py --M 1024 --N 512
|
||||
|
||||
# Export with symbolic shapes.
|
||||
python examples/jax/cutlass_call_export.py --export_symbolic
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
bdim, _, _ = cute.arch.block_dim()
|
||||
|
||||
thread_idx = bidx * bdim + tidx
|
||||
|
||||
m, n = gA.shape
|
||||
ni = thread_idx % n
|
||||
mi = thread_idx // n
|
||||
|
||||
a_val = gA[mi, ni]
|
||||
b_val = gB[mi, ni]
|
||||
gC[mi, ni] = a_val + b_val
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
|
||||
print("mA: ", mA.layout)
|
||||
print("mB: ", mB.layout)
|
||||
print("mC: ", mC.layout)
|
||||
num_threads_per_block = 256
|
||||
m, n = mA.shape
|
||||
kernel(mA, mB, mC).launch(
|
||||
grid=((m * n) // num_threads_per_block, 1, 1),
|
||||
block=(num_threads_per_block, 1, 1),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def run_example(M, N, export_symbolic_shapes):
|
||||
@jax.jit
|
||||
def f(a, b):
|
||||
call = cutlass_call(launch, output_shape_dtype=a)
|
||||
return jax.nn.sigmoid(call(a, b))
|
||||
|
||||
@jax.jit
|
||||
def ref_f(a, b):
|
||||
return jax.nn.sigmoid(a + b)
|
||||
|
||||
# Symbolic or partially shapes are supported by cutlass_call and cute.Tensor
|
||||
# This allows export of functions calling Cut eDSL kernels w/o having to re-compile
|
||||
# the kernel for each new shape.
|
||||
if export_symbolic_shapes:
|
||||
a, b = export.symbolic_shape("a, b")
|
||||
export_shape_dtype = jax.ShapeDtypeStruct((a, b), jnp.float32)
|
||||
else:
|
||||
export_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)
|
||||
|
||||
print("Exporting with input signature: ")
|
||||
print(f"({export_shape_dtype}, {export_shape_dtype})")
|
||||
|
||||
# jax.export can be used to export a jit function containing cutlass_call.
|
||||
# The function get_export_disabled_safety_checks() returns a list of custom
|
||||
# call targets that are used by cutlass_call not part of Jax's built-in
|
||||
# list of stable custom calls.
|
||||
exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())
|
||||
traced = exported(export_shape_dtype, export_shape_dtype)
|
||||
|
||||
# Serialize the computation to a byte blob.
|
||||
blob = traced.serialize()
|
||||
print(f"Serialized computation is {len(blob)} bytes.")
|
||||
|
||||
# Deserialize and run
|
||||
rehydrated = export.deserialize(blob)
|
||||
|
||||
key = jax.random.key(1123)
|
||||
a_key, b_key = jax.random.split(key, 2)
|
||||
|
||||
a = create_tensor((M, N), dtype=jnp.float32, key=a_key)
|
||||
b = create_tensor((M, N), dtype=jnp.float32, key=b_key)
|
||||
c = rehydrated.call(a, b)
|
||||
assert jnp.allclose(c, ref_f(a, b))
|
||||
|
||||
# If the computation was exported with dynamic shapes then we can also
|
||||
# call it with different shapes. The kernel will not be re-compiled
|
||||
# even though the shapes are changing.
|
||||
if export_symbolic_shapes:
|
||||
a = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=a_key)
|
||||
b = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=b_key)
|
||||
c = rehydrated.call(a, b)
|
||||
assert jnp.allclose(c, ref_f(a, b))
|
||||
|
||||
a = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=a_key)
|
||||
b = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=b_key)
|
||||
c = rehydrated.call(a, b)
|
||||
assert jnp.allclose(c, ref_f(a, b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Demonstration of using jax.export with functions with cutlass_call"
|
||||
)
|
||||
parser.add_argument("--M", default=512, type=int)
|
||||
parser.add_argument("--N", default=256, type=int)
|
||||
parser.add_argument("--export_symbolic", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_example(args.M, args.N, args.export_symbolic)
|
||||
print("PASS")
|
||||
140
examples/python/CuTeDSL/jax/cutlass_call_sharding.py
Normal file
140
examples/python/CuTeDSL/jax/cutlass_call_sharding.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from functools import partial
|
||||
import argparse
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.jax as cjax
|
||||
from cutlass.jax.testing import create_tensor
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
"""
|
||||
Examples of combining jax.jit and jax.shard_map for sharding and executing kernels
|
||||
across multiple GPU devices.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with addition operation
|
||||
python examples/jax/cutlass_call_sharding.py
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)
|
||||
frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)
|
||||
frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)
|
||||
|
||||
cute.autovec_copy(a[None, tidx, bidx], frgA)
|
||||
cute.autovec_copy(b[None, tidx, bidx], frgB)
|
||||
frgC.store(frgA.load() + frgB.load())
|
||||
cute.autovec_copy(frgC, c[None, tidx, bidx])
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch(
|
||||
stream: cuda.CUstream,
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
):
|
||||
cute.printf("a: {}", a.layout)
|
||||
cute.printf("b: {}", b.layout)
|
||||
cute.printf("c: {}", c.layout)
|
||||
kernel(a, b, c).launch(
|
||||
grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream
|
||||
)
|
||||
|
||||
|
||||
def run_example():
|
||||
# Create a device mesh with one axis b
|
||||
ngpu = jax.device_count()
|
||||
mesh = jax.make_mesh((ngpu,), "b")
|
||||
|
||||
if ngpu == 1:
|
||||
print("Note: only 1 GPU was detected.")
|
||||
|
||||
# We will shard our 3D tensors over b
|
||||
sharding = P("b", None, None)
|
||||
|
||||
@partial(jax.jit, static_argnums=[0, 1])
|
||||
def allocate_sharded_tensors(shape, dtype):
|
||||
key = jax.random.key(1123)
|
||||
a_key, b_keys = jax.random.split(key, 2)
|
||||
a = create_tensor(shape, dtype, a_key)
|
||||
b = create_tensor(shape, dtype, b_keys)
|
||||
a = jax.lax.with_sharding_constraint(a, NamedSharding(mesh, sharding))
|
||||
b = jax.lax.with_sharding_constraint(b, NamedSharding(mesh, sharding))
|
||||
return a, b
|
||||
|
||||
@jax.jit
|
||||
def compute(a, b):
|
||||
# This jax.shard_map partitions the cutlass_call over the mesh.
|
||||
@partial(
|
||||
jax.shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(sharding, sharding),
|
||||
out_specs=(sharding, sharding),
|
||||
)
|
||||
def sharded_call(a_block, b_block):
|
||||
call = cjax.cutlass_call(
|
||||
launch,
|
||||
use_static_tensors=True,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(a_block.shape, a_block.dtype),
|
||||
)
|
||||
ref_result = a_block + b_block
|
||||
return call(a_block, b_block), ref_result
|
||||
|
||||
return sharded_call(a, b)
|
||||
|
||||
# Allocate (32, 16, 64) on each GPU
|
||||
shape = (32 * ngpu, 16, 64)
|
||||
dtype = jnp.float32
|
||||
|
||||
a, b = allocate_sharded_tensors(shape, dtype)
|
||||
c, c_ref = compute(a, b)
|
||||
|
||||
assert jnp.allclose(c, c_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_example()
|
||||
print("PASS")
|
||||
329
examples/python/CuTeDSL/jax/elementwise_apply_example.py
Normal file
329
examples/python/CuTeDSL/jax/elementwise_apply_example.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import argparse
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import List, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
"""
|
||||
An Elementwise Apply Example using CuTe DSL with cutlass.jax.cutlass_call
|
||||
|
||||
This example is similar to examples/ampere/elementwise_apply.py but demonstrates
|
||||
how to run the code in a jax specific way using the cutlass_call primitive. It assumes
|
||||
familiarity with basic CuTe DSL concepts as well as the cutlass_call primitive.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Run with addition operation
|
||||
python examples/jax/elementwise_apply_example.py --M 1024 --N 512 --op add
|
||||
|
||||
# Run with multiplication operation
|
||||
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op mul
|
||||
|
||||
# Run with subtraction operation
|
||||
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op sub
|
||||
"""
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def elementwise_apply_kernel(
|
||||
op: cutlass.Constexpr,
|
||||
mInputs: List[cute.Tensor],
|
||||
mC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
tv_layout: cute.Layout, # (tid, vid) -> logic coord
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
|
||||
###############################################################################
|
||||
# Slice to local tile of thread block
|
||||
###############################################################################
|
||||
blk_crd = ((None, None), (bidx, bidy))
|
||||
|
||||
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
|
||||
# All for loops below on input tensors would be fully unrolled automatically at compile time
|
||||
# logical coord -> memory address
|
||||
gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN)
|
||||
gC = mC[blk_crd] # (TileM, TileN)
|
||||
gCrd = cC[blk_crd] # (TileM, TileN)
|
||||
|
||||
print("[DSL INFO] Sliced Tensors per thread block:")
|
||||
for i in cutlass.range_constexpr(len(gInputs)):
|
||||
print(f"[DSL INFO] ctaInputs{i} = {gInputs[i].type}")
|
||||
print(f"[DSL INFO] gC = {gC.type}")
|
||||
print(f"[DSL INFO] gCrd = {gCrd.type}")
|
||||
|
||||
###############################################################################
|
||||
# Compose with thread block TV layout to map thread & value indices to memory address
|
||||
###############################################################################
|
||||
# (tid, vid) -> memory address
|
||||
tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs]
|
||||
tidfrgC = cute.composition(gC, tv_layout)
|
||||
tidfrgCrd = cute.composition(gCrd, tv_layout)
|
||||
|
||||
# repeat None like vid to remove hierarchy of layout
|
||||
thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1]))
|
||||
|
||||
###############################################################################
|
||||
# Slice to local tile of thread
|
||||
###############################################################################
|
||||
# vid -> address
|
||||
thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V)
|
||||
thrC = tidfrgC[thr_crd] # (V)
|
||||
thrCrd = tidfrgCrd[thr_crd]
|
||||
|
||||
print("[DSL INFO] Sliced Tensors per thread:")
|
||||
for i in cutlass.range_constexpr(len(thrInputs)):
|
||||
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
|
||||
print(f"[DSL INFO] thrC = {thrC.type}")
|
||||
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
||||
|
||||
###############################################################################
|
||||
# Compute predicate for out of boundary checks
|
||||
###############################################################################
|
||||
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
||||
print(f"[DSL INFO] frgPred = {frgPred.type}")
|
||||
|
||||
for i in cutlass.range_constexpr(cute.size(frgPred)):
|
||||
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Load data and compute result
|
||||
##########################################################
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = op(*[thrInput.load() for thrInput in thrInputs])
|
||||
thrC.store(result)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def elementwise_apply(
|
||||
op: cutlass.Constexpr, inputs, result: cute.Tensor, stream: cuda.CUstream
|
||||
):
|
||||
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
||||
CuTe Python and store to result tensor.
|
||||
|
||||
:param op: Binary operator or lambda function to apply element-wise
|
||||
:type op: cutlass.Constexpr
|
||||
:param a: First input tensor
|
||||
:type a: cute.Tensor
|
||||
:param b: Second input tensor
|
||||
:type b: cute.Tensor
|
||||
:param result: Output tensor to store the results of op(a, b)
|
||||
:type result: cute.Tensor
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
|
||||
# Baseline: naive TV layout
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (512, 4) tile
|
||||
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
|
||||
# cta_tiler = (512, 4)
|
||||
|
||||
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (4, 512) tile
|
||||
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
|
||||
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
|
||||
# cta_tiler = (4, 512)
|
||||
|
||||
# Opt-2: 2D tile but worse
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (128, 16) logical tile
|
||||
# * V layout is bad as contiguous mode is not on right-most
|
||||
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
|
||||
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
|
||||
# cta_tiler = (128, 16)
|
||||
|
||||
# Opt-3: SOL with 2D thread tile
|
||||
# * mA layout: (4096, 4096):(4096, 1)
|
||||
# * TV layout map to (64, 256) logical tile
|
||||
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
|
||||
|
||||
# Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type
|
||||
coalesced_ldst_bytes = 16
|
||||
|
||||
# Compile time validation: expect same element type for all input tensors
|
||||
assert all(t.element_type == inputs[0].element_type for t in inputs)
|
||||
dtype = inputs[0].element_type
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))
|
||||
val_layout = cute.recast_layout(dtype.width, 8, val_layout)
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
print("[DSL INFO] Input Tensors:")
|
||||
for i, t in enumerate(inputs):
|
||||
print(f"[DSL INFO] inputs{i} = {t}")
|
||||
print(f"[DSL INFO] result = {result}")
|
||||
|
||||
print("[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
||||
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
||||
|
||||
print("[DSL INFO] Tiled Tensors:")
|
||||
mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs]
|
||||
# ((TileM, TileN), (RestM, RestN))
|
||||
mC = cute.zipped_divide(result, tiler_mn)
|
||||
|
||||
# (RestM, RestN) -> (RestN, RestM)
|
||||
remap_block = cute.make_ordered_layout(
|
||||
cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0)
|
||||
)
|
||||
for i, t in enumerate(mInputs):
|
||||
print(f"[DSL INFO] gInputs{i} = {mInputs[i]}")
|
||||
mInputs[i] = cute.composition(t, (None, remap_block))
|
||||
print(f"[DSL INFO] gInputs{i} (remapped) = {mInputs[i]}")
|
||||
|
||||
mC = cute.composition(mC, (None, remap_block))
|
||||
print(f"[DSL INFO] gC = {mC}")
|
||||
|
||||
idC = cute.make_identity_tensor(result.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
print(f"[DSL INFO] coord tensor = {cC}")
|
||||
|
||||
# Launch the kernel asynchronously
|
||||
# Group input tensors into a list as a single argument
|
||||
elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch(
|
||||
# Compute production at each mode of mC.shape[1] to get multi-dimensional grid size
|
||||
grid=cute.product_each(mC.shape[1]),
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
@cutlass.dsl_user_op
|
||||
def leaky_relu(x, alpha, *, loc=None, ip=None):
|
||||
return cute.where(x > 0, x, alpha * x, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def leaky_relu_ref(x, alpha):
|
||||
import jax.numpy as jnp
|
||||
|
||||
return jnp.where(x > 0, x, alpha * x)
|
||||
|
||||
|
||||
def run_and_verify(op, M, N, dtype, skip_ref_check=False):
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import cutlass.jax as cjax
|
||||
import cutlass.jax.testing as testing
|
||||
|
||||
if op == "leaky_relu":
|
||||
op = partial(leaky_relu, alpha=0.01)
|
||||
ref_op = partial(leaky_relu_ref, alpha=0.01)
|
||||
num_inputs = 1
|
||||
else:
|
||||
op = getattr(operator, op)
|
||||
ref_op = op
|
||||
num_inputs = 2
|
||||
|
||||
# This jax function is transformed using jax.jit to compile its contents
|
||||
# into an efficient HLO executable.
|
||||
@partial(jax.jit, static_argnums=[1])
|
||||
def jax_function(inputs, op):
|
||||
call = cjax.cutlass_call(
|
||||
# Bind jax arguments to kernel signature
|
||||
lambda stream, inputs, output, *, op: elementwise_apply(
|
||||
op, inputs, output, stream
|
||||
),
|
||||
# Specify output shape/dtype of result
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(inputs[0].shape, inputs[0].dtype),
|
||||
# Pass static/constexpr values as kwargs
|
||||
op=op,
|
||||
)
|
||||
|
||||
# Call the kernel!
|
||||
return call(inputs)
|
||||
|
||||
@partial(jax.jit, static_argnums=[1])
|
||||
def jax_ref_function(inputs, op):
|
||||
return op(*inputs)
|
||||
|
||||
print("\nRunning Elementwise Apply test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
|
||||
jax_dtype = cjax.cutlass_to_jax_dtype(dtype)
|
||||
keys = jax.random.split(jax.random.key(1435), num_inputs)
|
||||
inputs = [testing.create_tensor((M, N), jax_dtype, key) for key in keys]
|
||||
|
||||
print("Input tensor shapes:")
|
||||
for i in range(num_inputs):
|
||||
print(f"inputs[{i}]: {inputs[i].shape}, dtype: {inputs[i].dtype}")
|
||||
|
||||
epsilon = 1.2
|
||||
if op in (operator.truediv, operator.floordiv):
|
||||
inputs[1] = jnp.where(inputs[1] == 0, epsilon, inputs[1])
|
||||
|
||||
# Call the jax.jit function which will compile the kernel
|
||||
c = jax_function(inputs, op)
|
||||
|
||||
if not skip_ref_check:
|
||||
print("Executing elementwise apply kernel...")
|
||||
c = jax_function(inputs, op)
|
||||
print("Verifying results...")
|
||||
assert jnp.allclose(ref_op(*inputs), c)
|
||||
print("Results verified successfully!")
|
||||
print(f"First few elements of result: \n{c[:3, :3]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Demonstration of calling a kernel with cutlass_call"
|
||||
)
|
||||
parser.add_argument("--M", default=4096, type=int)
|
||||
parser.add_argument("--N", default=4096, type=int)
|
||||
parser.add_argument("--op", default="add", type=str)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_and_verify(
|
||||
args.op,
|
||||
args.M,
|
||||
args.N,
|
||||
dtype=cutlass.Float32,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
)
|
||||
print("\nPASS")
|
||||
Reference in New Issue
Block a user