fix TVM FFI doc and update example

This commit is contained in:
Fung Xie
2025-11-25 05:46:12 -08:00
parent 1de3a576cc
commit 739fffce27
2 changed files with 46 additions and 22 deletions

View File

@@ -1715,7 +1715,6 @@ def compile_bmm(
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
epilogue_op: cutlass.Constexpr = lambda x: x,
options: str = "",
):
from cutlass.cute.runtime import make_fake_compact_tensor
@@ -1749,7 +1748,15 @@ def compile_bmm(
)
return cute.compile(
bmm, gemm_op, a, b, c, max_active_clusters, stream, epilogue_op, options=options
bmm,
gemm_op,
a,
b,
c,
max_active_clusters,
stream,
epilogue_op,
options="--enable-tvm-ffi",
)
@@ -1811,7 +1818,6 @@ def run(
iterations: int = 1,
skip_ref_check: bool = False,
use_cold_l2: bool = False,
use_tvm_ffi: bool = False,
benchmark: bool = False,
**kwargs,
):
@@ -1853,8 +1859,6 @@ def run(
:type skip_ref_check: bool, optional
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False.
:type use_cold_l2: bool, optional
:param use_tvm_ffi: Whether to use TVM FFI for the kernel, defaults to False.
:type use_tvm_ffi: bool, optional
:param benchmark: Whether to only benchmark the kernel, defaults to False.
:type benchmark: bool, optional
:raises RuntimeError: If CUDA GPU is not available.
@@ -1874,7 +1878,7 @@ def run(
print(f"Iterations: {iterations}")
print(f"Skip reference checking: {skip_ref_check}")
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
print(f"Use TVM FFI: {'True' if use_tvm_ffi else 'False'}")
print(f"Use TVM FFI")
import torch
from cutlass.torch import dtype as torch_dtype
@@ -1906,10 +1910,6 @@ def run(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
options = []
if use_tvm_ffi:
options.append("--enable-tvm-ffi")
compiled_fn = compile_bmm(
gemm,
ab_dtype,
@@ -1920,7 +1920,6 @@ def run(
c_major,
max_active_clusters,
current_stream,
options=",".join(options),
)
# Run and verify BMM with torch
@@ -2043,12 +2042,6 @@ def prepare_parser():
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
parser.add_argument(
"--use_tvm_ffi",
action="store_true",
default=False,
help="Enable TVM FFI for the kernel, defaults to False using CuTe DSL's native runtime",
)
return parser
@@ -2090,7 +2083,6 @@ if __name__ == "__main__":
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
args.use_tvm_ffi,
args.benchmark,
)
print("PASS")

View File

@@ -4,7 +4,8 @@
Compile with TVM FFI
====================
Apache TVM FFI is an open ABI and FFI for machine learning systems. More information can be found in the `official documentation <https://tvm.apache.org/ffi/>`_.
Apache TVM FFI is an open ABI and FFI for machine learning systems. More information can be found in
the `official documentation <https://tvm.apache.org/ffi/>`_.
To install TVM FFI, you can run the following command:
@@ -14,7 +15,9 @@ To install TVM FFI, you can run the following command:
# optional package for improved torch tensor calling performance
pip install torch-c-dlpack-ext
In |DSL|, TVM FFI can be enabled as an option for JIT-compiled functions. Using TVM FFI can lead to faster JIT function invocation and provides better interoperability with machine learning frameworks (e.g., directly take ``torch.Tensor`` as arguments).
In |DSL|, TVM FFI can be enabled as an option for JIT-compiled functions. Using TVM FFI can lead to faster
JIT function invocation and provides better interoperability with machine learning frameworks
(e.g., directly take ``torch.Tensor`` as arguments).
Enable Apache TVM FFI in |DSL|
@@ -40,7 +43,8 @@ There are two ways to enable TVM FFI in |DSL|:
Note that the object returned by ``cute.compile`` is a Python function specific to TVM FFI.
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. Please note that this setting will apply to all JIT compilations within the environment.
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``.
Please note that this setting will apply to all JIT compilations within the environment.
Minimizing Host Overhead
@@ -129,7 +133,8 @@ stride via the ``stride`` argument in the ``make_fake_tensor`` API.
``cute.Tensor`` adapter for TVM FFI
-----------------------------------
To adapt the ``cute.Tensor`` to the TVM FFI function, you can use the ``cute.runtime.from_dlpack`` function with the ``enable_tvm_ffi=True`` option or the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. For example:
To adapt the ``cute.Tensor`` to the TVM FFI function, you can use the ``cute.runtime.from_dlpack`` function with the
``enable_tvm_ffi=True`` option or the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. For example:
.. code-block:: python
@@ -288,6 +293,33 @@ composed of the types that are supported by TVM FFI. The example below shows how
example_add_one_with_tuple()
Limitations
-----------
The Fake Tensor flow is ONLY compatible with TVM FFI because TVM FFI support more flexible constraints on Tensor arguments.
For instance, fake tensor can specify per-mode static shape or constraints on shape and strides which is not supported by
existing `from_dlpack` flow. It's expected that JIT function compiled with fake tensor may have different ABI with
tensor converted by `from_dlpack`.
.. code-block:: python
import cutlass.cute as cute
import torch
n = cute.sym_int()
# Dynamic Shape
fake_a = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
# Compile without tvm-ffi
compiled_fn = cute.compile(foo, fake_a)
# Wrong, in compatible ABI
compiled_fn(from_dlpack(a))
In order to avoid such issue, it's recommended fake tensor is only used with TVM FFI.
Supported types
---------------