mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 17:25:45 +00:00
202 lines
7.4 KiB
Python
202 lines
7.4 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
"""
|
|
Examples of using jax.export APIs with functions using cutlass_call.
|
|
|
|
This example demonstrates three export modes:
|
|
|
|
1. Concrete shapes -- shapes are fixed constants baked into the export.
|
|
2. Unconstrained symbolic shapes ("a, b")
|
|
3. Constrained symbolic shapes ("32*M, 16*N")
|
|
|
|
The JAX function being exported is the same in all three cases; only the
|
|
shape specification passed to jax.export differs.
|
|
|
|
It assumes familiarity with CuTe DSL concepts such as layouts and dynamic shapes
|
|
as well as JAX's exporting and serialization features:
|
|
https://docs.jax.dev/en/latest/export/index.html#export
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/jax/cutlass_call_export.py --M 512 --N 256
|
|
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import cuda.bindings.driver as cuda
|
|
|
|
import cutlass.cute as cute
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import export
|
|
|
|
from cutlass.jax import cutlass_call, get_export_disabled_safety_checks, TensorSpec
|
|
from cutlass.jax.testing import create_tensor
|
|
|
|
|
|
# Simple element-wise addition kernel: gC[i,j] = gA[i,j] + gB[i,j]
|
|
@cute.kernel
|
|
def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
bdim, _, _ = cute.arch.block_dim()
|
|
|
|
thread_idx = bidx * bdim + tidx
|
|
|
|
m, n = gA.shape
|
|
ni = thread_idx % n
|
|
mi = thread_idx // n
|
|
|
|
a_val = gA[mi, ni]
|
|
b_val = gB[mi, ni]
|
|
gC[mi, ni] = a_val + b_val
|
|
|
|
|
|
@cute.jit
|
|
def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
|
|
num_threads_per_block = 256
|
|
m, n = mA.shape
|
|
kernel(mA, mB, mC).launch(
|
|
grid=((m * n) // num_threads_per_block, 1, 1),
|
|
block=(num_threads_per_block, 1, 1),
|
|
stream=stream,
|
|
)
|
|
|
|
|
|
def _export_and_run(f, ref_f, input_shape_dtype, run_shapes):
|
|
"""Export f, serialize/deserialize, then run on each shape in run_shapes.
|
|
|
|
Both inputs (a, b) are assumed to share the same input_shape_dtype.
|
|
"""
|
|
print(f"Exporting with input signature: ({input_shape_dtype}, {input_shape_dtype})")
|
|
|
|
# jax.export can be used to export a jit function containing cutlass_call.
|
|
# CUTLASS custom call targets are not on JAX's built-in stable custom-call
|
|
# allowlist, so we pass them via disabled_checks to suppress that safety check.
|
|
exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())
|
|
traced = exported(input_shape_dtype, input_shape_dtype)
|
|
|
|
blob = traced.serialize()
|
|
print(f"Serialized computation is {len(blob)} bytes.")
|
|
|
|
rehydrated = export.deserialize(blob)
|
|
|
|
key = jax.random.key(1123)
|
|
a_key, b_key = jax.random.split(key, 2)
|
|
for shape in run_shapes:
|
|
a = create_tensor(shape, dtype=jnp.float32, key=a_key)
|
|
b = create_tensor(shape, dtype=jnp.float32, key=b_key)
|
|
c = rehydrated.call(a, b)
|
|
assert jnp.allclose(c, ref_f(a, b)), f"Mismatch at shape {shape}"
|
|
print(f" shape {shape}: OK")
|
|
|
|
|
|
def run_example(M, N):
|
|
@jax.jit
|
|
def ref_f(a, b):
|
|
return jax.nn.sigmoid(a + b)
|
|
|
|
# The same JAX function is used in all three examples below. The export
|
|
# mode is determined entirely by the shape spec passed to jax.export.
|
|
@jax.jit
|
|
def f(a, b):
|
|
call = cutlass_call(launch, output_shape_dtype=a)
|
|
return jax.nn.sigmoid(call(a, b))
|
|
|
|
# ── 1. Concrete shapes ────────────────────────────────────────────────────
|
|
# Shapes are fixed constants baked into the export. The deserialized
|
|
# computation only accepts exactly these dimensions at runtime.
|
|
print("\nConcrete shapes:")
|
|
|
|
input_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)
|
|
_export_and_run(
|
|
f,
|
|
ref_f,
|
|
input_shape_dtype,
|
|
run_shapes=[(M, N)], # concrete exports reject any other shape
|
|
)
|
|
|
|
# ── 2. Unconstrained symbolic shapes ─────────────────────────────────────
|
|
# Both dimensions are fully dynamic. The exported computation accepts any
|
|
# (M, N) at runtime without recompilation.
|
|
print("\nUnconstrained symbolic shapes:")
|
|
|
|
a_sym, b_sym = export.symbolic_shape("a, b")
|
|
input_shape_dtype = jax.ShapeDtypeStruct((a_sym, b_sym), jnp.float32)
|
|
_export_and_run(
|
|
f,
|
|
ref_f,
|
|
input_shape_dtype,
|
|
run_shapes=[(M, N), (M * 2, N * 4), (M * 4, N * 4)],
|
|
)
|
|
|
|
# ── 3. Constrained symbolic shapes (divisibility) ─────────────────────────
|
|
# Shapes are declared as multiples of a tile size via TensorSpec.divisibility.
|
|
# The symbolic expression "32*M, 16*N" tells jax.export that dim 0 is always
|
|
# a multiple of 32 and dim 1 is always a multiple of 16. This lets the
|
|
# compiler generate more efficient code (e.g. no remainder handling).
|
|
# Runtime shapes must satisfy these divisibility constraints.
|
|
print("\nConstrained symbolic shapes:")
|
|
|
|
@jax.jit
|
|
def f_divisible(a, b):
|
|
spec = TensorSpec(divisibility=(32, 16))
|
|
call = cutlass_call(
|
|
launch,
|
|
output_shape_dtype=a,
|
|
input_spec=(spec, spec),
|
|
output_spec=spec,
|
|
)
|
|
return jax.nn.sigmoid(call(a, b))
|
|
|
|
m_sym, n_sym = export.symbolic_shape("32*M, 16*N")
|
|
input_shape_dtype = jax.ShapeDtypeStruct((m_sym, n_sym), jnp.float32)
|
|
_export_and_run(
|
|
f_divisible,
|
|
ref_f,
|
|
input_shape_dtype,
|
|
run_shapes=[(M, N), (M * 2, N * 2), (M * 4, N * 4)],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Demonstration of using jax.export with functions with cutlass_call"
|
|
)
|
|
parser.add_argument("--M", default=512, type=int)
|
|
parser.add_argument("--N", default=256, type=int)
|
|
|
|
args = parser.parse_args()
|
|
run_example(args.M, args.N)
|
|
print("PASS")
|