mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 17:25:45 +00:00
330 lines
13 KiB
Python
330 lines
13 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
import argparse
|
|
import operator
|
|
from functools import partial
|
|
from typing import List
|
|
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
"""
|
|
An Elementwise Apply Example using CuTe DSL with cutlass.jax.cutlass_call
|
|
|
|
This example is similar to examples/ampere/elementwise_apply.py but demonstrates
|
|
how to run the code in a jax specific way using the cutlass_call primitive. It assumes
|
|
familiarity with basic CuTe DSL concepts as well as the cutlass_call primitive.
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
# Run with addition operation
|
|
python examples/jax/elementwise_apply_example.py --M 1024 --N 512 --op add
|
|
|
|
# Run with multiplication operation
|
|
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op mul
|
|
|
|
# Run with subtraction operation
|
|
python examples/ampere/elementwise_apply_example.py --M 1024 --N 512 --op sub
|
|
"""
|
|
|
|
|
|
@cute.kernel
|
|
def elementwise_apply_kernel(
|
|
op: cutlass.Constexpr,
|
|
mInputs: List[cute.Tensor],
|
|
mC: cute.Tensor,
|
|
cC: cute.Tensor, # coordinate tensor
|
|
shape: cute.Shape,
|
|
tv_layout: cute.Layout, # (tid, vid) -> logic coord
|
|
):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, bidy, _ = cute.arch.block_idx()
|
|
|
|
###############################################################################
|
|
# Slice to local tile of thread block
|
|
###############################################################################
|
|
blk_crd = ((None, None), (bidx, bidy))
|
|
|
|
# Leverage the meta-programming capability of the DSL to slice the tensors for each input
|
|
# All for loops below on input tensors would be fully unrolled automatically at compile time
|
|
# logical coord -> memory address
|
|
gInputs = [t[blk_crd] for t in mInputs] # (TileM, TileN)
|
|
gC = mC[blk_crd] # (TileM, TileN)
|
|
gCrd = cC[blk_crd] # (TileM, TileN)
|
|
|
|
print("[DSL INFO] Sliced Tensors per thread block:")
|
|
for i in cutlass.range_constexpr(len(gInputs)):
|
|
print(f"[DSL INFO] ctaInputs{i} = {gInputs[i].type}")
|
|
print(f"[DSL INFO] gC = {gC.type}")
|
|
print(f"[DSL INFO] gCrd = {gCrd.type}")
|
|
|
|
###############################################################################
|
|
# Compose with thread block TV layout to map thread & value indices to memory address
|
|
###############################################################################
|
|
# (tid, vid) -> memory address
|
|
tidfrgInputs = [cute.composition(t, tv_layout) for t in gInputs]
|
|
tidfrgC = cute.composition(gC, tv_layout)
|
|
tidfrgCrd = cute.composition(gCrd, tv_layout)
|
|
|
|
# repeat None like vid to remove hierarchy of layout
|
|
thr_crd = (tidx, cute.repeat_like(None, tidfrgInputs[0][1]))
|
|
|
|
###############################################################################
|
|
# Slice to local tile of thread
|
|
###############################################################################
|
|
# vid -> address
|
|
thrInputs = [t[thr_crd] for t in tidfrgInputs] # (V)
|
|
thrC = tidfrgC[thr_crd] # (V)
|
|
thrCrd = tidfrgCrd[thr_crd]
|
|
|
|
print("[DSL INFO] Sliced Tensors per thread:")
|
|
for i in cutlass.range_constexpr(len(thrInputs)):
|
|
print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}")
|
|
print(f"[DSL INFO] thrC = {thrC.type}")
|
|
print(f"[DSL INFO] thrCrd = {thrCrd.type}")
|
|
|
|
###############################################################################
|
|
# Compute predicate for out of boundary checks
|
|
###############################################################################
|
|
frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)
|
|
print(f"[DSL INFO] frgPred = {frgPred.type}")
|
|
|
|
for i in cutlass.range_constexpr(cute.size(frgPred)):
|
|
frgPred[i] = cute.elem_less(thrCrd[i], shape)
|
|
|
|
# if tidx == 0 and bidx == 0:
|
|
# cute.print_tensor(frgPred)
|
|
|
|
##########################################################
|
|
# Load data and compute result
|
|
##########################################################
|
|
|
|
# Load data before use. The compiler will optimize the copy and load
|
|
# operations to convert some memory ld/st into register uses.
|
|
result = op(*[thrInput.load() for thrInput in thrInputs])
|
|
thrC.store(result)
|
|
|
|
|
|
@cute.jit
|
|
def elementwise_apply(
|
|
op: cutlass.Constexpr, inputs, result: cute.Tensor, stream: cuda.CUstream
|
|
):
|
|
"""CUDA kernel applying binary operator on each element of two n-D input tensors in
|
|
CuTe Python and store to result tensor.
|
|
|
|
:param op: Binary operator or lambda function to apply element-wise
|
|
:type op: cutlass.Constexpr
|
|
:param a: First input tensor
|
|
:type a: cute.Tensor
|
|
:param b: Second input tensor
|
|
:type b: cute.Tensor
|
|
:param result: Output tensor to store the results of op(a, b)
|
|
:type result: cute.Tensor
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
|
|
# Baseline: naive TV layout
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (512, 4) tile
|
|
# * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad
|
|
# tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1)))
|
|
# cta_tiler = (512, 4)
|
|
|
|
# Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout)
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (4, 512) tile
|
|
# * tidx maps to mode-1 which is leading mode of input tensor for coalesced load
|
|
# tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1)))
|
|
# cta_tiler = (4, 512)
|
|
|
|
# Opt-2: 2D tile but worse
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (128, 16) logical tile
|
|
# * V layout is bad as contiguous mode is not on right-most
|
|
# * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most )
|
|
# tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128)))
|
|
# cta_tiler = (128, 16)
|
|
|
|
# Opt-3: SOL with 2D thread tile
|
|
# * mA layout: (4096, 4096):(4096, 1)
|
|
# * TV layout map to (64, 256) logical tile
|
|
# * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store
|
|
|
|
# Use 128bit(16B) load as canonicalized form of val_layout then recast to target element-type
|
|
coalesced_ldst_bytes = 16
|
|
|
|
# Compile time validation: expect same element type for all input tensors
|
|
assert all(t.element_type == inputs[0].element_type for t in inputs)
|
|
dtype = inputs[0].element_type
|
|
|
|
thr_layout = cute.make_ordered_layout((4, 64), order=(1, 0))
|
|
val_layout = cute.make_ordered_layout((16, coalesced_ldst_bytes), order=(1, 0))
|
|
val_layout = cute.recast_layout(dtype.width, 8, val_layout)
|
|
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
|
|
|
print("[DSL INFO] Input Tensors:")
|
|
for i, t in enumerate(inputs):
|
|
print(f"[DSL INFO] inputs{i} = {t}")
|
|
print(f"[DSL INFO] result = {result}")
|
|
|
|
print("[DSL INFO] Tiling Parameters:")
|
|
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
|
|
print(f"[DSL INFO] tv_layout = {tv_layout}")
|
|
|
|
print("[DSL INFO] Tiled Tensors:")
|
|
mInputs = [cute.zipped_divide(input, tiler_mn) for input in inputs]
|
|
# ((TileM, TileN), (RestM, RestN))
|
|
mC = cute.zipped_divide(result, tiler_mn)
|
|
|
|
# (RestM, RestN) -> (RestN, RestM)
|
|
remap_block = cute.make_ordered_layout(
|
|
cute.select(mInputs[0].shape[1], mode=[1, 0]), order=(1, 0)
|
|
)
|
|
for i, t in enumerate(mInputs):
|
|
print(f"[DSL INFO] gInputs{i} = {mInputs[i]}")
|
|
mInputs[i] = cute.composition(t, (None, remap_block))
|
|
print(f"[DSL INFO] gInputs{i} (remapped) = {mInputs[i]}")
|
|
|
|
mC = cute.composition(mC, (None, remap_block))
|
|
print(f"[DSL INFO] gC = {mC}")
|
|
|
|
idC = cute.make_identity_tensor(result.shape)
|
|
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
|
print(f"[DSL INFO] coord tensor = {cC}")
|
|
|
|
# Launch the kernel asynchronously
|
|
# Group input tensors into a list as a single argument
|
|
elementwise_apply_kernel(op, mInputs, mC, cC, result.shape, tv_layout).launch(
|
|
# Compute production at each mode of mC.shape[1] to get multi-dimensional grid size
|
|
grid=cute.product_each(mC.shape[1]),
|
|
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
stream=stream,
|
|
)
|
|
|
|
|
|
@cutlass.dsl_user_op
|
|
def leaky_relu(x, alpha, *, loc=None, ip=None):
|
|
return cute.where(x > 0, x, alpha * x, loc=loc, ip=ip)
|
|
|
|
|
|
def leaky_relu_ref(x, alpha):
|
|
import jax.numpy as jnp
|
|
|
|
return jnp.where(x > 0, x, alpha * x)
|
|
|
|
|
|
def run_and_verify(op, M, N, dtype, skip_ref_check=False):
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import cutlass.jax as cjax
|
|
import cutlass.jax.testing as testing
|
|
|
|
if op == "leaky_relu":
|
|
op = partial(leaky_relu, alpha=0.01)
|
|
ref_op = partial(leaky_relu_ref, alpha=0.01)
|
|
num_inputs = 1
|
|
else:
|
|
op = getattr(operator, op)
|
|
ref_op = op
|
|
num_inputs = 2
|
|
|
|
# This jax function is transformed using jax.jit to compile its contents
|
|
# into an efficient HLO executable.
|
|
@partial(jax.jit, static_argnums=[1])
|
|
def jax_function(inputs, op):
|
|
call = cjax.cutlass_call(
|
|
# Bind jax arguments to kernel signature
|
|
lambda stream, inputs, output, *, op: elementwise_apply(
|
|
op, inputs, output, stream
|
|
),
|
|
# Specify output shape/dtype of result
|
|
output_shape_dtype=jax.ShapeDtypeStruct(inputs[0].shape, inputs[0].dtype),
|
|
# Pass static/constexpr values as kwargs
|
|
op=op,
|
|
)
|
|
|
|
# Call the kernel!
|
|
return call(inputs)
|
|
|
|
@partial(jax.jit, static_argnums=[1])
|
|
def jax_ref_function(inputs, op):
|
|
return op(*inputs)
|
|
|
|
print("\nRunning Elementwise Apply test with:")
|
|
print(f"Tensor dimensions: [{M}, {N}]")
|
|
print(f"Input and Output Data type: {dtype}")
|
|
|
|
jax_dtype = cjax.cutlass_to_jax_dtype(dtype)
|
|
keys = jax.random.split(jax.random.key(1435), num_inputs)
|
|
inputs = [testing.create_tensor((M, N), jax_dtype, key) for key in keys]
|
|
|
|
print("Input tensor shapes:")
|
|
for i in range(num_inputs):
|
|
print(f"inputs[{i}]: {inputs[i].shape}, dtype: {inputs[i].dtype}")
|
|
|
|
epsilon = 1.2
|
|
if op in (operator.truediv, operator.floordiv):
|
|
inputs[1] = jnp.where(inputs[1] == 0, epsilon, inputs[1])
|
|
|
|
# Call the jax.jit function which will compile the kernel
|
|
c = jax_function(inputs, op)
|
|
|
|
if not skip_ref_check:
|
|
print("Executing elementwise apply kernel...")
|
|
c = jax_function(inputs, op)
|
|
print("Verifying results...")
|
|
assert jnp.allclose(ref_op(*inputs), c)
|
|
print("Results verified successfully!")
|
|
print(f"First few elements of result: \n{c[:3, :3]}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Demonstration of calling a kernel with cutlass_call"
|
|
)
|
|
parser.add_argument("--M", default=4096, type=int)
|
|
parser.add_argument("--N", default=4096, type=int)
|
|
parser.add_argument("--op", default="add", type=str)
|
|
parser.add_argument("--skip_ref_check", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
run_and_verify(
|
|
args.op,
|
|
args.M,
|
|
args.N,
|
|
dtype=cutlass.Float32,
|
|
skip_ref_check=args.skip_ref_check,
|
|
)
|
|
print("\nPASS")
|