Files
cutlass/examples/python/CuTeDSL/jax/cutlass_call_sharding.py
2026-04-07 12:16:05 -04:00

176 lines
5.6 KiB
Python

# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
from jax.experimental.custom_partitioning import custom_partitioning
import cutlass.cute as cute
import cutlass.jax as cjax
from cutlass.jax.testing import create_tensor
import cuda.bindings.driver as cuda
"""
Examples of combining jax.jit, jax.shard_map and custom_partitioning for sharding
and executing kernels across multiple GPU devices.
To run this example:
.. code-block:: bash
# Run with addition operation
python examples/jax/cutlass_call_sharding.py
"""
@cute.kernel
def kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)
frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)
frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)
cute.autovec_copy(a[None, tidx, bidx], frgA)
cute.autovec_copy(b[None, tidx, bidx], frgB)
frgC.store(frgA.load() + frgB.load())
cute.autovec_copy(frgC, c[None, tidx, bidx])
@cute.jit
def launch(
stream: cuda.CUstream,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
):
cute.printf("a: {}", a.layout)
cute.printf("b: {}", b.layout)
cute.printf("c: {}", c.layout)
kernel(a, b, c).launch(
grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream
)
def sharded_cutlass_call_impl(a_block, b_block):
"""The sharded implementation that operates on a single device."""
call = cjax.cutlass_call(
launch,
use_static_tensors=True,
output_shape_dtype=jax.ShapeDtypeStruct(a_block.shape, a_block.dtype),
)
ref_result = a_block + b_block
return call(a_block, b_block), ref_result
@custom_partitioning
def custom_shared_call(a, b):
return sharded_cutlass_call_impl(a, b)
def custom_shared_call_partitioner(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
result_shardings = tuple([arg_shardings[0]] * len(result_shape))
def lower_fn(*args):
return sharded_cutlass_call_impl(*args)
return mesh, lower_fn, result_shardings, arg_shardings
custom_shared_call.def_partition(custom_shared_call_partitioner)
def run_example():
# Create a device mesh with one axis b
ngpu = jax.device_count()
mesh = jax.make_mesh((ngpu,), "b", axis_types=(AxisType.Explicit,))
if ngpu == 1:
print("Note: only 1 GPU was detected.")
# We will shard our 3D tensors over b
sharding = P("b", None, None)
named_sharding = NamedSharding(mesh, sharding)
print("Testing shard_map...")
@partial(
jax.jit, static_argnums=[0, 1], out_shardings=(named_sharding, named_sharding)
)
def allocate_sharded_tensors(shape, dtype):
key = jax.random.key(1123)
a_key, b_key = jax.random.split(key, 2)
a = create_tensor(shape, dtype, a_key)
b = create_tensor(shape, dtype, b_key)
return a, b
@jax.jit
def compute(a, b):
# This jax.shard_map partitions the cutlass_call over the mesh.
@partial(
jax.shard_map,
mesh=mesh,
in_specs=(sharding, sharding),
out_specs=(sharding, sharding),
)
def sharded_call(a_block, b_block):
return sharded_cutlass_call_impl(a_block, b_block)
return sharded_call(a, b)
# Allocate (32, 16, 64) on each GPU
shape = (32 * ngpu, 16, 64)
dtype = jnp.float32
a, b = allocate_sharded_tensors(shape, dtype)
c, c_ref = compute(a, b)
assert jnp.allclose(c, c_ref)
print("Testing custom_partitioning...")
# Test custom_partitioning implementation which should produce identical results
@jax.jit
def compute_cp(a, b):
return custom_shared_call(a, b)
c, c_ref = compute_cp(a, b)
assert jnp.allclose(c, c_ref)
if __name__ == "__main__":
run_example()
print("PASS")