mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
PR update (#3103)
This commit is contained in:
committed by
GitHub
parent
4ca61d0662
commit
418d38a5de
1260
examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb
Normal file
1260
examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
366
examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py
Normal file
366
examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.jax as cjax
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
"""
|
||||
CuTe DSL kernels used by the ``cute_dsl_jax.ipynb`` notebook.
|
||||
|
||||
This module defines GPU kernels written in CuTe DSL (CUTLASS 4.x Python DSL)
|
||||
that are called from JAX via ``cutlass.jax.cutlass_call``. ``cutlass_call`` is a
|
||||
JAX primitive that triggers compilation of the kernel during lowering and embeds
|
||||
it into the HLO computation, so XLA can launch it efficiently without callback
|
||||
to Python.
|
||||
|
||||
Kernels provided:
|
||||
|
||||
- ``vector_add`` — element-wise c = a + b (3-D CuTe layout)
|
||||
- ``saxpy`` — y = alpha * x + y
|
||||
- ``relu`` — element-wise ReLU with flat indexing
|
||||
- ``fused_bias_relu`` — fused bias addition + ReLU
|
||||
- ``gemm`` — tiled matrix multiplication
|
||||
- ``elementwise_add`` — 2-D element-wise add (flat indexing, ``jax.export``-compatible)
|
||||
|
||||
The notebook imports these kernels and wraps each one with ``cutlass_call``
|
||||
inside ``@jax.jit`` functions. See ``cute_dsl_jax.ipynb`` for usage, validation,
|
||||
and step-by-step explanations.
|
||||
|
||||
This module is imported by the notebook and by ``cute_dsl_jax.py``. It can also
|
||||
be run directly to validate every kernel:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Interactive notebook (recommended for learning)
|
||||
jupyter lab cute_dsl_jax.ipynb
|
||||
|
||||
# Full demo as a standalone script
|
||||
python cute_dsl_jax_kernels.py
|
||||
"""
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Vector Add: c = a + b #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def vector_add_kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
||||
"""Per-thread kernel: each thread adds one element."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)
|
||||
frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)
|
||||
frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)
|
||||
|
||||
cute.autovec_copy(a[None, tidx, bidx], frgA)
|
||||
cute.autovec_copy(b[None, tidx, bidx], frgB)
|
||||
frgC.store(frgA.load() + frgB.load())
|
||||
cute.autovec_copy(frgC, c[None, tidx, bidx])
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_vector_add(
|
||||
stream: cuda.CUstream,
|
||||
a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
|
||||
):
|
||||
vector_add_kernel(a, b, c).launch(
|
||||
grid=[a.shape[-1], 1, 1],
|
||||
block=[a.shape[-2], 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# SAXPY: y = alpha * x + y #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def saxpy_kernel(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, alpha: float):
|
||||
"""SAXPY: out[i] = alpha * x[i] + y[i]."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
frgX = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type)
|
||||
frgY = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type)
|
||||
frgO = cute.make_rmem_tensor(cute.size(out, mode=[0]), out.element_type)
|
||||
|
||||
cute.autovec_copy(x[None, tidx, bidx], frgX)
|
||||
cute.autovec_copy(y[None, tidx, bidx], frgY)
|
||||
frgO.store(alpha * frgX.load() + frgY.load())
|
||||
cute.autovec_copy(frgO, out[None, tidx, bidx])
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_saxpy(
|
||||
stream: cuda.CUstream,
|
||||
x: cute.Tensor, y: cute.Tensor, out: cute.Tensor,
|
||||
*, alpha: float,
|
||||
):
|
||||
saxpy_kernel(x, y, out, alpha).launch(
|
||||
grid=[x.shape[-1], 1, 1],
|
||||
block=[x.shape[-2], 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# ReLU: out = max(0, x) #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def relu_kernel(x: cute.Tensor, out: cute.Tensor, N: int):
|
||||
"""Per-thread kernel: each thread computes ReLU of one element."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
bdx, _, _ = cute.arch.block_dim()
|
||||
|
||||
idx = bidx * bdx + tidx
|
||||
if idx < N:
|
||||
val = x[idx]
|
||||
out[idx] = cutlass.max(val, cutlass.Float32(0.0))
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_relu(
|
||||
stream: cuda.CUstream,
|
||||
x: cute.Tensor, out: cute.Tensor,
|
||||
*, N: int,
|
||||
):
|
||||
BLOCK_SIZE = 256
|
||||
grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
relu_kernel(x, out, N).launch(
|
||||
grid=[grid_size, 1, 1],
|
||||
block=[BLOCK_SIZE, 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Fused Bias + ReLU: out = max(0, x + bias[col]) #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def fused_bias_relu_kernel(
|
||||
x: cute.Tensor, bias: cute.Tensor, out: cute.Tensor, N: int, width: int,
|
||||
):
|
||||
"""Per-thread: out[i] = max(0, x[i] + bias[i % width])."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
bdx, _, _ = cute.arch.block_dim()
|
||||
|
||||
idx = bidx * bdx + tidx
|
||||
if idx < N:
|
||||
col = idx % width
|
||||
val = x[idx] + bias[col]
|
||||
out[idx] = cutlass.max(val, cutlass.Float32(0.0))
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_fused_bias_relu(
|
||||
stream: cuda.CUstream,
|
||||
x: cute.Tensor, bias: cute.Tensor, out: cute.Tensor,
|
||||
*, N: int, width: int,
|
||||
):
|
||||
BLOCK_SIZE = 256
|
||||
grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
fused_bias_relu_kernel(x, bias, out, N, width).launch(
|
||||
grid=[grid_size, 1, 1],
|
||||
block=[BLOCK_SIZE, 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# GEMM: D = A @ B #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def gemm_kernel(
|
||||
A: cute.Tensor, B: cute.Tensor, D: cute.Tensor,
|
||||
M: int, N: int, K: int, BLOCK_M: int, BLOCK_N: int,
|
||||
):
|
||||
"""Tiled GEMM: each thread accumulates output elements."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bm, bn, _ = cute.arch.block_idx()
|
||||
bdx, _, _ = cute.arch.block_dim()
|
||||
|
||||
for i in cutlass.range(tidx, BLOCK_M * BLOCK_N, bdx):
|
||||
row = i // BLOCK_N
|
||||
col = i % BLOCK_N
|
||||
m_idx = bm * BLOCK_M + row
|
||||
n_idx = bn * BLOCK_N + col
|
||||
if m_idx < M and n_idx < N:
|
||||
acc = cutlass.Float32(0.0)
|
||||
for k in cutlass.range(K):
|
||||
acc += A[m_idx * K + k] * B[k * N + n_idx]
|
||||
D[m_idx * N + n_idx] = acc
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_gemm(
|
||||
stream: cuda.CUstream,
|
||||
A: cute.Tensor, B: cute.Tensor, D: cute.Tensor,
|
||||
*, M: int, N: int, K: int,
|
||||
):
|
||||
BLOCK_M, BLOCK_N = 64, 64
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch(
|
||||
grid=[grid_m, grid_n, 1],
|
||||
block=[256, 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Element-wise Add (2-D, flat indexing) #
|
||||
# ------------------------------------------------------------------ #
|
||||
@cute.kernel
|
||||
def elementwise_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
||||
"""Per-thread kernel: 2-D element-wise add using flat indexing."""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
bdim, _, _ = cute.arch.block_dim()
|
||||
|
||||
thread_idx = bidx * bdim + tidx
|
||||
|
||||
m, n = gA.shape
|
||||
ni = thread_idx % n
|
||||
mi = thread_idx // n
|
||||
|
||||
a_val = gA[mi, ni]
|
||||
b_val = gB[mi, ni]
|
||||
gC[mi, ni] = a_val + b_val
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_elementwise_add(
|
||||
stream: cuda.CUstream,
|
||||
mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor,
|
||||
):
|
||||
num_threads_per_block = 256
|
||||
m, n = mA.shape
|
||||
elementwise_add_kernel(mA, mB, mC).launch(
|
||||
grid=((m * n) // num_threads_per_block, 1, 1),
|
||||
block=(num_threads_per_block, 1, 1),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Self-tests #
|
||||
# ------------------------------------------------------------------ #
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
BLOCK = 256
|
||||
N_BLOCKS = 4
|
||||
|
||||
# ── Vector Add ────────────────────────────────────────────────────
|
||||
# 3-D CuTe layout: (elems_per_thread, threads_per_block, num_blocks)
|
||||
a = jax.random.normal(jax.random.PRNGKey(0), (1, BLOCK, N_BLOCKS), dtype=jnp.float32)
|
||||
b = jax.random.normal(jax.random.PRNGKey(1), (1, BLOCK, N_BLOCKS), dtype=jnp.float32)
|
||||
call = cjax.cutlass_call(
|
||||
launch_vector_add,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
use_static_tensors=True,
|
||||
)
|
||||
c = jax.jit(call)(a, b)
|
||||
np.testing.assert_allclose(np.array(c), np.array(a + b), rtol=1e-5, atol=1e-5)
|
||||
print('vector_add: PASSED')
|
||||
|
||||
# ── SAXPY ─────────────────────────────────────────────────────────
|
||||
x = jax.random.normal(jax.random.PRNGKey(2), (1, BLOCK, N_BLOCKS), dtype=jnp.float32)
|
||||
y = jax.random.normal(jax.random.PRNGKey(3), (1, BLOCK, N_BLOCKS), dtype=jnp.float32)
|
||||
alpha = 2.5
|
||||
call = cjax.cutlass_call(
|
||||
launch_saxpy,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
use_static_tensors=True,
|
||||
alpha=alpha,
|
||||
)
|
||||
out = jax.jit(call)(x, y)
|
||||
np.testing.assert_allclose(np.array(out), np.array(alpha * x + y), rtol=1e-5, atol=1e-5)
|
||||
print('saxpy: PASSED')
|
||||
|
||||
# ── ReLU ──────────────────────────────────────────────────────────
|
||||
N_ELEM = BLOCK * N_BLOCKS
|
||||
x = jax.random.normal(jax.random.PRNGKey(4), (N_ELEM,), dtype=jnp.float32)
|
||||
call = cjax.cutlass_call(
|
||||
launch_relu,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
N=N_ELEM,
|
||||
)
|
||||
out = jax.jit(call)(x)
|
||||
np.testing.assert_allclose(np.array(out), np.array(jnp.maximum(x, 0)), rtol=1e-5, atol=1e-5)
|
||||
print('relu: PASSED')
|
||||
|
||||
# ── Fused Bias + ReLU ─────────────────────────────────────────────
|
||||
ROWS, COLS = 16, 64
|
||||
x = jax.random.normal(jax.random.PRNGKey(5), (ROWS * COLS,), dtype=jnp.float32)
|
||||
bias = jax.random.normal(jax.random.PRNGKey(6), (COLS,), dtype=jnp.float32)
|
||||
call = cjax.cutlass_call(
|
||||
launch_fused_bias_relu,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
N=ROWS * COLS, width=COLS,
|
||||
)
|
||||
out = jax.jit(call)(x, bias)
|
||||
ref = jnp.maximum(x.reshape(ROWS, COLS) + bias, 0).reshape(-1)
|
||||
np.testing.assert_allclose(np.array(out), np.array(ref), rtol=1e-5, atol=1e-5)
|
||||
print('fused_bias_relu: PASSED')
|
||||
|
||||
# ── GEMM ──────────────────────────────────────────────────────────
|
||||
M, N, K = 128, 128, 64
|
||||
A = jax.random.normal(jax.random.PRNGKey(7), (M * K,), dtype=jnp.float32)
|
||||
B = jax.random.normal(jax.random.PRNGKey(8), (K * N,), dtype=jnp.float32)
|
||||
call = cjax.cutlass_call(
|
||||
launch_gemm,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct((M * N,), A.dtype),
|
||||
M=M, N=N, K=K,
|
||||
)
|
||||
D = jax.jit(call)(A, B)
|
||||
ref = A.reshape(M, K) @ B.reshape(K, N)
|
||||
np.testing.assert_allclose(np.array(D.reshape(M, N)), np.array(ref), rtol=1e-2, atol=1e-2)
|
||||
print('gemm: PASSED')
|
||||
|
||||
# ── Elementwise Add (2-D) ─────────────────────────────────────────
|
||||
M, N = 16, 256
|
||||
a = jax.random.normal(jax.random.PRNGKey(9), (M, N), dtype=jnp.float32)
|
||||
b = jax.random.normal(jax.random.PRNGKey(10), (M, N), dtype=jnp.float32)
|
||||
call = cjax.cutlass_call(
|
||||
launch_elementwise_add,
|
||||
output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype),
|
||||
)
|
||||
c = jax.jit(call)(a, b)
|
||||
np.testing.assert_allclose(np.array(c), np.array(a + b), rtol=1e-5, atol=1e-5)
|
||||
print('elementwise_add: PASSED')
|
||||
|
||||
print('\nAll kernels passed.')
|
||||
Reference in New Issue
Block a user