mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 14:59:01 +00:00
fix TVM FFI doc and update example
This commit is contained in:
@@ -1715,7 +1715,6 @@ def compile_bmm(
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
stream: cuda.CUstream,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
options: str = "",
|
||||
):
|
||||
from cutlass.cute.runtime import make_fake_compact_tensor
|
||||
|
||||
@@ -1749,7 +1748,15 @@ def compile_bmm(
|
||||
)
|
||||
|
||||
return cute.compile(
|
||||
bmm, gemm_op, a, b, c, max_active_clusters, stream, epilogue_op, options=options
|
||||
bmm,
|
||||
gemm_op,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
epilogue_op,
|
||||
options="--enable-tvm-ffi",
|
||||
)
|
||||
|
||||
|
||||
@@ -1811,7 +1818,6 @@ def run(
|
||||
iterations: int = 1,
|
||||
skip_ref_check: bool = False,
|
||||
use_cold_l2: bool = False,
|
||||
use_tvm_ffi: bool = False,
|
||||
benchmark: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -1853,8 +1859,6 @@ def run(
|
||||
:type skip_ref_check: bool, optional
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False.
|
||||
:type use_cold_l2: bool, optional
|
||||
:param use_tvm_ffi: Whether to use TVM FFI for the kernel, defaults to False.
|
||||
:type use_tvm_ffi: bool, optional
|
||||
:param benchmark: Whether to only benchmark the kernel, defaults to False.
|
||||
:type benchmark: bool, optional
|
||||
:raises RuntimeError: If CUDA GPU is not available.
|
||||
@@ -1874,7 +1878,7 @@ def run(
|
||||
print(f"Iterations: {iterations}")
|
||||
print(f"Skip reference checking: {skip_ref_check}")
|
||||
print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}")
|
||||
print(f"Use TVM FFI: {'True' if use_tvm_ffi else 'False'}")
|
||||
print(f"Use TVM FFI")
|
||||
|
||||
import torch
|
||||
from cutlass.torch import dtype as torch_dtype
|
||||
@@ -1906,10 +1910,6 @@ def run(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1]
|
||||
)
|
||||
|
||||
options = []
|
||||
if use_tvm_ffi:
|
||||
options.append("--enable-tvm-ffi")
|
||||
|
||||
compiled_fn = compile_bmm(
|
||||
gemm,
|
||||
ab_dtype,
|
||||
@@ -1920,7 +1920,6 @@ def run(
|
||||
c_major,
|
||||
max_active_clusters,
|
||||
current_stream,
|
||||
options=",".join(options),
|
||||
)
|
||||
|
||||
# Run and verify BMM with torch
|
||||
@@ -2043,12 +2042,6 @@ def prepare_parser():
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_tvm_ffi",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable TVM FFI for the kernel, defaults to False using CuTe DSL's native runtime",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@@ -2090,7 +2083,6 @@ if __name__ == "__main__":
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
args.use_tvm_ffi,
|
||||
args.benchmark,
|
||||
)
|
||||
print("PASS")
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
Compile with TVM FFI
|
||||
====================
|
||||
|
||||
Apache TVM FFI is an open ABI and FFI for machine learning systems. More information can be found in the `official documentation <https://tvm.apache.org/ffi/>`_.
|
||||
Apache TVM FFI is an open ABI and FFI for machine learning systems. More information can be found in
|
||||
the `official documentation <https://tvm.apache.org/ffi/>`_.
|
||||
|
||||
To install TVM FFI, you can run the following command:
|
||||
|
||||
@@ -14,7 +15,9 @@ To install TVM FFI, you can run the following command:
|
||||
# optional package for improved torch tensor calling performance
|
||||
pip install torch-c-dlpack-ext
|
||||
|
||||
In |DSL|, TVM FFI can be enabled as an option for JIT-compiled functions. Using TVM FFI can lead to faster JIT function invocation and provides better interoperability with machine learning frameworks (e.g., directly take ``torch.Tensor`` as arguments).
|
||||
In |DSL|, TVM FFI can be enabled as an option for JIT-compiled functions. Using TVM FFI can lead to faster
|
||||
JIT function invocation and provides better interoperability with machine learning frameworks
|
||||
(e.g., directly take ``torch.Tensor`` as arguments).
|
||||
|
||||
|
||||
Enable Apache TVM FFI in |DSL|
|
||||
@@ -40,7 +43,8 @@ There are two ways to enable TVM FFI in |DSL|:
|
||||
|
||||
Note that the object returned by ``cute.compile`` is a Python function specific to TVM FFI.
|
||||
|
||||
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. Please note that this setting will apply to all JIT compilations within the environment.
|
||||
2. Alternatively, you can enable TVM FFI globally by setting the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``.
|
||||
Please note that this setting will apply to all JIT compilations within the environment.
|
||||
|
||||
|
||||
Minimizing Host Overhead
|
||||
@@ -129,7 +133,8 @@ stride via the ``stride`` argument in the ``make_fake_tensor`` API.
|
||||
``cute.Tensor`` adapter for TVM FFI
|
||||
-----------------------------------
|
||||
|
||||
To adapt the ``cute.Tensor`` to the TVM FFI function, you can use the ``cute.runtime.from_dlpack`` function with the ``enable_tvm_ffi=True`` option or the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. For example:
|
||||
To adapt the ``cute.Tensor`` to the TVM FFI function, you can use the ``cute.runtime.from_dlpack`` function with the
|
||||
``enable_tvm_ffi=True`` option or the environment variable ``CUTE_DSL_ENABLE_TVM_FFI=1``. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -288,6 +293,33 @@ composed of the types that are supported by TVM FFI. The example below shows how
|
||||
example_add_one_with_tuple()
|
||||
|
||||
|
||||
Limitations
|
||||
-----------
|
||||
|
||||
The Fake Tensor flow is ONLY compatible with TVM FFI because TVM FFI support more flexible constraints on Tensor arguments.
|
||||
For instance, fake tensor can specify per-mode static shape or constraints on shape and strides which is not supported by
|
||||
existing `from_dlpack` flow. It's expected that JIT function compiled with fake tensor may have different ABI with
|
||||
tensor converted by `from_dlpack`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
|
||||
n = cute.sym_int()
|
||||
# Dynamic Shape
|
||||
fake_a = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,))
|
||||
|
||||
# Compile without tvm-ffi
|
||||
compiled_fn = cute.compile(foo, fake_a)
|
||||
|
||||
# Wrong, in compatible ABI
|
||||
compiled_fn(from_dlpack(a))
|
||||
|
||||
|
||||
In order to avoid such issue, it's recommended fake tensor is only used with TVM FFI.
|
||||
|
||||
|
||||
Supported types
|
||||
---------------
|
||||
|
||||
|
||||
Reference in New Issue
Block a user