v4.3.1 update. (#2817)

This commit is contained in:
Junkai-Wu
2025-11-27 22:49:30 +08:00
committed by GitHub
parent 2052fd3885
commit 1de3a576cc
44 changed files with 3316 additions and 510 deletions

View File

@@ -35,13 +35,35 @@ There are two ways to enable TVM FFI in |DSL|:
a_cute = cute.runtime.from_dlpack(a_torch, enable_tvm_ffi=True).mark_layout_dynamic()
b_cute = cute.runtime.from_dlpack(b_torch, enable_tvm_ffi=True).mark_layout_dynamic()
compiled_add = cute.compile(add, a_cute, b_cute, options="--enable-tvm-ffi")
compiled_add = cute.compile(add, a_torch, b_torch, options="--enable-tvm-ffi")
Note that the object returned by ``cute.compile`` is a Python function specific to TVM FFI.
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. Please note that this setting will apply to all JIT compilations within the environment.
Minimizing Host Overhead
------------------------
Eager kernel invocation overhead on the CPU host can sometimes become a bottleneck
for latency-sensitive applications. TVM FFI can help greatly reduce this overhead.
To maximize performance benefits, we recommend setting up your workflow as follows
(detailed instructions are provided in subsequent sections):
- **Compile the kernel with TVM FFI enabled.**
- **Declare shape constraints using fake tensors** and reuse the compiled function
throughout your execution.
- **Pass PyTorch tensors directly** to the compiled function to avoid explicit DLPack conversion.
- **Use the environment stream flag** to implicitly synchronize with the current PyTorch stream.
- **Rely on compiled argument validation** instead of Python-side attribute validation,
as TVM FFI functions perform fast compiled checks.
Following these steps can significantly reduce the host-side overhead of eager kernel execution.
The sections below provide detailed examples and explanations for each step.
You may find it helpful to refer back to this summary after you review the implementation details.
Fake tensor for compilation
---------------------------
@@ -92,6 +114,18 @@ The fake tensor is a placeholder that mimics the interface of a real tensor but
It is used in compilation or testing scenarios where only shape/type/layout information is needed.
All attempts to access or mutate data will raise errors.
Note on Stride Order
~~~~~~~~~~~~~~~~~~~~
Note that CuTe's convention is to write the stride order for dimensions from left to right,
where a lower order number means higher priority. In the context of the ``make_fake_compact_tensor`` API,
for shape ``(2, 3, 4)`` and stride order ``(0, 1, 2)``, the stride is ``(1, 2, 6)``.
This is commonly known as column-major order. If you want to create a fake tensor with compact row-major order,
you should explicitly pass in ``stride_order=tuple(reversed(range(len(shape))))``
to ``make_fake_compact_tensor``. Alternatively, you can always precisely control the
stride via the ``stride`` argument in the ``make_fake_tensor`` API.
``cute.Tensor`` adapter for TVM FFI
-----------------------------------
@@ -173,6 +207,9 @@ The following example demonstrates this approach; the function accepts ``torch.c
print("result of b_torch after compiled_add_one(a_torch, b_torch, torch_stream)")
print(b_torch)
Using Environment Stream
~~~~~~~~~~~~~~~~~~~~~~~~
The second option is to rely on the environment-stream flag.
Pass ``use_tvm_ffi_env_stream=True`` to ``make_fake_stream`` to mark the argument as an
environment stream so it no longer has to be provided explicitly.
@@ -202,8 +239,54 @@ before each call. The example below shows this flow:
print("result of b_torch after compiled_add_one(a_torch, b_torch)")
print(b_torch)
Using the environment-stream flag both speeds up calls and simplifies integration
Using the environment stream flag both speeds up calls and simplifies integration
with frameworks such as PyTorch, since no explicit stream parameter is required.
We recommend using the environment stream flag to both simplify framework integration
and minimize host-side calling overhead.
Working with Tuples
-------------------
TVM FFI functions can also accept tuples as arguments. Tuples can be recursively
composed of the types that are supported by TVM FFI. The example below shows how to use tuples as arguments:
.. code-block:: python
import torch
from cutlass import cute
@cute.kernel
def device_add_one(a: cute.Tensor, b: cute.Tensor, c: cute.Float32):
threads_per_block = 128
cta_x_, _, _ = cute.arch.block_idx()
tid_x, _, _ = cute.arch.thread_idx()
tid = cta_x_ * threads_per_block + tid_x
if tid < a.shape[0]:
b[tid] = a[tid] + c
@cute.jit
def add_one_with_tuple(a: Tuple[cute.Tensor, cute.Tensor, cute.Float32]):
n = a[0].shape[0]
threads_per_block = 128
blocks = (n + threads_per_block - 1) // threads_per_block
device_add_one(a[0], a[1], a[2]).launch(grid=(blocks, 1, 1), block=(threads_per_block, 1, 1))
def example_add_one_with_tuple():
n = cute.sym_int()
a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
compiled_add_one = cute.compile(
add_one_with_tuple, (a_cute, b_cute, cute.Float32(4)),
options="--enable-tvm-ffi"
)
a_torch = torch.arange(10, dtype=torch.float32, device="cuda")
b_torch = torch.empty(10, dtype=torch.float32, device="cuda")
compiled_add_one((a_torch, b_torch, 5))
print("result of b_torch after compiled_add_one((a_torch, b_torch, 5))")
print(b_torch)
example_add_one_with_tuple()
Supported types
---------------
@@ -230,6 +313,8 @@ The TVM FFI function supports the following |DSL|-specific types as arguments:
- ``tvm_ffi.Shape`` or python tuple of ints.
* - CUDA stream ``cuda.CUstream``
- A stream class that implements the CUDA stream protocol (e.g. ``torch.cuda.Stream``, ``cuda.CUstream``).
* - Tuple of types (e.g. ``Tuple[cute.Tensor, cute.Tensor, cutlass.Int32]``)
- Python tuple of corresponding call-time types.
Error handling
--------------

View File

@@ -415,24 +415,36 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to
)
# Layout in DLTensorWrapper has int32 overflow risk. Please set use_32bit_stride to False.
Leveraging TVM FFI for Faster PyTorch Interop
---------------------------------------------
The latest version of |DSL| supports TVM FFI to improve interoperability with PyTorch
and other machine learning frameworks. Using TVM FFI provides the following features:
- Faster JIT function invocation.
- Direct acceptance of ``torch.Tensor`` objects as function arguments.
- Enhanced error handling and kernel validation.
- Seamless integration with multiple programming languages.
For more details, see :doc:`compile_with_tvm_ffi`.
Bypass the DLPack Protocol
--------------------------
In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly.
This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function,
In certain scenarios, users may wish to bypass the DLPack protocol and invoke the JIT function directly.
This can be accomplished by creating a lightweight JIT wrapper around the existing JIT function,
utilizing ``cute.ptr`` and ``cute.make_tensor`` to pass pointers and construct tensors directly.
Typical use cases for bypassing DLPack include:
1. Users want to call the JIT function directly to avoid the overhead introduced by the DLPack protocol.
2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment
2. DLPack canonicalizes the stride of shape-1 dimensions to 1, which may result in incorrect alignment
propagation and affect memory access or performance.
3. DLPack may lack support for some narrow data types.
The following example illustrates how to bypass the DLPack protocol when invoking a JIT function.
Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three
arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper
function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor``
Assume we have a pre-defined ``TensorOpGemm`` kernel whose JIT interface expects three
arguments of type ``cute.Tensor``. To enable direct invocation without DLPack, we first define a JIT wrapper
function that accepts ``cute.Pointer`` types as parameters. Within this wrapper, we use ``cute.make_tensor``
to construct tensors from the provided pointers, and then call the ``TensorOpGemm`` kernel as usual.
.. code-block:: python
@@ -459,7 +471,7 @@ to construct tensors from the provided pointers, and then call the ``TensorOpGem
mA = cute.make_tensor(a_ptr, layout=a_layout)
mB = cute.make_tensor(b_ptr, layout=b_layout)
mC = cute.make_tensor(c_ptr, layout=c_layout)
# TensorOpGemm is a pre-defined kernel from our example
tensor_op_gemm = TensorOpGemm(
a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1)
@@ -467,9 +479,9 @@ to construct tensors from the provided pointers, and then call the ``TensorOpGem
tensor_op_gemm(mA, mB, mC)
To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor
To pass a PyTorch tensor to this new JIT wrapper, we retrieve the raw pointer from the PyTorch tensor
and create a ``cute.Pointer`` instance using ``cute.make_ptr``.
This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential
This approach allows us to bypass the DLPack protocol entirely, avoiding its overhead and potential
issues with shape-1 dimension handling.
.. code-block:: python
@@ -483,7 +495,7 @@ issues with shape-1 dimension handling.
c = torch.randn(
n, m, l, dtype=torch.float16, device="cuda"
).permute(1, 2, 0)
# from cutlass.cute.runtime import make_ptr
a_ptr = make_ptr(
cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
@@ -495,3 +507,5 @@ issues with shape-1 dimension handling.
cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, m, n, k, l)