# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse import torch import time import math from typing import Type, Optional import cutlass import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.torch as cutlass_torch from cutlass.cute.runtime import from_dlpack """ CTA-level LayerNorm / RMSNorm Example using CuTe DSL. This example implements a CTA-level normalization kernel (LayerNorm or RMSNorm) using the CuTe DSL. Each CTA processes one row of the input tensor and performs the full normalization pipeline, including global memory loads, reduction, normalization, and global memory stores. In this kernel: - Threads are arranged linearly within a CTA. - Vectorized 128-bit loads/stores are used to maximize memory bandwidth. To run this example: .. code-block:: bash python examples/python/CuTeDSL/hopper/cta_norm.py python examples/python/CuTeDSL/hopper/cta_norm.py \ --M 4096 --N 8192 --dtype fp16 --threads 256 \ --norm_type rms --benchmark To collect performance with NCU profiler: .. code-block:: bash ncu -k regex:".*cutlass.*" python examples/python/CuTeDSL/hopper/cta_norm.py """ DTYPE_MAP = { "fp16": cutlass.Float16, "bf16": cutlass.BFloat16, "fp32": cutlass.Float32, } class CtaNorm: def __init__( self, N: int, norm_type: str, threads_per_cta: Optional[int] = None, ): self.N = N # hidden_size self.norm_type = norm_type # "layer" or "rms" self.elems_per_thread = 8 self.warp_size = 32 self.threads_per_cta = threads_per_cta or self.heuristic_threads() self.warps_per_cta = (self.threads_per_cta + 31) // self.warp_size def heuristic_threads(self): elems_per_warp = self.elems_per_thread * self.warp_size heu_warps = (self.N + elems_per_warp - 1) // elems_per_warp // 4 heu_warps = max(heu_warps, 1) # at least one warp heu_warps = (heu_warps + 1) // 2 * 2 # be multiple of 2 heu_threads = heu_warps * 32 return heu_threads @cute.jit def __call__( self, mY, mX, mWeight, mBias, eps: cutlass.Float32 = 1e-6, ): print("[DSL INFO] Input Tensors:") print(f"[DSL INFO] mY = {mY.type}") print(f"[DSL INFO] mX = {mX.type}") print(f"[DSL INFO] mWeight = {mWeight.type}") if cutlass.const_expr(self.norm_type == "layer"): print(f"[DSL INFO] mBias = {mBias.type}") M, _ = mX.shape atom_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128, ) t_layout = cute.make_layout(self.threads_per_cta) # thread layout within a CTA v_layout = cute.make_layout(self.elems_per_thread) # per-thread vector layout tiled_copy = cute.make_tiled_copy_tv(atom_copy, t_layout, v_layout) print("[DSL INFO] Tiling Parameters:") print(f"[DSL INFO] tiled_copy = {tiled_copy}") self.kernel(mY, mX, mWeight, mBias, tiled_copy, eps).launch( grid=[M, 1, 1], block=[self.warps_per_cta * self.warp_size, 1, 1], ) @cute.kernel def kernel( self, mY: cute.Tensor, mX: cute.Tensor, mWeight: Optional[cute.Tensor], mBias: Optional[cute.Tensor], tiled_copy: cute.TiledCopy, eps: cute.Float, ): tidx, _, _ = cute.arch.thread_idx() # thread index bidx, _, _ = cute.arch.block_idx() # cta index thr_copy = tiled_copy.get_slice(tidx) gY = cute.local_tile(mY, tiler=(1, self.N), coord=(bidx, 0)) gX = cute.local_tile(mX, tiler=(1, self.N), coord=(bidx, 0)) gY, gX = gY[0, None], gX[0, None] print("[DSL INFO] Tiled Tensors:") print(f"[DSL INFO] gY = {gY.type}") print(f"[DSL INFO] gX = {gX.type}") tYgY = thr_copy.partition_S(gY) pred = cute.make_rmem_tensor( cute.size(tYgY, mode=[1]), cutlass.Boolean, ) for i in range(cute.size(pred)): offset = (i * self.threads_per_cta + tidx) * self.elems_per_thread pred[i] = offset < self.N tXgX = thr_copy.partition_S(gX) tWgW = thr_copy.partition_S(mWeight) if cutlass.const_expr(self.norm_type == "layer"): tBgB = thr_copy.partition_S(mBias) tXrX = cute.make_fragment_like(tXgX) tXrX.fill(0) # initialize rmem fragment to zero to simplify reduction code tWrW = cute.make_fragment_like(tWgW) if cutlass.const_expr(self.norm_type == "layer"): tBrB = cute.make_fragment_like(tBgB) print("[DSL INFO] Sliced Tensors per thread:") print(f"[DSL INFO] tYgY = {tYgY.type}") print(f"[DSL INFO] tXgX = {tXgX.type}") print(f"[DSL INFO] tWgW = {tWgW.type}") if cutlass.const_expr(self.norm_type == "layer"): print(f"[DSL INFO] tBgB = {tBgB.type}") print(f"[DSL INFO] pred = {pred.type}") for i in range(cute.size(tXrX, mode=[1])): if pred[i]: cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # LDG.128 cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # LDG.128 if cutlass.const_expr(self.norm_type == "layer"): cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # LDG.128 if cutlass.const_expr(self.norm_type == "layer"): tYrY = self.apply_layernorm(tXrX, tWrW, tBrB, eps, tidx, pred) elif cutlass.const_expr(self.norm_type == "rms"): tYrY = self.apply_rmsnorm(tXrX, tWrW, eps, tidx, pred) else: raise ValueError("norm_type must be 'layer' or 'rms'.") for i in range(cute.size(tXrX, mode=[1])): if pred[i]: cute.autovec_copy(tYrY[None, i], tYgY[None, i]) # STG.128 @cute.jit def warp_reduce(self, val, reduce_size = 32): iters = int(math.log2(reduce_size)) for i in range(iters): val = val + cute.arch.shuffle_sync_bfly(val, offset=1<> 5 lane_id = tidx & 31 if lane_id == 0: acc[warp_id] = val cute.arch.sync_threads() if warp_id == 0: val = acc[lane_id] if lane_id < self.warps_per_cta else cutlass.Float32(0) val = self.warp_reduce(val) acc[self.warps_per_cta] = val cute.arch.sync_threads() val = acc[self.warps_per_cta] return val @cute.jit def apply_layernorm( self, x: cute.Tensor, weight: cute.Tensor, bias: cute.Tensor, eps: cute.Float, tidx: cutlass.Int32, pred: cute.Tensor, ): """ mean = sum(x) / D var = sum((x - mean) ^ 2) / D y[i] = (x[i] - mean) / sqrt(var + eps) * weight[i] + bias[i] """ smem = cutlass.utils.SmemAllocator() acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1) # Reduce x val = cute.Float32(0.0) for idx in range(cute.size(x)): # Accumulate in FP32 to improve numerical precision. val += x[idx].to(cutlass.Float32) val = self.warp_reduce(val) val = self.cta_reduce(val, acc, tidx) mean = val / self.N # Reduce (x - mean) ^ 2 val = cute.Float32(0.0) for i in range(cute.size(x, mode=[1])): if pred[i]: for idx in range(cute.size(x[None, i])): # Accumulate in FP32 to improve numerical precision. x_fp32 = x[None, i][idx].to(cutlass.Float32) val += (x_fp32 - mean) * (x_fp32 - mean) val = self.warp_reduce(val) val = self.cta_reduce(val, acc, tidx) factor = cute.rsqrt(val / self.N + eps) # Normalize normed = cute.make_fragment_like(x) value = (x.load() - mean) * factor * weight.load() + bias.load() normed.store(value.to(normed.element_type)) return normed @cute.jit def apply_rmsnorm( self, x: cute.Tensor, weight: cute.Tensor, eps: cute.Float, tidx: cutlass.Int32, pred: cute.Tensor, ): """ y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i] """ smem = cutlass.utils.SmemAllocator() acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1) val = cute.Float32(0.0) for i in range(cute.size(x, mode=[1])): if pred[i]: for idx in range(cute.size(x[None, i])): # Accumulate in FP32 to improve numerical precision. x_fp32 = x[None, i][idx].to(cutlass.Float32) val += x_fp32 * x_fp32 val = self.warp_reduce(val) acc_sq = self.cta_reduce(val, acc, tidx) factor = cute.rsqrt(acc_sq / self.N + eps) tNrN = cute.make_fragment_like(x) tNrN.store((x.load() * factor * weight.load()).to(tNrN.element_type)) return tNrN def run_layernorm( M: int, N: int, threads_per_cta: int, norm_type: str, dtype: Type[cutlass.Numeric], skip_ref_check=False, benchmark=True, warmup_iterations=2, iterations=200, eps=1e-6, ): if N % 8 > 0: raise ValueError(f"N = {N} must be a multiple of 8 for this example.") if threads_per_cta is not None: if threads_per_cta % 32 != 0 or not (0 < threads_per_cta <= 1024): raise ValueError(f"Invalid threads_per_cta = {threads_per_cta}") print("Running CtaNorm test with:") print(f"Tensor dimensions: [{M}, {N}]") print(f"Input and Output Data type: {dtype}") torch_dtype = cutlass_torch.dtype(dtype) x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype) bias = None if norm_type == "layer": bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype) y = torch.empty_like(x) print("Input tensor shapes:") print(f"x: {x.shape}, dtype: {x.dtype}") print(f"weight: {weight.shape}, dtype: {weight.dtype}") if norm_type == "layer": print(f"bias: {bias.shape}, dtype: {bias.dtype}") print(f"y: {y.shape}, dtype: {y.dtype}\n") _x = from_dlpack(x, assumed_align=16, enable_tvm_ffi=True) _weight = from_dlpack(weight, assumed_align=16, enable_tvm_ffi=True) _bias = None if norm_type == "layer": _bias = from_dlpack(bias, assumed_align=16, enable_tvm_ffi=True) _y = from_dlpack(y, assumed_align=16, enable_tvm_ffi=True) print("Compiling kernel with cute.compile ...") start_time = time.time() layernorm = CtaNorm(N, norm_type, threads_per_cta) if norm_type == "layer": compiled_func = cute.compile( layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi", ) else: compiled_func = cute.compile( layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi", ) compilation_time = time.time() - start_time print(f"Compilation time: {compilation_time:.4f} seconds") print("Executing vector add kernel...") if not skip_ref_check: compiled_func(y, x, weight, bias, eps) print("Verifying results...") if norm_type == "layer": ref = torch.layer_norm(x, (N,), weight, bias, eps) else: ref = torch.rms_norm(x, (N,), weight, eps) torch.testing.assert_close(y, ref, atol=1e-3, rtol=1e-3) print("Results verified successfully!") if not benchmark: return def generate_tensors(): x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype) bias = None if norm_type == "layer": bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype) y = torch.empty_like(x) return testing.JitArguments(y, x, weight, bias, eps) def torch_ref(y, x, weight, bias, eps): if norm_type == "layer": y = torch.layer_norm(x, (N,), weight, bias, eps) else: y = torch.rms_norm(x, (N,), weight, eps) def eval(func, name): avg_time_us = testing.benchmark( func, workspace_generator=generate_tensors, workspace_count=10, warmup_iterations=warmup_iterations, iterations=iterations, ) # Print execution results print(f"\n{name}") print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") print( f"Achieved memory throughput: {(2 * (x.numel() + weight.numel()) * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s" ) eval(compiled_func, f"CuTe {norm_type}norm kernel") eval(torch_ref, f"PyTorch {norm_type}norm reference") if __name__ == "__main__": parser = argparse.ArgumentParser( description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels" ) parser.add_argument("--M", default=4096, type=int) parser.add_argument("--N", default=4096, type=int) parser.add_argument( "--dtype", default="fp16", choices=DTYPE_MAP.keys(), help="Data type for input/output tensors (e.g. float16, bf16, float32)", ) parser.add_argument("--norm_type", choices=["layer", "rms"], default="layer", type=str) parser.add_argument("--threads", default=None, type=int) parser.add_argument("--warmup_iterations", default=2, type=int) parser.add_argument("--iterations", default=100, type=int) parser.add_argument("--skip_ref_check", action="store_true") parser.add_argument("--benchmark", action="store_true") args = parser.parse_args() if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") run_layernorm( args.M, args.N, args.threads, args.norm_type, dtype=cutlass.Float16, skip_ref_check=args.skip_ref_check, benchmark=args.benchmark, warmup_iterations=args.warmup_iterations, iterations=args.iterations, ) print("\nPASS")