mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-24 14:54:34 +00:00
v4.3.1 update. (#2817)
This commit is contained in:
@@ -35,13 +35,35 @@ There are two ways to enable TVM FFI in |DSL|:
|
||||
a_cute = cute.runtime.from_dlpack(a_torch, enable_tvm_ffi=True).mark_layout_dynamic()
|
||||
b_cute = cute.runtime.from_dlpack(b_torch, enable_tvm_ffi=True).mark_layout_dynamic()
|
||||
|
||||
compiled_add = cute.compile(add, a_cute, b_cute, options="--enable-tvm-ffi")
|
||||
compiled_add = cute.compile(add, a_torch, b_torch, options="--enable-tvm-ffi")
|
||||
|
||||
|
||||
Note that the object returned by ``cute.compile`` is a Python function specific to TVM FFI.
|
||||
|
||||
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. Please note that this setting will apply to all JIT compilations within the environment.
|
||||
|
||||
|
||||
Minimizing Host Overhead
|
||||
------------------------
|
||||
|
||||
Eager kernel invocation overhead on the CPU host can sometimes become a bottleneck
|
||||
for latency-sensitive applications. TVM FFI can help greatly reduce this overhead.
|
||||
To maximize performance benefits, we recommend setting up your workflow as follows
|
||||
(detailed instructions are provided in subsequent sections):
|
||||
|
||||
- **Compile the kernel with TVM FFI enabled.**
|
||||
- **Declare shape constraints using fake tensors** and reuse the compiled function
|
||||
throughout your execution.
|
||||
- **Pass PyTorch tensors directly** to the compiled function to avoid explicit DLPack conversion.
|
||||
- **Use the environment stream flag** to implicitly synchronize with the current PyTorch stream.
|
||||
- **Rely on compiled argument validation** instead of Python-side attribute validation,
|
||||
as TVM FFI functions perform fast compiled checks.
|
||||
|
||||
Following these steps can significantly reduce the host-side overhead of eager kernel execution.
|
||||
The sections below provide detailed examples and explanations for each step.
|
||||
You may find it helpful to refer back to this summary after you review the implementation details.
|
||||
|
||||
|
||||
Fake tensor for compilation
|
||||
---------------------------
|
||||
|
||||
@@ -92,6 +114,18 @@ The fake tensor is a placeholder that mimics the interface of a real tensor but
|
||||
It is used in compilation or testing scenarios where only shape/type/layout information is needed.
|
||||
All attempts to access or mutate data will raise errors.
|
||||
|
||||
Note on Stride Order
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Note that CuTe's convention is to write the stride order for dimensions from left to right,
|
||||
where a lower order number means higher priority. In the context of the ``make_fake_compact_tensor`` API,
|
||||
for shape ``(2, 3, 4)`` and stride order ``(0, 1, 2)``, the stride is ``(1, 2, 6)``.
|
||||
This is commonly known as column-major order. If you want to create a fake tensor with compact row-major order,
|
||||
you should explicitly pass in ``stride_order=tuple(reversed(range(len(shape))))``
|
||||
to ``make_fake_compact_tensor``. Alternatively, you can always precisely control the
|
||||
stride via the ``stride`` argument in the ``make_fake_tensor`` API.
|
||||
|
||||
|
||||
``cute.Tensor`` adapter for TVM FFI
|
||||
-----------------------------------
|
||||
|
||||
@@ -173,6 +207,9 @@ The following example demonstrates this approach; the function accepts ``torch.c
|
||||
print("result of b_torch after compiled_add_one(a_torch, b_torch, torch_stream)")
|
||||
print(b_torch)
|
||||
|
||||
Using Environment Stream
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The second option is to rely on the environment-stream flag.
|
||||
Pass ``use_tvm_ffi_env_stream=True`` to ``make_fake_stream`` to mark the argument as an
|
||||
environment stream so it no longer has to be provided explicitly.
|
||||
@@ -202,8 +239,54 @@ before each call. The example below shows this flow:
|
||||
print("result of b_torch after compiled_add_one(a_torch, b_torch)")
|
||||
print(b_torch)
|
||||
|
||||
Using the environment-stream flag both speeds up calls and simplifies integration
|
||||
Using the environment stream flag both speeds up calls and simplifies integration
|
||||
with frameworks such as PyTorch, since no explicit stream parameter is required.
|
||||
We recommend using the environment stream flag to both simplify framework integration
|
||||
and minimize host-side calling overhead.
|
||||
|
||||
Working with Tuples
|
||||
-------------------
|
||||
|
||||
TVM FFI functions can also accept tuples as arguments. Tuples can be recursively
|
||||
composed of the types that are supported by TVM FFI. The example below shows how to use tuples as arguments:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from cutlass import cute
|
||||
|
||||
@cute.kernel
|
||||
def device_add_one(a: cute.Tensor, b: cute.Tensor, c: cute.Float32):
|
||||
threads_per_block = 128
|
||||
cta_x_, _, _ = cute.arch.block_idx()
|
||||
tid_x, _, _ = cute.arch.thread_idx()
|
||||
tid = cta_x_ * threads_per_block + tid_x
|
||||
if tid < a.shape[0]:
|
||||
b[tid] = a[tid] + c
|
||||
|
||||
@cute.jit
|
||||
def add_one_with_tuple(a: Tuple[cute.Tensor, cute.Tensor, cute.Float32]):
|
||||
n = a[0].shape[0]
|
||||
threads_per_block = 128
|
||||
blocks = (n + threads_per_block - 1) // threads_per_block
|
||||
device_add_one(a[0], a[1], a[2]).launch(grid=(blocks, 1, 1), block=(threads_per_block, 1, 1))
|
||||
|
||||
def example_add_one_with_tuple():
|
||||
n = cute.sym_int()
|
||||
a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
|
||||
b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
|
||||
compiled_add_one = cute.compile(
|
||||
add_one_with_tuple, (a_cute, b_cute, cute.Float32(4)),
|
||||
options="--enable-tvm-ffi"
|
||||
)
|
||||
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
|
||||
b_torch = torch.empty(10, dtype=torch.float32, device="cuda")
|
||||
compiled_add_one((a_torch, b_torch, 5))
|
||||
print("result of b_torch after compiled_add_one((a_torch, b_torch, 5))")
|
||||
print(b_torch)
|
||||
|
||||
example_add_one_with_tuple()
|
||||
|
||||
|
||||
Supported types
|
||||
---------------
|
||||
@@ -230,6 +313,8 @@ The TVM FFI function supports the following |DSL|-specific types as arguments:
|
||||
- ``tvm_ffi.Shape`` or python tuple of ints.
|
||||
* - CUDA stream ``cuda.CUstream``
|
||||
- A stream class that implements the CUDA stream protocol (e.g. ``torch.cuda.Stream``, ``cuda.CUstream``).
|
||||
* - Tuple of types (e.g. ``Tuple[cute.Tensor, cute.Tensor, cutlass.Int32]``)
|
||||
- Python tuple of corresponding call-time types.
|
||||
|
||||
Error handling
|
||||
--------------
|
||||
|
||||
@@ -415,24 +415,36 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to
|
||||
)
|
||||
# Layout in DLTensorWrapper has int32 overflow risk. Please set use_32bit_stride to False.
|
||||
|
||||
Leveraging TVM FFI for Faster PyTorch Interop
|
||||
---------------------------------------------
|
||||
|
||||
The latest version of |DSL| supports TVM FFI to improve interoperability with PyTorch
|
||||
and other machine learning frameworks. Using TVM FFI provides the following features:
|
||||
|
||||
- Faster JIT function invocation.
|
||||
- Direct acceptance of ``torch.Tensor`` objects as function arguments.
|
||||
- Enhanced error handling and kernel validation.
|
||||
- Seamless integration with multiple programming languages.
|
||||
|
||||
For more details, see :doc:`compile_with_tvm_ffi`.
|
||||
|
||||
Bypass the DLPack Protocol
|
||||
--------------------------
|
||||
|
||||
In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly.
|
||||
This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function,
|
||||
In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly.
|
||||
This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function,
|
||||
utilizing ``cute.ptr`` and ``cute.make_tensor`` to pass pointers and construct tensors directly.
|
||||
|
||||
Typical use cases for bypassing DLPack include:
|
||||
1. Users want to call the JIT function directly to avoid the overhead introduced by the DLPack protocol.
|
||||
2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment
|
||||
2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment
|
||||
propagation and affect memory access or performance.
|
||||
3. DLPack may lack support for some narrow data types.
|
||||
|
||||
The following example illustrates how to bypass the DLPack protocol when invoking a JIT function.
|
||||
Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three
|
||||
arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper
|
||||
function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor``
|
||||
Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three
|
||||
arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper
|
||||
function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor``
|
||||
to construct tensors from the provided pointers, and then call the ``TensorOpGemm`` kernel as usual.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -459,7 +471,7 @@ to construct tensors from the provided pointers, and then call the ``TensorOpGem
|
||||
mA = cute.make_tensor(a_ptr, layout=a_layout)
|
||||
mB = cute.make_tensor(b_ptr, layout=b_layout)
|
||||
mC = cute.make_tensor(c_ptr, layout=c_layout)
|
||||
|
||||
|
||||
# TensorOpGemm is a pre-defined kernel from our example
|
||||
tensor_op_gemm = TensorOpGemm(
|
||||
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
|
||||
@@ -467,9 +479,9 @@ to construct tensors from the provided pointers, and then call the ``TensorOpGem
|
||||
|
||||
tensor_op_gemm(mA, mB, mC)
|
||||
|
||||
To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor
|
||||
To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor
|
||||
and create a ``cute.Pointer`` instance using ``cute.make_ptr``.
|
||||
This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential
|
||||
This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential
|
||||
issues with shape-1 dimension handling.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -483,7 +495,7 @@ issues with shape-1 dimension handling.
|
||||
c = torch.randn(
|
||||
n, m, l, dtype=torch.float16, device="cuda"
|
||||
).permute(1, 2, 0)
|
||||
|
||||
|
||||
# from cutlass.cute.runtime import make_ptr
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
@@ -495,3 +507,5 @@ issues with shape-1 dimension handling.
|
||||
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
|
||||
)
|
||||
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, m, n, k, l)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user