From 418d38a5de245373d5bac809c06d1ecd1a196c63 Mon Sep 17 00:00:00 2001 From: Katja Sirazitdinova Date: Thu, 2 Apr 2026 14:00:41 +0400 Subject: [PATCH] PR update (#3103) --- .../python/CuTeDSL/jax/cute_dsl_jax.ipynb | 1260 +++++++++++++++++ .../CuTeDSL/jax/cute_dsl_jax_kernels.py | 366 +++++ 2 files changed, 1626 insertions(+) create mode 100644 examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb create mode 100644 examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py diff --git a/examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb b/examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb new file mode 100644 index 000000000..70ef40b75 --- /dev/null +++ b/examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb @@ -0,0 +1,1260 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Writing High-Performance GPU Kernels with CuTe DSL and JAX\n", + "\n", + "## Overview\n", + "\n", + "JAX provides excellent built-in GPU support through XLA, but sometimes you need to go beyond what the compiler can generate automatically. Custom GPU kernels let you exploit hardware-specific features, fuse operations that XLA misses, or implement algorithms that don't map cleanly to standard library calls. CuTe DSL bridges this gap by letting you write CUDA kernels in Python and plug them directly into JAX programs.\n", + "\n", + "**What you'll do:**\n", + "\n", + "- Install CUTLASS 4.x and its CuTe DSL Python front-end\n", + "- Write a **Vector Add** kernel using `@cute.kernel` and launch it with `@cute.jit`\n", + "- Integrate CuTe DSL kernels into JAX programs via `cutlass.jax.cutlass_call`\n", + "- Implement **SAXPY** (`y = alpha * x + y`) with scalar kernel arguments\n", + "- Write **ReLU** and **Fused Bias+ReLU** activation kernels for deep learning\n", + "- Build a **tiled GEMM** using tensor core MMA instructions\n", + "- Shard CUTLASS kernels across multiple GPUs with `jax.shard_map`\n", + "- **Export and serialize** JAX functions containing CUTLASS kernels with **`jax.export`**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "[CuTe DSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl.html) is the Python-native interface to [CUTLASS](https://docs.nvidia.com/cutlass/latest/) 4.4+, NVIDIA's open-source library of high-performance CUDA kernels. It exposes the same CuTe abstractions (layouts, tensors, thread-to-data mappings) that power CUTLASS's C++ template library, but authored entirely in Python.\n", + "\n", + "Traditionally, writing custom GPU kernels meant working in C++ or CUDA — a steep learning curve for Python-focused ML engineers. CuTe DSL changes this: you define per-thread logic with `@cute.kernel`, configure launch parameters with `@cute.jit`, and the CUTLASS JIT compiler generates optimized CUDA code behind the scenes. The `cutlass.jax` integration module then lets you call these kernels from JAX as if they were native operations, with full support for `@jax.jit`, automatic differentiation plumbing, and multi-device sharding.\n", + "\n", + "This notebook walks through progressively more complex kernels showing the patterns you'll reuse in your own custom operations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The CuTe mental model\n", + "\n", + "At its core, CuTe is an **index transformation DSL** — it provides abstractions for mapping logical coordinates to physical memory offsets. Everything in CuTe builds on the following concepts:\n", + "\n", + "**Shape** describes the dimensions of your data. A shape can be simple like `(M, N)` for a matrix, or hierarchical like `((2, 4), N)` where the first mode is itself subdivided. Hierarchical shapes are especially useful on GPUs, where work is organized in layers:\n", + "\n", + "- A **thread** is the smallest unit of execution — one thread runs one sequence of instructions.\n", + "- A **warp** is a group of 32 threads that execute in lockstep on the same hardware unit.\n", + "- A **block** is a group of threads (organized internally into warps) that share fast on-chip (shared) memory and can synchronize with each other.\n", + "- The **grid** is the collection of all blocks launched by a kernel.\n", + "\n", + "CuTe shapes can nest to mirror this hierarchy. Such a hierarchical shape can be used to model a GPU execution hierarchy — for example, 32 threads per warp × 8 warps per block, across N blocks — when bound to CUDA’s thread and block indices.\n", + "\n", + "**Coordinate** is a position within a shape. For a shape `(4, 8)`, the coordinate `(2, 5)` identifies one element — row 2, column 5.\n", + "\n", + "**Stride** tells CuTe how far you move in memory when you step along each dimension. In a row-major `(4, 8)` matrix, memory is laid out row by row: the first 8 elements belong to row 0, the next 8 to row 1, and so on. Moving one column to the right simply advances to the next element in memory (stride 1). Moving one row down skips over an entire row of 8 elements (stride 8). So the stride is `(8, 1)`.\n", + "\n", + "**Layout = (Shape, Stride)** is CuTe's central abstraction. Shape and stride must have the same rank — each logical dimension must have a corresponding stride. \n", + "\n", + "Although we think of tensors as multi-dimensional, GPU memory itself is just a long one-dimensional array of elements. Given a coordinate, a layout tells you where that element lives in memory. It does this by combining the coordinate with the stride:\n", + "\n", + "```\n", + "offset = coord[0] * stride[0] + coord[1] * stride[1] + ...\n", + "```\n", + "\n", + "In CuTe DSL, you can define layout using:\n", + "\n", + "```python\n", + "cute.make_layout((...), stride=(...))\n", + "```\n", + "\n", + "One important thing to note here: in *row-major* layout, elements of each row are stored contiguously in memory (so columns vary fastest), whereas in *column-major* layout, elements of each column are stored contiguously (so rows vary fastest), meaning the logical shape stays the same but the stride — and therefore the memory access pattern — changes.\n", + "\n", + "For example, with layout in row-major order `((4, 8), (8, 1))`, coordinate `(2, 5)` maps to offset `2*8 + 5*1 = 21`. \n", + "\n", + "A column-major layout for the same shape would use stride `(1, 4)`, so the same coordinate maps to `2*1 + 5*4 = 22`. \n", + "\n", + "The shape stays the same — only the stride changes.\n", + "\n", + "```python\n", + "row_major = cute.make_layout((M, N), stride=(N, cutlass.Int32(1)))\n", + "col_major = cute.make_layout((M, N), stride=(cutlass.Int32(1), M))\n", + "```\n", + "\n", + "This separation of logical structure from physical storage is what makes CuTe powerful. Algorithms operate on coordinates, while layouts decide how those coordinates map to memory. Change the stride, and you change the storage pattern — without rewriting the algorithm.\n", + "\n", + "In the following examples, you won’t see `make_layout` because the kernels operate on `cute.Tensor` objects and use CuTe’s tensor / fragment helpers (`cute.size`, `cute.make_rmem_tensor`, `cute.autovec_copy`, `Tensor[...]`) which already encode the shape, stride and indexing semantics the kernel needs. The code stays higher-level and avoids manual offset arithmetic or explicit layout construction — that’s deliberate: CuTe’s helpers are there so kernels read like algorithms, not pointer math." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hardware and software requirements\n", + "\n", + "| Requirement | Minimum | Recommended |\n", + "|------------|---------|-------------|\n", + "| GPU | SM 8.0+ (Ampere) | SM 9.0+ (Hopper) |\n", + "| CUDA | 12.9 | 13.1 |\n", + "| JAX | 0.8.1+ | Latest |\n", + "| CUTLASS | 4.4+ (CuTe DSL) | Latest |\n", + "| Python | 3.10+ | 3.12 |\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's check which GPUs are available in this environment. The `nvidia-smi` command shows the GPU model, driver version, CUDA toolkit version, and current memory usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We programmatically query the GPU's **compute capability** using `nvidia-smi`. This two-digit number (e.g., 9.0 for Hopper) tells us which hardware features are available. CuTe DSL requires SM 8.0 (Ampere) or newer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "\n", + "def get_compute_capability():\n", + " \"\"\"Query the compute capability of the first visible GPU.\"\"\"\n", + " out = subprocess.check_output(\n", + " [\"nvidia-smi\", \"--query-gpu=compute_cap\", \"--format=csv,noheader\"], text=True\n", + " )\n", + " major, minor = out.strip().split(\"\\n\")[0].split(\".\")\n", + " return int(major), int(minor)\n", + "\n", + "SM_MAJOR, SM_MINOR = get_compute_capability()\n", + "print(f\"Detected compute capability: SM {SM_MAJOR}.{SM_MINOR}\")\n", + "\n", + "if SM_MAJOR < 8:\n", + " print(\"WARNING: CuTe DSL requires SM 8.0+ (Ampere or newer).\")\n", + " print(\"Some examples may not run on this GPU.\")\n", + "else:\n", + " print(\"GPU is compatible with CuTe DSL.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install CuTe DSL\n", + "\n", + "The `nvidia-cutlass-dsl` package bundles CuTe DSL together with its JAX integration module (`cutlass.jax`). The `[cu13]` extra pulls in CUDA 13.x compatible runtime libraries. Version 4.4+ is required for the JAX integration.\n", + "\n", + "Refer to the [official documentation](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html) for a more comprehansive installation guide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"nvidia-cutlass-dsl[cu13]==4.4.0.dev1\" --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With CUTLASS installed, we import the libraries we'll use throughout the notebook: `cutlass` for kernel definitions, `jax` and `jnp` for array computation and JIT compilation, and `numpy` for result validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # suppress TF/XLA info & warnings\n", + "os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n", + "\n", + "import cutlass\n", + "from importlib.metadata import version as _pkg_version\n", + "print(f\"CUTLASS version: {_pkg_version('nvidia-cutlass-dsl')}\")\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining kernels\n", + "\n", + "In CuTe DSL, kernels are defined in two layers:\n", + "\n", + "1. **`@cute.kernel`** defines the per-thread program — the sequence of instructions executed by each thread instance.\n", + "2. **`@cute.jit`** compiles the kernel and specifies how it runs on the GPU: the grid (how many blocks), the block (how many threads per block), and the CUDA stream (which execution queue to launch into).\n", + " \n", + "CuTe DSL lowers Python kernels to CUDA/CUTLASS code and compiles them just-in-time using the CUTLASS JIT toolchain.\n", + "\n", + "**Note:** CuTe DSL uses `inspect.getsourcelines()` to parse kernel source, so `@cute.kernel` / `@cute.jit` functions must live in `.py` files rather than notebook cells. We have written them to a module [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) — refer to this file for the concept explainations. \n", + "\n", + "Here we import the pre-written kernel launch functions from `cute_dsl_jax_kernels.py`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cutlass.jax as cjax\n", + "from cute_dsl_jax_kernels import (\n", + " launch_vector_add, launch_saxpy, launch_gemm,\n", + " launch_relu, launch_fused_bias_relu,\n", + " launch_elementwise_add,\n", + ")\n", + "print(\"Imported: launch_vector_add, launch_saxpy, launch_gemm, launch_relu, launch_fused_bias_relu, launch_elementwise_add\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic kernel: vector add\n", + "\n", + "We’ll start with the simplest GPU kernel — vector add: `c[i] = a[i] + b[i]`. \n", + "\n", + "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", + "\n", + "### CuTe vector add kernel explained\n", + "\n", + "Each thread identifies itself using `thread_idx()` and `block_idx()`. Thread and block indices are accessed through `cute.arch` (e.g., `thread_idx`, `block_idx`), each returning `(x, y, z)` tuples, because CUDA’s execution and indexing are 3-dimensional by design. Since this kernel is launched in 1D, we only use the `x` component (`tidx` and `bidx`) and ignore the unused `y` and `z` values with `_`.\n", + "\n", + "```python\n", + "tidx, _, _ = cute.arch.thread_idx()\n", + "bidx, _, _ = cute.arch.block_idx()\n", + "```\n", + "\n", + "Inside the kernel, tensors are typically created in register memory using `cute.make_rmem_tensor`:\n", + "\n", + "```python\n", + "frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)\n", + "frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)\n", + "frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)\n", + "```\n", + "\n", + "Here, `frgA` and `frgB` hold the input values in registers, while `frgC` is a register fragment that will store the computed result before it is written back to global memory. `mode=[0]` selects the first dimension of the tensor — the \"elements per thread\" axis — so the register fragment is sized to hold exactly the data owned by one thread.\n", + "\n", + "Data movement between global and register memory is explicit: fragments are read using `load()` and written back using `store()`, while `cute.autovec_copy` performs efficient, vectorized transfers between memory spaces. Here, one element of `a` and `b` is loaded into register fragments, the sum in registers is computed and the result is stored back to `c`:\n", + "\n", + "```python\n", + "cute.autovec_copy(a[None, tidx, bidx], frgA)\n", + "cute.autovec_copy(b[None, tidx, bidx], frgB)\n", + "frgC.store(frgA.load() + frgB.load())\n", + "cute.autovec_copy(frgC, c[None, tidx, bidx])\n", + "```\n", + "\n", + "The `None` selects the entire first dimension (which has size 1 in this example), preserving the (`elems_per_thread`, `threads_per_block`, `num_blocks`) structure while allowing each thread to access its own slice of the tensor.\n", + "\n", + "> **Concept: Tensor = Pointer + Layout**\n", + "> \n", + "> A CuTe **Tensor** pairs a pointer to GPU memory with a **Layout** that describes how to navigate it. When the kernel receives `a: cute.Tensor`, it gets both the raw data and the index mapping. In this example, our tensors have shape `(1, BLOCK, num_blocks)` — one element per thread, `BLOCK=256` (defined in the example below) threads per block, spread across blocks. The layout maps a `(elems_per_thread, threads_per_block, num_blocks)` coordinate to the flat memory offset where that element lives. The kernel never computes offsets manually — it just indexes the tensor with `a[None, tidx, bidx]` and CuTe's layout handles the rest.\n", + "\n", + "The `@cute.kernel` defines one thread’s work. The `@cute.jit` launcher decides how many threads run, and how they’re grouped. It must follow the signature convention: `(stream, *inputs, *outputs, *, **kwargs)` — where `stream` is a CUDA stream managed by XLA, followed by input tensors, then output tensors.\n", + "\n", + "```python\n", + "@cute.jit\n", + "def launch_vector_add(\n", + " stream: cuda.CUstream,\n", + " a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,\n", + "):\n", + " vector_add_kernel(a, b, c).launch(\n", + " grid=[a.shape[-1], 1, 1],\n", + " block=[a.shape[-2], 1, 1],\n", + " stream=stream,\n", + " )\n", + "```\n", + "\n", + "We launch `a.shape[-2]` threads per block and `a.shape[-1]` blocks, directly matching the tensor’s `(1, threads_per_block, num_blocks)` layout so that `threadIdx.x` indexes the thread dimension and `blockIdx.x` indexes the block dimension. We use -2 and -1 because they refer to the last two tensor dimensions (threads per block and number of blocks), making the launch configuration robust even if additional leading dimensions are added.\n", + "\n", + "> **Concept: Layout composition**\n", + ">\n", + "> The vector add kernel expects 3-D tensors with shape `(elems_per_thread, threads_per_block, num_blocks)`, but our data is a flat 1-D array. The JAX wrapper reshapes from 1-D to 3-D before calling the kernel, and back afterward. In CuTe terms, this reshape is a **layout composition** — combining the original 1-D layout with a new layout that splits the single dimension into three. CuTe performs this algebraically: the composed layout maps 3-D coordinates directly to the original flat offsets, with no data movement. Reshaping is free — it's just a change of layout, not a copy.\n", + "\n", + "### JAX integration via `cutlass_call`\n", + "\n", + "The `cutlass.jax.cutlass_call` function wraps a CuTe `@cute.jit` launch function as a JAX custom call, so your CuTe/CUTLASS kernel can be invoked inside `@jax.jit`-compiled code and become part of the XLA computation graph.\n", + "\n", + "High-level flow:\n", + "\n", + "1. Prepare the data (pad + reshape)\n", + "\n", + "* We pad `N` up to a multiple of BLOCK so blocks are full (no partial last block), then reshape the 1-D vector into the 3-D logical tensor shape the kernel expects: `(elems_per_thread, threads_per_block, num_blocks)`.\n", + "* This reshape is free — it only changes the layout/interpretation of memory. No copy happens.\n", + "\n", + "```python\n", + "N = a.shape[0]\n", + "padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", + "a_pad = jnp.pad(a, (0, padded - N))\n", + "a_3d = a_pad.reshape(1, BLOCK, padded // BLOCK)\n", + "```\n", + "\n", + "2. Wrap the launcher\n", + "\n", + "* This returns a callable that accepts JAX arrays (DeviceArrays) and will, when executed inside `@jax.jit`, lower to a JAX custom call that launches your compiled CUTLASS kernel.\n", + "* `output_shape_dtype` tells JAX/XLA what the kernel will produce so shapes and dtypes are known for compilation and graph building.\n", + "* `use_static_tensors=True` asks the wrapper to treat the kernel tensors as static (compile-time) shapes where possible — this allows CuTe/CUTLASS to generate specialized, high-performance code for fixed shapes.\n", + "\n", + "```python\n", + "call = cjax.cutlass_call(\n", + " launch_fn, # The @cute.jit function\n", + " output_shape_dtype=..., # Shape/dtype of output(s)\n", + ")\n", + "result = call(*input_arrays) # Pass JAX arrays here\n", + "```\n", + " \n", + "3. Call the wrapped launcher\n", + "\n", + "* Inside a `@jax.jit` function this becomes a custom call node in the XLA graph; at runtime XLA provides a CUDA `CUstream` and device pointers, and the CUTLASS kernel is invoked on that stream.\n", + "* The callable accepts JAX arrays and returns a JAX array containing the kernel output.\n", + "\n", + "```python\n", + "c_3d = call(a_3d, b_3d)\n", + "```\n", + "\n", + "4. Unpack back to 1-D and trim padding\n", + "\n", + "* Convert the logical 3-D result back to a flat 1-D array and drop the padded tail.\n", + "\n", + "```python\n", + "return c_3d.reshape(-1)[:N]\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BLOCK = 256 # threads per block for vector add: 256 is a practical default: \n", + " # large enough to expose parallelism, small enough to scale \n", + " # well across different GPUs, and aligned with the hardware’s \n", + " # 32-thread warp execution model.\n", + "\n", + "@jax.jit\n", + "def jax_vector_add(a, b):\n", + " \"\"\"JAX-compatible vector add using CUTLASS kernel.\"\"\"\n", + " N = a.shape[0]\n", + " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", + " a_pad = jnp.pad(a, (0, padded - N))\n", + " b_pad = jnp.pad(b, (0, padded - N))\n", + " # Reshape to (1, BLOCK, num_blocks) for the CuTe kernel\n", + " a_3d = a_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " b_3d = b_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " call = cjax.cutlass_call(\n", + " launch_vector_add,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(a_3d.shape, a_3d.dtype),\n", + " use_static_tensors=True,\n", + " )\n", + " c_3d = call(a_3d, b_3d)\n", + " return c_3d.reshape(-1)[:N]\n", + "\n", + "print(\"jax_vector_add defined.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's test our CUTLASS vector add by comparing its output against JAX's built-in `+` operator. We generate two random arrays, run both implementations, and verify the results match element-by-element." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " # Test vector add\n", + "N = 1024\n", + "key = jax.random.PRNGKey(0)\n", + "a = jax.random.normal(key, (N,), dtype=jnp.float32)\n", + "b = jax.random.normal(jax.random.PRNGKey(1), (N,), dtype=jnp.float32)\n", + "\n", + "c = jax_vector_add(a, b)\n", + "c_ref = a + b\n", + "\n", + "np.testing.assert_allclose(np.array(c), np.array(c_ref), rtol=1e-5)\n", + "print(f\"Vector Add PASSED (N={N})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(c - c_ref))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SAXPY: scalar parameters in kernels\n", + "\n", + "**SAXPY** computes `out[i] = alpha * x[i] + y[i]`. This builds on the vector add pattern and introduces passing a **scalar argument** (`alpha`) alongside tensor arguments.\n", + "\n", + "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", + "\n", + "### CuTe SAXPY kernel explained\n", + "\n", + "The SAXPY kernel follows the same structure as vector add — identify the thread, load data into registers, compute, write back — with one addition: a scalar `alpha` parameter.\n", + "\n", + "```python\n", + "@cute.kernel\n", + "def saxpy_kernel(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, alpha: float):\n", + "```\n", + "\n", + "The signature adds `alpha: float` alongside the tensor arguments. CuTe DSL compiles scalar parameters just like CUDA kernel arguments — they are passed by value and available to every thread.\n", + "\n", + "The body is identical to vector add except for the arithmetic:\n", + "\n", + "```python\n", + "frgO.store(alpha * frgX.load() + frgY.load())\n", + "```\n", + "\n", + "Each thread loads its element of `x` and `y` into register fragments, multiplies `x` by `alpha`, adds `y`, and writes the result to `out`.\n", + "\n", + "The launcher passes `alpha` as a **keyword-only** argument (note the `*` in the signature):\n", + "\n", + "```python\n", + "@cute.jit\n", + "def launch_saxpy(\n", + " stream: cuda.CUstream,\n", + " x: cute.Tensor, y: cute.Tensor, out: cute.Tensor,\n", + " *, alpha: float,\n", + "):\n", + " saxpy_kernel(x, y, out, alpha).launch(\n", + " grid=[x.shape[-1], 1, 1],\n", + " block=[x.shape[-2], 1, 1],\n", + " stream=stream,\n", + " )\n", + "```\n", + "\n", + "The keyword-only convention matters for `cutlass_call`: positional arguments correspond to JAX tensors (managed by XLA), while keyword arguments are scalar values passed directly to the kernel. In the JAX wrapper below, `alpha=alpha` routes through `cutlass_call` as a kernel kwarg:\n", + "\n", + "```python\n", + "call = cjax.cutlass_call(\n", + " launch_saxpy,\n", + " ...,\n", + " alpha=alpha, # scalar kwarg → passed to the kernel\n", + ")\n", + "out_3d = call(x_3d, y_3d) # tensor args → managed by XLA\n", + "```\n", + "\n", + "> **Concept: Static vs dynamic integers**\n", + ">\n", + "> CUTLASS distinguishes between values known at **compile time** (static) and values known only at **runtime** (dynamic). Static integers — like tensor shapes passed with `use_static_tensors=True` or constants like `BLOCK_SIZE` — are baked into the generated CUDA code, letting the compiler unroll loops, optimize memory access patterns, and eliminate branches. Dynamic values like `alpha` are passed as regular kernel arguments and read at runtime. As a rule of thumb: make shapes and tile sizes static, keep data-dependent values dynamic.\n", + "\n", + "Note that `jax_saxpy` uses `@partial(jax.jit, static_argnums=(2,))` to mark `alpha` as a static argument to JAX. This means JAX will recompile the function whenever `alpha` changes — which is fine for a value that rarely varies, and lets the CUTLASS JIT bake the exact `alpha` value into the generated CUDA code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "@partial(jax.jit, static_argnums=(2,))\n", + "def jax_saxpy(x, y, alpha=2.0):\n", + " \"\"\"JAX-compatible SAXPY using CUTLASS kernel.\"\"\"\n", + " N = x.shape[0]\n", + " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", + " x_pad = jnp.pad(x, (0, padded - N))\n", + " y_pad = jnp.pad(y, (0, padded - N))\n", + " x_3d = x_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " y_3d = y_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " call = cjax.cutlass_call(\n", + " launch_saxpy,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_3d.shape, x_3d.dtype),\n", + " use_static_tensors=True,\n", + " alpha=alpha,\n", + " )\n", + " out_3d = call(x_3d, y_3d)\n", + " return out_3d.reshape(-1)[:N]\n", + "\n", + "print(\"jax_saxpy defined.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We test the SAXPY kernel with `alpha=2.5`, comparing against the reference computation `alpha * x + y`. The `assert_allclose` check verifies that results match within floating-point tolerance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test SAXPY\n", + "N = 2048\n", + "ALPHA = 2.5\n", + "key = jax.random.PRNGKey(42)\n", + "x = jax.random.normal(key, (N,), dtype=jnp.float32)\n", + "y = jax.random.normal(jax.random.PRNGKey(43), (N,), dtype=jnp.float32)\n", + "\n", + "result = jax_saxpy(x, y, alpha=ALPHA)\n", + "ref = ALPHA * x + y\n", + "\n", + "np.testing.assert_allclose(np.array(result), np.array(ref), rtol=1e-5)\n", + "print(f\"SAXPY PASSED (N={N}, alpha={ALPHA})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(result - ref))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deep learning activations: ReLU and fused bias+ReLU\n", + "\n", + "**ReLU** (`max(0, x)`) is the most widely used activation function in deep learning. It's elementwise and trivially parallel — a perfect custom kernel for ML workloads.\n", + "\n", + "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementations.\n", + "\n", + "### CuTe ReLU kernel explained\n", + "\n", + "The ReLU kernel uses a different pattern from vector add and SAXPY. Instead of the 3-D tensor approach with register fragments, it uses **flat 1-D indexing** — simpler and equally efficient for elementwise operations.\n", + "\n", + "```python\n", + "@cute.kernel\n", + "def relu_kernel(x: cute.Tensor, out: cute.Tensor, N: int):\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", + " bdx, _, _ = cute.arch.block_dim()\n", + "```\n", + "\n", + "A new call appears here: `cute.arch.block_dim()` returns the number of threads per block (set at launch time). Together with `thread_idx` and `block_idx`, it lets each thread compute its unique **global index**:\n", + "\n", + "```python\n", + " idx = bidx * bdx + tidx\n", + "```\n", + "\n", + "For example, if we launch 256 threads per block: thread 3 in block 2 gets `idx = 2 * 256 + 3 = 515`. This is the standard CUDA pattern for mapping threads to 1-D data.\n", + "\n", + "Because the data length `N` may not be a multiple of the block size, the last block could contain threads that point past the end of the array. The bounds check prevents out-of-bounds writes:\n", + "\n", + "```python\n", + " if idx < N:\n", + " val = x[idx]\n", + " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", + "```\n", + "\n", + "Here we index the tensors directly with `x[idx]` — no register fragments or `autovec_copy`. For simple elementwise operations this flat approach is cleaner. `cutlass.max` is CuTe DSL's built-in max function, and `cutlass.Float32(0.0)` creates a typed zero constant that matches the tensor's element type.\n", + "\n", + "The launcher computes how many blocks are needed to cover `N` elements:\n", + "\n", + "```python\n", + "@cute.jit\n", + "def launch_relu(stream, x, out, *, N):\n", + " BLOCK_SIZE = 256\n", + " grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE\n", + " relu_kernel(x, out, N).launch(\n", + " grid=[grid_size, 1, 1],\n", + " block=[BLOCK_SIZE, 1, 1],\n", + " stream=stream,\n", + " )\n", + "```\n", + "\n", + "The formula `(N + BLOCK_SIZE - 1) // BLOCK_SIZE` is ceiling division — it ensures we launch enough blocks even when `N` isn't a multiple of 256. The bounds check inside the kernel handles the leftover threads in the last block.\n", + "\n", + "### JAX wrapper: ReLU\n", + "\n", + "The JAX wrapper is simpler than vector add because we skip the 3-D reshape. The kernel works on flat 1-D data directly:\n", + "\n", + "```python\n", + "x_flat = x.reshape(-1) # flatten to 1-D\n", + "call = cjax.cutlass_call(\n", + " launch_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N, # scalar kwarg → bounds check inside kernel\n", + ")\n", + "out_flat = call(x_flat)\n", + "return out_flat.reshape(x.shape) # restore original shape\n", + "```\n", + "\n", + "`N` is passed as a keyword argument so the kernel knows how many elements are valid. The output is reshaped back to match the input's original shape (works for any dimensionality)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def jax_relu(x):\n", + " \"\"\"JAX-compatible ReLU using CUTLASS kernel.\"\"\"\n", + " N = x.size\n", + " x_flat = x.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N,\n", + " )\n", + " out_flat = call(x_flat)\n", + " return out_flat.reshape(x.shape)\n", + "\n", + "print(\"jax_relu defined.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We verify the ReLU kernel by comparing against `jax.nn.relu`. Positive values should pass through unchanged, and negative values should become zero." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test ReLU\n", + "N = 2048\n", + "key = jax.random.PRNGKey(7)\n", + "x = jax.random.normal(key, (N,), dtype=jnp.float32)\n", + "\n", + "result = jax_relu(x)\n", + "ref = jax.nn.relu(x)\n", + "\n", + "np.testing.assert_allclose(np.array(result), np.array(ref), rtol=1e-5)\n", + "print(f\"ReLU PASSED (N={N})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(result - ref))):.2e}\")\n", + "print(f\" Sample: x[:6] = {x[:6]}\")\n", + "print(f\" out[:6] = {result[:6]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuTe kernel explained: fused bias+ReLU\n", + "\n", + "**Fused bias+ReLU** computes `max(0, x + bias)` in a single kernel. This demonstrates **kernel fusion** — combining multiple operations into one GPU pass.\n", + "\n", + "Why fusion matters:\n", + "- **Without fusion:** `x + bias` writes an intermediate array to global memory, then `max(0, ...)` reads it back. That's two kernel launches and two round-trips to memory.\n", + "- **With fusion:** one kernel reads `x` and `bias`, computes the sum and the max, and writes the final result. Half the memory traffic, one launch instead of two.\n", + "\n", + "The kernel extends the ReLU pattern with a bias lookup:\n", + "\n", + "```python\n", + " idx = bidx * bdx + tidx\n", + " if idx < N:\n", + " col = idx % width\n", + " val = x[idx] + bias[col]\n", + " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", + "```\n", + "\n", + "The input `x` is a flattened `(batch, width)` matrix. `idx` is the global flat index, and `col = idx % width` recovers which column (feature) this element belongs to, so we can look up the correct bias. This modular indexing pattern is common in fused kernels that combine elementwise and broadcast operations.\n", + "\n", + "The launcher and JAX wrapper follow the same flat-indexing pattern as ReLU, with `N` (total elements) and `width` (columns) passed as keyword arguments:\n", + "\n", + "```python\n", + "call = cjax.cutlass_call(\n", + " launch_fused_bias_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N, width=width,\n", + ")\n", + "out_flat = call(x_flat, bias) # two input tensors: x and bias\n", + "```\n", + "\n", + "Note that `width` is marked as a static argument in the JAX wrapper via `static_argnums=(2,)`. This means JAX recompiles when the feature dimension changes, allowing CUTLASS to generate specialized code for each width." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "@partial(jax.jit, static_argnums=(2,))\n", + "def jax_fused_bias_relu(x, bias, width):\n", + " \"\"\"JAX-compatible fused Bias+ReLU using CUTLASS kernel.\n", + "\n", + " Args:\n", + " x: Input matrix of shape (batch, width), flattened to 1-D for the kernel.\n", + " bias: Bias vector of shape (width,).\n", + " width: Number of columns (static, passed as constexpr to the kernel).\n", + " \"\"\"\n", + " N = x.size\n", + " x_flat = x.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_fused_bias_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N, width=width,\n", + " )\n", + " out_flat = call(x_flat, bias)\n", + " return out_flat.reshape(x.shape)\n", + "\n", + "print(\"jax_fused_bias_relu defined.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test the fused kernel against the equivalent two-step JAX computation: add bias, then apply ReLU. The results should match exactly since both paths perform the same arithmetic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test Fused Bias+ReLU\n", + "BATCH, WIDTH = 64, 512\n", + "key = jax.random.PRNGKey(99)\n", + "x = jax.random.normal(key, (BATCH, WIDTH), dtype=jnp.float32)\n", + "bias = jax.random.normal(jax.random.PRNGKey(100), (WIDTH,), dtype=jnp.float32)\n", + "\n", + "result = jax_fused_bias_relu(x, bias, WIDTH)\n", + "ref = jnp.maximum(0, x + bias[None, :])\n", + "\n", + "np.testing.assert_allclose(np.array(result), np.array(ref), rtol=1e-5)\n", + "print(f\"Fused Bias+ReLU PASSED (batch={BATCH}, width={WIDTH})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(result - ref))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Going further:** For a production-grade generalization of elementwise kernels — with optimized TV (thread-value) layouts, vectorized memory access, and support for arbitrary binary operators including custom ops like `leaky_relu` — see NVIDIA's [elementwise_apply_example.py](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/jax/elementwise_apply_example.py)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced: Tiled GEMM\n", + "\n", + "This demonstrates a general matrix multiply (GEMM) kernel: `D = A @ B` where A is (M, K), B is (K, N), and D is (M, N). Unlike the previous elementwise kernels, GEMM requires cooperation across data dimensions — each output element is a dot product over K values.\n", + "\n", + "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", + "\n", + "> **Concept: Tiling**\n", + ">\n", + "> Tiling is CuTe's mechanism for partitioning data into sub-problems that map onto the GPU's execution hierarchy. In a GEMM, we divide the output matrix into `BLOCK_M x BLOCK_N` tiles, each assigned to one thread block. Within a tile, individual threads split the work further. CuTe's tiling operations decompose a layout into an \"inner\" part (the tile itself) and an \"outer\" part (which tile we're on). The block index `(bm, bn)` selects the outer coordinate, and thread indices work within the inner tile. This two-level decomposition — partition then index locally — is the fundamental pattern for mapping parallel GPU work to data.\n", + "\n", + "### CuTe GEMM kernel explained\n", + "\n", + "This is our first kernel using a **2-D grid**. Each block is responsible for a `BLOCK_M x BLOCK_N` tile of the output matrix:\n", + "\n", + "```python\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bm, bn, _ = cute.arch.block_idx()\n", + " bdx, _, _ = cute.arch.block_dim()\n", + "```\n", + "\n", + "Note `bm, bn, _` — the block index now has two meaningful components: `bm` selects the tile row, `bn` selects the tile column.\n", + "\n", + "Each tile contains `BLOCK_M * BLOCK_N` output elements, but we only have `bdx` (256) threads per block. A **stride loop** distributes the work evenly:\n", + "\n", + "```python\n", + " for i in cutlass.range(tidx, BLOCK_M * BLOCK_N, bdx):\n", + "```\n", + "\n", + "This loop starts at `tidx` and steps by `bdx` (the block size). For a `64 x 64 = 4096` element tile with 256 threads, each thread computes `4096 / 256 = 16` output elements. `cutlass.range` works like Python's `range()` but generates CUDA loop code.\n", + "\n", + "Within the loop, the flat tile index `i` is converted to 2-D tile-local coordinates, then to global matrix coordinates:\n", + "\n", + "```python\n", + " row = i // BLOCK_N # tile-local row\n", + " col = i % BLOCK_N # tile-local column\n", + " m_idx = bm * BLOCK_M + row # global row in D\n", + " n_idx = bn * BLOCK_N + col # global column in D\n", + "```\n", + "\n", + "A bounds check handles edge tiles where the matrix dimensions aren't multiples of the block size:\n", + "\n", + "```python\n", + " if m_idx < M and n_idx < N:\n", + "```\n", + "\n", + "The inner loop accumulates the dot product over the K dimension:\n", + "\n", + "```python\n", + " acc = cutlass.Float32(0.0)\n", + " for k in cutlass.range(K):\n", + " acc += A[m_idx * K + k] * B[k * N + n_idx]\n", + " D[m_idx * N + n_idx] = acc\n", + "```\n", + "\n", + "Here we use **manual row-major indexing**: `A[m_idx * K + k]` computes the offset into the flattened 1-D tensor for element `(m_idx, k)` of a row-major matrix with K columns. Similarly, `B[k * N + n_idx]` indexes element `(k, n_idx)`. Production CUTLASS kernels use multi-dimensional CuTe tensor indexing instead, but explicit indexing makes the memory layout visible for learning.\n", + "\n", + "### Launch configuration\n", + "\n", + "The launcher sets up a 2-D grid matching the tile decomposition:\n", + "\n", + "```python\n", + "@cute.jit\n", + "def launch_gemm(stream, A, B, D, *, M, N, K):\n", + " BLOCK_M, BLOCK_N = 64, 64\n", + " grid_m = (M + BLOCK_M - 1) // BLOCK_M\n", + " grid_n = (N + BLOCK_N - 1) // BLOCK_N\n", + " gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch(\n", + " grid=[grid_m, grid_n, 1],\n", + " block=[256, 1, 1],\n", + " stream=stream,\n", + " )\n", + "```\n", + "\n", + "- `grid=[grid_m, grid_n, 1]` — one block per output tile, arranged in a 2-D grid\n", + "- `block=[256, 1, 1]` — 256 threads per block, each handling multiple elements via the stride loop\n", + "- `M, N, K, BLOCK_M, BLOCK_N` are all passed as compile-time constants to the kernel\n", + "\n", + "### JAX wrapper for GEMM\n", + "\n", + "The JAX wrapper flattens both input matrices to 1-D (matching the kernel's flat indexing), passes the matrix dimensions as keyword arguments, and reshapes the result:\n", + "\n", + "```python\n", + "a_flat = a.reshape(-1)\n", + "b_flat = b.reshape(-1)\n", + "call = cjax.cutlass_call(\n", + " launch_gemm,\n", + " output_shape_dtype=jax.ShapeDtypeStruct((M * N,), a.dtype),\n", + " M=M, N=N, K=K,\n", + ")\n", + "d_flat = call(a_flat, b_flat)\n", + "return d_flat.reshape(M, N)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def jax_cutlass_gemm(a, b):\n", + " \"\"\"JAX wrapper for the CUTLASS GEMM kernel.\"\"\"\n", + " M, K = a.shape\n", + " _, N = b.shape\n", + " a_flat = a.reshape(-1)\n", + " b_flat = b.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_gemm,\n", + " output_shape_dtype=jax.ShapeDtypeStruct((M * N,), a.dtype),\n", + " M=M, N=N, K=K,\n", + " )\n", + " d_flat = call(a_flat, b_flat)\n", + " return d_flat.reshape(M, N)\n", + "\n", + "print(\"jax_cutlass_gemm defined.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test the CUTLASS GEMM against JAX's `jnp.matmul`. We use relaxed tolerances (`rtol=1e-2`) because our simple kernel accumulates the K-dimension in a different order than cuBLAS, leading to small floating-point differences that are expected and harmless." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test GEMM\n", + "M, N, K = 256, 256, 128\n", + "key = jax.random.PRNGKey(0)\n", + "A = jax.random.normal(key, (M, K), dtype=jnp.float32)\n", + "B = jax.random.normal(jax.random.PRNGKey(1), (K, N), dtype=jnp.float32)\n", + "\n", + "D = jax_cutlass_gemm(A, B)\n", + "D_ref = jnp.matmul(A, B)\n", + "\n", + "np.testing.assert_allclose(np.array(D), np.array(D_ref), rtol=1e-2, atol=1e-2)\n", + "print(f\"GEMM PASSED (M={M}, N={N}, K={K})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(D - D_ref))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance comparison\n", + "\n", + "Let's compare our CUTLASS GEMM kernel against JAX's built-in `jnp.matmul` (which calls cuBLAS under the hood).\n", + "\n", + "Our simple tiled kernel is **not expected to beat cuBLAS** — cuBLAS is one of the most heavily optimized libraries in existence, with hand-tuned assembly for each GPU architecture. The goal here is to show the integration pattern and demonstrate that custom kernels produce correct results.\n", + "\n", + "CuTe DSL's real value shows up when you need kernels that cuBLAS doesn't provide: custom fusions, non-standard data layouts, mixed-precision schemes, or operations specific to your model architecture.\n", + "\n", + "The benchmark below runs each implementation 20 times (after a warmup pass to trigger JIT compilation) and reports the average wall-clock time. `block_until_ready()` ensures we time the actual GPU execution, not just the asynchronous launch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "M, N, K = 512, 512, 512\n", + "A = jax.random.normal(jax.random.PRNGKey(0), (M, K), dtype=jnp.float32)\n", + "B = jax.random.normal(jax.random.PRNGKey(1), (K, N), dtype=jnp.float32)\n", + "\n", + "# Warmup\n", + "_ = jax_cutlass_gemm(A, B).block_until_ready()\n", + "_ = jnp.matmul(A, B).block_until_ready()\n", + "\n", + "NUM_RUNS = 20\n", + "\n", + "# Time CUTLASS GEMM\n", + "start = time.perf_counter() \n", + "for _ in range(NUM_RUNS):\n", + " _ = jax_cutlass_gemm(A, B).block_until_ready()\n", + "cutlass_time = (time.perf_counter() - start) / NUM_RUNS\n", + "\n", + "# Time JAX matmul\n", + "start = time.perf_counter()\n", + "for _ in range(NUM_RUNS):\n", + " _ = jnp.matmul(A, B).block_until_ready()\n", + "jax_time = (time.perf_counter() - start) / NUM_RUNS\n", + "\n", + "print(f\"Matrix size: {M}x{N}x{K}\")\n", + "print(f\"CUTLASS GEMM: {cutlass_time*1000:.3f} ms\")\n", + "print(f\"JAX jnp.matmul: {jax_time*1000:.3f} ms\")\n", + "print(f\"Ratio (CUTLASS / JAX): {cutlass_time / jax_time:.2f}x\")\n", + "print()\n", + "print(\"Note: Our simple tiled kernel is not expected to beat cuBLAS.\")\n", + "print(\"CuTe DSL's value is in specialized kernels cuBLAS doesn't provide.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-GPU: sharding CUTLASS kernels via `jax.shard_map`\n", + "\n", + "One of JAX's key strengths is transparent multi-device execution. CUTLASS kernels integrated via `cutlass_call` participate fully in JAX's sharding APIs, so you can distribute work across all available GPUs without modifying the kernel code.\n", + "\n", + "### How sharding works\n", + "\n", + "The key idea: split the data across devices, run the same kernel independently on each device's local shard, and let JAX handle the coordination.\n", + "\n", + "**1. Create a device mesh.** A mesh maps physical devices to named logical axes:\n", + "\n", + "```python\n", + "mesh = jax.make_mesh((num_devices,), \"x\")\n", + "```\n", + "\n", + "This creates a 1-D mesh with `num_devices` devices along an axis called `\"x\"`. For 8 GPUs, the mesh maps device 0 through device 7 to positions 0–7 along the `\"x\"` axis.\n", + "\n", + "**2. Define the sharding spec.** `PartitionSpec` tells JAX how to slice each tensor dimension across the mesh:\n", + "\n", + "```python\n", + "sharding = P(None, None, \"x\")\n", + "```\n", + "\n", + "For our 3-D tensors with shape `(elems_per_thread, threads_per_block, num_blocks)`:\n", + "- `None` — don't shard the first dimension (elems per thread, stays local)\n", + "- `None` — don't shard the second dimension (threads per block, stays local)\n", + "- `\"x\"` — shard the third dimension (blocks) across devices on the `\"x\"` axis\n", + "\n", + "So with 8 devices and 128 total blocks, each device gets a tensor of shape `(1, 256, 16)` — its 16 local blocks.\n", + "\n", + "**3. Use `shard_map` to run per-device code.** `shard_map` wraps a function so that each device receives only its local shard:\n", + "\n", + "```python\n", + "@partial(\n", + " shard_map,\n", + " mesh=mesh,\n", + " in_specs=(sharding, sharding),\n", + " out_specs=sharding,\n", + ")\n", + "def _add(a_shard, b_shard):\n", + " call = cjax.cutlass_call(\n", + " launch_vector_add,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(a_shard.shape, a_shard.dtype),\n", + " use_static_tensors=True,\n", + " )\n", + " return call(a_shard, b_shard)\n", + "```\n", + "\n", + "Inside `_add`, the code is identical to single-GPU — it sees a regular tensor and calls the same CUTLASS kernel. The kernel has no idea it's running on multiple GPUs. JAX handles splitting inputs before the kernel and reassembling outputs afterward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "from functools import partial\n", + "from jax.sharding import PartitionSpec as P\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\", DeprecationWarning)\n", + " from jax.experimental.shard_map import shard_map\n", + "\n", + "num_devices = len(jax.devices())\n", + "print(f\"Number of devices: {num_devices}\")\n", + "\n", + "if num_devices > 1:\n", + " mesh = jax.make_mesh((num_devices,), \"x\")\n", + " # Kernel expects 3-D tensors: (elems_per_thread, threads, blocks)\n", + " # Shard along the blocks axis (last dim)\n", + " sharding = P(None, None, \"x\")\n", + "\n", + " @jax.jit\n", + " def sharded_vector_add(a, b):\n", + " @partial(\n", + " shard_map,\n", + " mesh=mesh,\n", + " in_specs=(sharding, sharding),\n", + " out_specs=sharding,\n", + " )\n", + " def _add(a_shard, b_shard):\n", + " call = cjax.cutlass_call(\n", + " launch_vector_add,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(\n", + " a_shard.shape, a_shard.dtype\n", + " ),\n", + " use_static_tensors=True,\n", + " )\n", + " return call(a_shard, b_shard)\n", + " return _add(a, b)\n", + "\n", + " # Create 3-D tensors: (1, 256, total_blocks) with total_blocks divisible by device count\n", + " blocks_per_device = 16\n", + " total_blocks = blocks_per_device * num_devices\n", + " shape = (1, BLOCK, total_blocks)\n", + " a_m = jax.random.normal(jax.random.PRNGKey(10), shape, dtype=jnp.float32)\n", + " b_m = jax.random.normal(jax.random.PRNGKey(11), shape, dtype=jnp.float32)\n", + "\n", + " c_m = sharded_vector_add(a_m, b_m)\n", + " np.testing.assert_allclose(np.array(c_m), np.array(a_m + b_m), rtol=1e-5)\n", + " N_total = int(np.prod(shape))\n", + " print(f\"Sharded Vector Add PASSED across {num_devices} devices (N={N_total})\")\n", + "else:\n", + " print(\"Only 1 device detected. Skipping multi-GPU example.\")\n", + " print(\"On a multi-GPU system, shard_map distributes CUTLASS kernels across devices.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exporting CUTLASS kernels with `jax.export`\n", + "\n", + "So far, every kernel we've written lives inside a `@jax.jit` function — it compiles and runs within the current Python process. But what if you want to **save** a compiled JAX function containing a CUTLASS kernel, ship it to another machine, or load it in a non-Python runtime?\n", + "\n", + "That's what `jax.export` does. It takes a JIT-compiled function and produces a **standalone, serialized artifact** that you can save to disk, send over the network, and reload later — even after the original Python program has exited. Without `jax.export`, JAX functions are only compiled and callable inside the same Python process through `jit`.\n", + "\n", + "With `jax.export` you get:\n", + "\n", + "- **Serialization** — turn your staged JAX computation into a blob that can be stored and reused\n", + "- **Interoperability** — future tools could invoke this from non-Python runtimes (TensorFlow, C++, other frameworks)\n", + "- **Stable HLO output** — useful for ahead-of-time (AOT) compilation, deployment, and cross-platform interoperability\n", + "\n", + "For CUTLASS kernels specifically:\n", + "\n", + "- The exported function includes **custom calls** to CUTLASS kernels — these aren't part of JAX's built-in compilation pipeline. `get_export_disabled_safety_checks()` tells JAX that these custom calls are safe to include in the exported output.\n", + "- With **symbolic shapes**, the exported artifact works for multiple input sizes without recompilation. The kernel doesn't have to be recompiled for new input shapes after export.\n", + "\n", + "### What `jax.export` gives you\n", + "\n", + "- **A StableHLO representation** of the compiled function (the lowered intermediate representation)\n", + "- **Metadata** about the function's inputs and outputs\n", + "- **A serialized blob** you can save to disk or transmit over the network\n", + "- **A callable object** (`rehydrated.call(...)`) that works independently of the code that built it\n", + "\n", + "### How it works\n", + "\n", + "The flow is straightforward:\n", + "\n", + "```python\n", + "from jax import export\n", + "from cutlass.jax import get_export_disabled_safety_checks\n", + "\n", + "# 1. Export the JIT-compiled function\n", + "exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())\n", + "\n", + "# 2. Specialize to a signature (concrete or symbolic shapes) and serialize\n", + "traced = exported(shape_dtype_spec, shape_dtype_spec)\n", + "blob = traced.serialize()\n", + "\n", + "# 3. Later: deserialize and call with real data\n", + "rehydrated = export.deserialize(blob)\n", + "result = rehydrated.call(a, b)\n", + "```\n", + "\n", + "The following example is adapted from [NVIDIA's official export example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/jax/cutlass_call_export.py). It exports a function that adds two matrices with a CUTLASS kernel and applies `sigmoid`, then serializes, deserializes, and verifies the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import export\n", + "from cutlass.jax import get_export_disabled_safety_checks\n", + "\n", + "# Define a function that uses a CUTLASS kernel + JAX ops.\n", + "# We use launch_elementwise_add which accepts 2-D tensors directly\n", + "# with flat indexing — compatible with jax.export's tracing.\n", + "@jax.jit\n", + "def f(a, b):\n", + " call = cjax.cutlass_call(launch_elementwise_add, output_shape_dtype=a)\n", + " return jax.nn.sigmoid(call(a, b))\n", + "\n", + "# Reference implementation (pure JAX)\n", + "@jax.jit\n", + "def ref_f(a, b):\n", + " return jax.nn.sigmoid(a + b)\n", + "\n", + "# --- Export with concrete shapes ---\n", + "M, N = 512, 256\n", + "export_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)\n", + "\n", + "print(f\"Exporting with input signature: ({export_shape_dtype}, {export_shape_dtype})\")\n", + "\n", + "# Export the function — get_export_disabled_safety_checks() tells JAX\n", + "# that CUTLASS custom call targets are safe to include\n", + "exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())\n", + "traced = exported(export_shape_dtype, export_shape_dtype)\n", + "\n", + "# Serialize to a byte blob\n", + "blob = traced.serialize()\n", + "print(f\"Serialized computation: {len(blob):,} bytes\")\n", + "\n", + "# Deserialize and run — this works independently of the original function\n", + "rehydrated = export.deserialize(blob)\n", + "\n", + "key = jax.random.PRNGKey(1123)\n", + "a = jax.random.normal(key, (M, N), dtype=jnp.float32)\n", + "b = jax.random.normal(jax.random.PRNGKey(456), (M, N), dtype=jnp.float32)\n", + "\n", + "c = rehydrated.call(a, b)\n", + "c_ref = ref_f(a, b)\n", + "\n", + "np.testing.assert_allclose(np.array(c), np.array(c_ref), rtol=1e-5)\n", + "print(f\"Export + Deserialize PASSED (M={M}, N={N})\")\n", + "print(f\" Max error: {float(jnp.max(jnp.abs(c - c_ref))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Exporting with symbolic shapes\n", + "\n", + "With concrete shapes, the exported artifact only works for the exact dimensions it was traced with. **Symbolic shapes** lift this restriction — they let you export once and call with any compatible dimensions, without recompilation.\n", + "\n", + "`export.symbolic_shape(\"a, b\")` creates symbolic dimension variables. The exported function is parameterized over these variables, so the same serialized blob works for `(512, 256)`, `(1024, 1024)`, or any other shape." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Export with symbolic shapes ---\n", + "a_sym, b_sym = export.symbolic_shape(\"a, b\")\n", + "symbolic_shape_dtype = jax.ShapeDtypeStruct((a_sym, b_sym), jnp.float32)\n", + "\n", + "print(f\"Exporting with symbolic signature: ({symbolic_shape_dtype}, {symbolic_shape_dtype})\")\n", + "\n", + "exported_sym = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())\n", + "traced_sym = exported_sym(symbolic_shape_dtype, symbolic_shape_dtype)\n", + "blob_sym = traced_sym.serialize()\n", + "print(f\"Serialized computation: {len(blob_sym):,} bytes\")\n", + "\n", + "rehydrated_sym = export.deserialize(blob_sym)\n", + "\n", + "# Call with different shapes — no recompilation needed.\n", + "# The same serialized blob works for any (M, N) where M*N is a\n", + "# multiple of the kernel's block size (256).\n", + "for shape in [(512, 256), (1024, 512), (2048, 1024)]:\n", + " a = jax.random.normal(jax.random.PRNGKey(42), shape, dtype=jnp.float32)\n", + " b = jax.random.normal(jax.random.PRNGKey(43), shape, dtype=jnp.float32)\n", + " c = rehydrated_sym.call(a, b)\n", + " c_ref = ref_f(a, b)\n", + " np.testing.assert_allclose(np.array(c), np.array(c_ref), rtol=1e-5)\n", + " print(f\" Symbolic export PASSED for shape {shape}\")\n", + "\n", + "print(\"All symbolic shape tests passed.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned to:\n", + "\n", + "- Define GPU kernels in Python with **`@cute.kernel`** and **`@cute.jit`**\n", + "- Bridge CuTe DSL kernels into JAX via **`cutlass.jax.cutlass_call`**\n", + "- Pass both tensor and scalar arguments to custom kernels\n", + "- Write **ReLU** and **Fused Bias+ReLU** activation kernels for deep learning\n", + "- Demonstrate **kernel fusion** — combining multiple ops into a single GPU kernel\n", + "- Build a **tiled GEMM** kernel using CuTe DSL abstractions\n", + "- Distribute CUTLASS kernels across GPUs with **`jax.shard_map`**\n", + "- **Export and serialize** JAX functions containing CUTLASS kernels with **`jax.export`**\n", + "\n", + "CuTe DSL is the right tool when you need direct control over tensor core matrix multiply-accumulate (MMA) instructions, shared memory layouts, and warp-level operations." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py b/examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py new file mode 100644 index 000000000..708936871 --- /dev/null +++ b/examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py @@ -0,0 +1,366 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass +import cutlass.cute as cute +import cutlass.jax as cjax +import cuda.bindings.driver as cuda + +""" +CuTe DSL kernels used by the ``cute_dsl_jax.ipynb`` notebook. + +This module defines GPU kernels written in CuTe DSL (CUTLASS 4.x Python DSL) +that are called from JAX via ``cutlass.jax.cutlass_call``. ``cutlass_call`` is a +JAX primitive that triggers compilation of the kernel during lowering and embeds +it into the HLO computation, so XLA can launch it efficiently without callback +to Python. + +Kernels provided: + +- ``vector_add`` — element-wise c = a + b (3-D CuTe layout) +- ``saxpy`` — y = alpha * x + y +- ``relu`` — element-wise ReLU with flat indexing +- ``fused_bias_relu`` — fused bias addition + ReLU +- ``gemm`` — tiled matrix multiplication +- ``elementwise_add`` — 2-D element-wise add (flat indexing, ``jax.export``-compatible) + +The notebook imports these kernels and wraps each one with ``cutlass_call`` +inside ``@jax.jit`` functions. See ``cute_dsl_jax.ipynb`` for usage, validation, +and step-by-step explanations. + +This module is imported by the notebook and by ``cute_dsl_jax.py``. It can also +be run directly to validate every kernel: + +.. code-block:: bash + + # Interactive notebook (recommended for learning) + jupyter lab cute_dsl_jax.ipynb + + # Full demo as a standalone script + python cute_dsl_jax_kernels.py +""" + + +# ------------------------------------------------------------------ # +# Vector Add: c = a + b # +# ------------------------------------------------------------------ # +@cute.kernel +def vector_add_kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): + """Per-thread kernel: each thread adds one element.""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) + frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type) + frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type) + + cute.autovec_copy(a[None, tidx, bidx], frgA) + cute.autovec_copy(b[None, tidx, bidx], frgB) + frgC.store(frgA.load() + frgB.load()) + cute.autovec_copy(frgC, c[None, tidx, bidx]) + + +@cute.jit +def launch_vector_add( + stream: cuda.CUstream, + a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, +): + vector_add_kernel(a, b, c).launch( + grid=[a.shape[-1], 1, 1], + block=[a.shape[-2], 1, 1], + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# SAXPY: y = alpha * x + y # +# ------------------------------------------------------------------ # +@cute.kernel +def saxpy_kernel(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, alpha: float): + """SAXPY: out[i] = alpha * x[i] + y[i].""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + frgX = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) + frgY = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type) + frgO = cute.make_rmem_tensor(cute.size(out, mode=[0]), out.element_type) + + cute.autovec_copy(x[None, tidx, bidx], frgX) + cute.autovec_copy(y[None, tidx, bidx], frgY) + frgO.store(alpha * frgX.load() + frgY.load()) + cute.autovec_copy(frgO, out[None, tidx, bidx]) + + +@cute.jit +def launch_saxpy( + stream: cuda.CUstream, + x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, + *, alpha: float, +): + saxpy_kernel(x, y, out, alpha).launch( + grid=[x.shape[-1], 1, 1], + block=[x.shape[-2], 1, 1], + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# ReLU: out = max(0, x) # +# ------------------------------------------------------------------ # +@cute.kernel +def relu_kernel(x: cute.Tensor, out: cute.Tensor, N: int): + """Per-thread kernel: each thread computes ReLU of one element.""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdx, _, _ = cute.arch.block_dim() + + idx = bidx * bdx + tidx + if idx < N: + val = x[idx] + out[idx] = cutlass.max(val, cutlass.Float32(0.0)) + + +@cute.jit +def launch_relu( + stream: cuda.CUstream, + x: cute.Tensor, out: cute.Tensor, + *, N: int, +): + BLOCK_SIZE = 256 + grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + relu_kernel(x, out, N).launch( + grid=[grid_size, 1, 1], + block=[BLOCK_SIZE, 1, 1], + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# Fused Bias + ReLU: out = max(0, x + bias[col]) # +# ------------------------------------------------------------------ # +@cute.kernel +def fused_bias_relu_kernel( + x: cute.Tensor, bias: cute.Tensor, out: cute.Tensor, N: int, width: int, +): + """Per-thread: out[i] = max(0, x[i] + bias[i % width]).""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdx, _, _ = cute.arch.block_dim() + + idx = bidx * bdx + tidx + if idx < N: + col = idx % width + val = x[idx] + bias[col] + out[idx] = cutlass.max(val, cutlass.Float32(0.0)) + + +@cute.jit +def launch_fused_bias_relu( + stream: cuda.CUstream, + x: cute.Tensor, bias: cute.Tensor, out: cute.Tensor, + *, N: int, width: int, +): + BLOCK_SIZE = 256 + grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE + fused_bias_relu_kernel(x, bias, out, N, width).launch( + grid=[grid_size, 1, 1], + block=[BLOCK_SIZE, 1, 1], + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# GEMM: D = A @ B # +# ------------------------------------------------------------------ # +@cute.kernel +def gemm_kernel( + A: cute.Tensor, B: cute.Tensor, D: cute.Tensor, + M: int, N: int, K: int, BLOCK_M: int, BLOCK_N: int, +): + """Tiled GEMM: each thread accumulates output elements.""" + tidx, _, _ = cute.arch.thread_idx() + bm, bn, _ = cute.arch.block_idx() + bdx, _, _ = cute.arch.block_dim() + + for i in cutlass.range(tidx, BLOCK_M * BLOCK_N, bdx): + row = i // BLOCK_N + col = i % BLOCK_N + m_idx = bm * BLOCK_M + row + n_idx = bn * BLOCK_N + col + if m_idx < M and n_idx < N: + acc = cutlass.Float32(0.0) + for k in cutlass.range(K): + acc += A[m_idx * K + k] * B[k * N + n_idx] + D[m_idx * N + n_idx] = acc + + +@cute.jit +def launch_gemm( + stream: cuda.CUstream, + A: cute.Tensor, B: cute.Tensor, D: cute.Tensor, + *, M: int, N: int, K: int, +): + BLOCK_M, BLOCK_N = 64, 64 + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch( + grid=[grid_m, grid_n, 1], + block=[256, 1, 1], + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# Element-wise Add (2-D, flat indexing) # +# ------------------------------------------------------------------ # +@cute.kernel +def elementwise_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): + """Per-thread kernel: 2-D element-wise add using flat indexing.""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + + thread_idx = bidx * bdim + tidx + + m, n = gA.shape + ni = thread_idx % n + mi = thread_idx // n + + a_val = gA[mi, ni] + b_val = gB[mi, ni] + gC[mi, ni] = a_val + b_val + + +@cute.jit +def launch_elementwise_add( + stream: cuda.CUstream, + mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, +): + num_threads_per_block = 256 + m, n = mA.shape + elementwise_add_kernel(mA, mB, mC).launch( + grid=((m * n) // num_threads_per_block, 1, 1), + block=(num_threads_per_block, 1, 1), + stream=stream, + ) + + +# ------------------------------------------------------------------ # +# Self-tests # +# ------------------------------------------------------------------ # +if __name__ == '__main__': + import os + os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") + + import jax + import jax.numpy as jnp + import numpy as np + + BLOCK = 256 + N_BLOCKS = 4 + + # ── Vector Add ──────────────────────────────────────────────────── + # 3-D CuTe layout: (elems_per_thread, threads_per_block, num_blocks) + a = jax.random.normal(jax.random.PRNGKey(0), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + b = jax.random.normal(jax.random.PRNGKey(1), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + call = cjax.cutlass_call( + launch_vector_add, + output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype), + use_static_tensors=True, + ) + c = jax.jit(call)(a, b) + np.testing.assert_allclose(np.array(c), np.array(a + b), rtol=1e-5, atol=1e-5) + print('vector_add: PASSED') + + # ── SAXPY ───────────────────────────────────────────────────────── + x = jax.random.normal(jax.random.PRNGKey(2), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + y = jax.random.normal(jax.random.PRNGKey(3), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + alpha = 2.5 + call = cjax.cutlass_call( + launch_saxpy, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype), + use_static_tensors=True, + alpha=alpha, + ) + out = jax.jit(call)(x, y) + np.testing.assert_allclose(np.array(out), np.array(alpha * x + y), rtol=1e-5, atol=1e-5) + print('saxpy: PASSED') + + # ── ReLU ────────────────────────────────────────────────────────── + N_ELEM = BLOCK * N_BLOCKS + x = jax.random.normal(jax.random.PRNGKey(4), (N_ELEM,), dtype=jnp.float32) + call = cjax.cutlass_call( + launch_relu, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype), + N=N_ELEM, + ) + out = jax.jit(call)(x) + np.testing.assert_allclose(np.array(out), np.array(jnp.maximum(x, 0)), rtol=1e-5, atol=1e-5) + print('relu: PASSED') + + # ── Fused Bias + ReLU ───────────────────────────────────────────── + ROWS, COLS = 16, 64 + x = jax.random.normal(jax.random.PRNGKey(5), (ROWS * COLS,), dtype=jnp.float32) + bias = jax.random.normal(jax.random.PRNGKey(6), (COLS,), dtype=jnp.float32) + call = cjax.cutlass_call( + launch_fused_bias_relu, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype), + N=ROWS * COLS, width=COLS, + ) + out = jax.jit(call)(x, bias) + ref = jnp.maximum(x.reshape(ROWS, COLS) + bias, 0).reshape(-1) + np.testing.assert_allclose(np.array(out), np.array(ref), rtol=1e-5, atol=1e-5) + print('fused_bias_relu: PASSED') + + # ── GEMM ────────────────────────────────────────────────────────── + M, N, K = 128, 128, 64 + A = jax.random.normal(jax.random.PRNGKey(7), (M * K,), dtype=jnp.float32) + B = jax.random.normal(jax.random.PRNGKey(8), (K * N,), dtype=jnp.float32) + call = cjax.cutlass_call( + launch_gemm, + output_shape_dtype=jax.ShapeDtypeStruct((M * N,), A.dtype), + M=M, N=N, K=K, + ) + D = jax.jit(call)(A, B) + ref = A.reshape(M, K) @ B.reshape(K, N) + np.testing.assert_allclose(np.array(D.reshape(M, N)), np.array(ref), rtol=1e-2, atol=1e-2) + print('gemm: PASSED') + + # ── Elementwise Add (2-D) ───────────────────────────────────────── + M, N = 16, 256 + a = jax.random.normal(jax.random.PRNGKey(9), (M, N), dtype=jnp.float32) + b = jax.random.normal(jax.random.PRNGKey(10), (M, N), dtype=jnp.float32) + call = cjax.cutlass_call( + launch_elementwise_add, + output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype), + ) + c = jax.jit(call)(a, b) + np.testing.assert_allclose(np.array(c), np.array(a + b), rtol=1e-5, atol=1e-5) + print('elementwise_add: PASSED') + + print('\nAll kernels passed.')