Files
cutlass/examples/python/CuTeDSL/jax/cutlass_call_export.py
2026-04-07 12:16:05 -04:00

202 lines
7.4 KiB
Python

# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Examples of using jax.export APIs with functions using cutlass_call.
This example demonstrates three export modes:
1. Concrete shapes -- shapes are fixed constants baked into the export.
2. Unconstrained symbolic shapes ("a, b")
3. Constrained symbolic shapes ("32*M, 16*N")
The JAX function being exported is the same in all three cases; only the
shape specification passed to jax.export differs.
It assumes familiarity with CuTe DSL concepts such as layouts and dynamic shapes
as well as JAX's exporting and serialization features:
https://docs.jax.dev/en/latest/export/index.html#export
To run this example:
.. code-block:: bash
python examples/jax/cutlass_call_export.py --M 512 --N 256
"""
import argparse
import cuda.bindings.driver as cuda
import cutlass.cute as cute
import jax
import jax.numpy as jnp
from jax import export
from cutlass.jax import cutlass_call, get_export_disabled_safety_checks, TensorSpec
from cutlass.jax.testing import create_tensor
# Simple element-wise addition kernel: gC[i,j] = gA[i,j] + gB[i,j]
@cute.kernel
def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
thread_idx = bidx * bdim + tidx
m, n = gA.shape
ni = thread_idx % n
mi = thread_idx // n
a_val = gA[mi, ni]
b_val = gB[mi, ni]
gC[mi, ni] = a_val + b_val
@cute.jit
def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
num_threads_per_block = 256
m, n = mA.shape
kernel(mA, mB, mC).launch(
grid=((m * n) // num_threads_per_block, 1, 1),
block=(num_threads_per_block, 1, 1),
stream=stream,
)
def _export_and_run(f, ref_f, input_shape_dtype, run_shapes):
"""Export f, serialize/deserialize, then run on each shape in run_shapes.
Both inputs (a, b) are assumed to share the same input_shape_dtype.
"""
print(f"Exporting with input signature: ({input_shape_dtype}, {input_shape_dtype})")
# jax.export can be used to export a jit function containing cutlass_call.
# CUTLASS custom call targets are not on JAX's built-in stable custom-call
# allowlist, so we pass them via disabled_checks to suppress that safety check.
exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())
traced = exported(input_shape_dtype, input_shape_dtype)
blob = traced.serialize()
print(f"Serialized computation is {len(blob)} bytes.")
rehydrated = export.deserialize(blob)
key = jax.random.key(1123)
a_key, b_key = jax.random.split(key, 2)
for shape in run_shapes:
a = create_tensor(shape, dtype=jnp.float32, key=a_key)
b = create_tensor(shape, dtype=jnp.float32, key=b_key)
c = rehydrated.call(a, b)
assert jnp.allclose(c, ref_f(a, b)), f"Mismatch at shape {shape}"
print(f" shape {shape}: OK")
def run_example(M, N):
@jax.jit
def ref_f(a, b):
return jax.nn.sigmoid(a + b)
# The same JAX function is used in all three examples below. The export
# mode is determined entirely by the shape spec passed to jax.export.
@jax.jit
def f(a, b):
call = cutlass_call(launch, output_shape_dtype=a)
return jax.nn.sigmoid(call(a, b))
# ── 1. Concrete shapes ────────────────────────────────────────────────────
# Shapes are fixed constants baked into the export. The deserialized
# computation only accepts exactly these dimensions at runtime.
print("\nConcrete shapes:")
input_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)
_export_and_run(
f,
ref_f,
input_shape_dtype,
run_shapes=[(M, N)], # concrete exports reject any other shape
)
# ── 2. Unconstrained symbolic shapes ─────────────────────────────────────
# Both dimensions are fully dynamic. The exported computation accepts any
# (M, N) at runtime without recompilation.
print("\nUnconstrained symbolic shapes:")
a_sym, b_sym = export.symbolic_shape("a, b")
input_shape_dtype = jax.ShapeDtypeStruct((a_sym, b_sym), jnp.float32)
_export_and_run(
f,
ref_f,
input_shape_dtype,
run_shapes=[(M, N), (M * 2, N * 4), (M * 4, N * 4)],
)
# ── 3. Constrained symbolic shapes (divisibility) ─────────────────────────
# Shapes are declared as multiples of a tile size via TensorSpec.divisibility.
# The symbolic expression "32*M, 16*N" tells jax.export that dim 0 is always
# a multiple of 32 and dim 1 is always a multiple of 16. This lets the
# compiler generate more efficient code (e.g. no remainder handling).
# Runtime shapes must satisfy these divisibility constraints.
print("\nConstrained symbolic shapes:")
@jax.jit
def f_divisible(a, b):
spec = TensorSpec(divisibility=(32, 16))
call = cutlass_call(
launch,
output_shape_dtype=a,
input_spec=(spec, spec),
output_spec=spec,
)
return jax.nn.sigmoid(call(a, b))
m_sym, n_sym = export.symbolic_shape("32*M, 16*N")
input_shape_dtype = jax.ShapeDtypeStruct((m_sym, n_sym), jnp.float32)
_export_and_run(
f_divisible,
ref_f,
input_shape_dtype,
run_shapes=[(M, N), (M * 2, N * 2), (M * 4, N * 4)],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Demonstration of using jax.export with functions with cutlass_call"
)
parser.add_argument("--M", default=512, type=int)
parser.add_argument("--N", default=256, type=int)
args = parser.parse_args()
run_example(args.M, args.N)
print("PASS")