mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 09:15:56 +00:00
[CuTeDSL] implment a cta-level norm example (both layernorm and rmsnorm) (#3009)
* kernel impl * add copyright
This commit is contained in:
424
examples/python/CuTeDSL/hopper/cta_norm.py
Normal file
424
examples/python/CuTeDSL/hopper/cta_norm.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
import math
|
||||
from typing import Type, Optional
|
||||
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
CTA-level LayerNorm / RMSNorm Example using CuTe DSL.
|
||||
|
||||
This example implements a CTA-level normalization kernel (LayerNorm or RMSNorm)
|
||||
using the CuTe DSL. Each CTA processes one row of the input tensor and performs
|
||||
the full normalization pipeline, including global memory loads, reduction,
|
||||
normalization, and global memory stores.
|
||||
|
||||
In this kernel:
|
||||
|
||||
- Threads are arranged linearly within a CTA.
|
||||
- Vectorized 128-bit loads/stores are used to maximize memory bandwidth.
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/python/CuTeDSL/hopper/cta_norm.py
|
||||
python examples/python/CuTeDSL/hopper/cta_norm.py \
|
||||
--M 4096 --N 8192 --dtype fp16 --threads 256 \
|
||||
--norm_type rms --benchmark
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
ncu -k regex:".*cutlass.*" python examples/python/CuTeDSL/hopper/cta_norm.py
|
||||
"""
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": cutlass.Float16,
|
||||
"bf16": cutlass.BFloat16,
|
||||
"fp32": cutlass.Float32,
|
||||
}
|
||||
|
||||
class CtaNorm:
|
||||
def __init__(
|
||||
self, N: int,
|
||||
norm_type: str,
|
||||
threads_per_cta: Optional[int] = None,
|
||||
):
|
||||
self.N = N # hidden_size
|
||||
self.norm_type = norm_type # "layer" or "rms"
|
||||
self.elems_per_thread = 8
|
||||
self.warp_size = 32
|
||||
self.threads_per_cta = threads_per_cta or self.heuristic_threads()
|
||||
self.warps_per_cta = (self.threads_per_cta + 31) // self.warp_size
|
||||
|
||||
def heuristic_threads(self):
|
||||
elems_per_warp = self.elems_per_thread * self.warp_size
|
||||
heu_warps = (self.N + elems_per_warp - 1) // elems_per_warp // 4
|
||||
heu_warps = max(heu_warps, 1) # at least one warp
|
||||
heu_warps = (heu_warps + 1) // 2 * 2 # be multiple of 2
|
||||
heu_threads = heu_warps * 32
|
||||
return heu_threads
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mY,
|
||||
mX,
|
||||
mWeight,
|
||||
mBias,
|
||||
eps: cutlass.Float32 = 1e-6,
|
||||
):
|
||||
print("[DSL INFO] Input Tensors:")
|
||||
print(f"[DSL INFO] mY = {mY.type}")
|
||||
print(f"[DSL INFO] mX = {mX.type}")
|
||||
print(f"[DSL INFO] mWeight = {mWeight.type}")
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
print(f"[DSL INFO] mBias = {mBias.type}")
|
||||
M, _ = mX.shape
|
||||
atom_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
mX.element_type,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
t_layout = cute.make_layout(self.threads_per_cta) # thread layout within a CTA
|
||||
v_layout = cute.make_layout(self.elems_per_thread) # per-thread vector layout
|
||||
tiled_copy = cute.make_tiled_copy_tv(atom_copy, t_layout, v_layout)
|
||||
print("[DSL INFO] Tiling Parameters:")
|
||||
print(f"[DSL INFO] tiled_copy = {tiled_copy}")
|
||||
self.kernel(mY, mX, mWeight, mBias, tiled_copy, eps).launch(
|
||||
grid=[M, 1, 1],
|
||||
block=[self.warps_per_cta * self.warp_size, 1, 1],
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mY: cute.Tensor,
|
||||
mX: cute.Tensor,
|
||||
mWeight: Optional[cute.Tensor],
|
||||
mBias: Optional[cute.Tensor],
|
||||
tiled_copy: cute.TiledCopy,
|
||||
eps: cute.Float,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx() # thread index
|
||||
bidx, _, _ = cute.arch.block_idx() # cta index
|
||||
thr_copy = tiled_copy.get_slice(tidx)
|
||||
|
||||
gY = cute.local_tile(mY, tiler=(1, self.N), coord=(bidx, 0))
|
||||
gX = cute.local_tile(mX, tiler=(1, self.N), coord=(bidx, 0))
|
||||
gY, gX = gY[0, None], gX[0, None]
|
||||
print("[DSL INFO] Tiled Tensors:")
|
||||
print(f"[DSL INFO] gY = {gY.type}")
|
||||
print(f"[DSL INFO] gX = {gX.type}")
|
||||
tYgY = thr_copy.partition_S(gY)
|
||||
pred = cute.make_rmem_tensor(
|
||||
cute.size(tYgY, mode=[1]), cutlass.Boolean,
|
||||
)
|
||||
for i in range(cute.size(pred)):
|
||||
offset = (i * self.threads_per_cta + tidx) * self.elems_per_thread
|
||||
pred[i] = offset < self.N
|
||||
tXgX = thr_copy.partition_S(gX)
|
||||
tWgW = thr_copy.partition_S(mWeight)
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
tBgB = thr_copy.partition_S(mBias)
|
||||
tXrX = cute.make_fragment_like(tXgX)
|
||||
tXrX.fill(0) # initialize rmem fragment to zero to simplify reduction code
|
||||
tWrW = cute.make_fragment_like(tWgW)
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
tBrB = cute.make_fragment_like(tBgB)
|
||||
print("[DSL INFO] Sliced Tensors per thread:")
|
||||
print(f"[DSL INFO] tYgY = {tYgY.type}")
|
||||
print(f"[DSL INFO] tXgX = {tXgX.type}")
|
||||
print(f"[DSL INFO] tWgW = {tWgW.type}")
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
print(f"[DSL INFO] tBgB = {tBgB.type}")
|
||||
print(f"[DSL INFO] pred = {pred.type}")
|
||||
for i in range(cute.size(tXrX, mode=[1])):
|
||||
if pred[i]:
|
||||
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # LDG.128
|
||||
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # LDG.128
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # LDG.128
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
tYrY = self.apply_layernorm(tXrX, tWrW, tBrB, eps, tidx, pred)
|
||||
elif cutlass.const_expr(self.norm_type == "rms"):
|
||||
tYrY = self.apply_rmsnorm(tXrX, tWrW, eps, tidx, pred)
|
||||
else:
|
||||
raise ValueError("norm_type must be 'layer' or 'rms'.")
|
||||
for i in range(cute.size(tXrX, mode=[1])):
|
||||
if pred[i]:
|
||||
cute.autovec_copy(tYrY[None, i], tYgY[None, i]) # STG.128
|
||||
|
||||
|
||||
@cute.jit
|
||||
def warp_reduce(self, val, reduce_size = 32):
|
||||
iters = int(math.log2(reduce_size))
|
||||
for i in range(iters):
|
||||
val = val + cute.arch.shuffle_sync_bfly(val, offset=1<<i)
|
||||
return val
|
||||
|
||||
@cute.jit
|
||||
def cta_reduce(self, val, acc, tidx):
|
||||
warp_id = tidx >> 5
|
||||
lane_id = tidx & 31
|
||||
if lane_id == 0:
|
||||
acc[warp_id] = val
|
||||
cute.arch.sync_threads()
|
||||
if warp_id == 0:
|
||||
val = acc[lane_id] if lane_id < self.warps_per_cta else cutlass.Float32(0)
|
||||
val = self.warp_reduce(val)
|
||||
acc[self.warps_per_cta] = val
|
||||
cute.arch.sync_threads()
|
||||
val = acc[self.warps_per_cta]
|
||||
return val
|
||||
|
||||
@cute.jit
|
||||
def apply_layernorm(
|
||||
self,
|
||||
x: cute.Tensor,
|
||||
weight: cute.Tensor,
|
||||
bias: cute.Tensor,
|
||||
eps: cute.Float,
|
||||
tidx: cutlass.Int32,
|
||||
pred: cute.Tensor,
|
||||
):
|
||||
"""
|
||||
mean = sum(x) / D
|
||||
var = sum((x - mean) ^ 2) / D
|
||||
y[i] = (x[i] - mean) / sqrt(var + eps) * weight[i] + bias[i]
|
||||
"""
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1)
|
||||
# Reduce x
|
||||
val = cute.Float32(0.0)
|
||||
for idx in range(cute.size(x)):
|
||||
# Accumulate in FP32 to improve numerical precision.
|
||||
val += x[idx].to(cutlass.Float32)
|
||||
val = self.warp_reduce(val)
|
||||
val = self.cta_reduce(val, acc, tidx)
|
||||
mean = val / self.N
|
||||
# Reduce (x - mean) ^ 2
|
||||
val = cute.Float32(0.0)
|
||||
for i in range(cute.size(x, mode=[1])):
|
||||
if pred[i]:
|
||||
for idx in range(cute.size(x[None, i])):
|
||||
# Accumulate in FP32 to improve numerical precision.
|
||||
x_fp32 = x[None, i][idx].to(cutlass.Float32)
|
||||
val += (x_fp32 - mean) * (x_fp32 - mean)
|
||||
val = self.warp_reduce(val)
|
||||
val = self.cta_reduce(val, acc, tidx)
|
||||
factor = cute.rsqrt(val / self.N + eps)
|
||||
# Normalize
|
||||
normed = cute.make_fragment_like(x)
|
||||
value = (x.load() - mean) * factor * weight.load() + bias.load()
|
||||
normed.store(value.to(normed.element_type))
|
||||
return normed
|
||||
|
||||
|
||||
@cute.jit
|
||||
def apply_rmsnorm(
|
||||
self,
|
||||
x: cute.Tensor,
|
||||
weight: cute.Tensor,
|
||||
eps: cute.Float,
|
||||
tidx: cutlass.Int32,
|
||||
pred: cute.Tensor,
|
||||
):
|
||||
"""
|
||||
y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i]
|
||||
"""
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1)
|
||||
val = cute.Float32(0.0)
|
||||
for i in range(cute.size(x, mode=[1])):
|
||||
if pred[i]:
|
||||
for idx in range(cute.size(x[None, i])):
|
||||
# Accumulate in FP32 to improve numerical precision.
|
||||
x_fp32 = x[None, i][idx].to(cutlass.Float32)
|
||||
val += x_fp32 * x_fp32
|
||||
val = self.warp_reduce(val)
|
||||
acc_sq = self.cta_reduce(val, acc, tidx)
|
||||
factor = cute.rsqrt(acc_sq / self.N + eps)
|
||||
tNrN = cute.make_fragment_like(x)
|
||||
tNrN.store((x.load() * factor * weight.load()).to(tNrN.element_type))
|
||||
return tNrN
|
||||
|
||||
|
||||
def run_layernorm(
|
||||
M: int,
|
||||
N: int,
|
||||
threads_per_cta: int,
|
||||
norm_type: str,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
warmup_iterations=2,
|
||||
iterations=200,
|
||||
eps=1e-6,
|
||||
):
|
||||
if N % 8 > 0:
|
||||
raise ValueError(f"N = {N} must be a multiple of 8 for this example.")
|
||||
if threads_per_cta is not None:
|
||||
if threads_per_cta % 32 != 0 or not (0 < threads_per_cta <= 1024):
|
||||
raise ValueError(f"Invalid threads_per_cta = {threads_per_cta}")
|
||||
|
||||
print("Running CtaNorm test with:")
|
||||
print(f"Tensor dimensions: [{M}, {N}]")
|
||||
print(f"Input and Output Data type: {dtype}")
|
||||
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
bias = None
|
||||
if norm_type == "layer":
|
||||
bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
print("Input tensor shapes:")
|
||||
print(f"x: {x.shape}, dtype: {x.dtype}")
|
||||
print(f"weight: {weight.shape}, dtype: {weight.dtype}")
|
||||
if norm_type == "layer":
|
||||
print(f"bias: {bias.shape}, dtype: {bias.dtype}")
|
||||
print(f"y: {y.shape}, dtype: {y.dtype}\n")
|
||||
|
||||
_x = from_dlpack(x, assumed_align=16, enable_tvm_ffi=True)
|
||||
_weight = from_dlpack(weight, assumed_align=16, enable_tvm_ffi=True)
|
||||
_bias = None
|
||||
if norm_type == "layer":
|
||||
_bias = from_dlpack(bias, assumed_align=16, enable_tvm_ffi=True)
|
||||
_y = from_dlpack(y, assumed_align=16, enable_tvm_ffi=True)
|
||||
|
||||
print("Compiling kernel with cute.compile ...")
|
||||
start_time = time.time()
|
||||
layernorm = CtaNorm(N, norm_type, threads_per_cta)
|
||||
if norm_type == "layer":
|
||||
compiled_func = cute.compile(
|
||||
layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi",
|
||||
)
|
||||
else:
|
||||
compiled_func = cute.compile(
|
||||
layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi",
|
||||
)
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
print("Executing vector add kernel...")
|
||||
if not skip_ref_check:
|
||||
compiled_func(y, x, weight, bias, eps)
|
||||
print("Verifying results...")
|
||||
if norm_type == "layer":
|
||||
ref = torch.layer_norm(x, (N,), weight, bias, eps)
|
||||
else:
|
||||
ref = torch.rms_norm(x, (N,), weight, eps)
|
||||
torch.testing.assert_close(y, ref, atol=1e-3, rtol=1e-3)
|
||||
print("Results verified successfully!")
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
def generate_tensors():
|
||||
x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
bias = None
|
||||
if norm_type == "layer":
|
||||
bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
||||
y = torch.empty_like(x)
|
||||
return testing.JitArguments(y, x, weight, bias, eps)
|
||||
|
||||
def torch_ref(y, x, weight, bias, eps):
|
||||
if norm_type == "layer":
|
||||
y = torch.layer_norm(x, (N,), weight, bias, eps)
|
||||
else:
|
||||
y = torch.rms_norm(x, (N,), weight, eps)
|
||||
|
||||
def eval(func, name):
|
||||
avg_time_us = testing.benchmark(
|
||||
func,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
# Print execution results
|
||||
print(f"\n{name}")
|
||||
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
||||
print(
|
||||
f"Achieved memory throughput: {(2 * (x.numel() + weight.numel()) * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
|
||||
eval(compiled_func, f"CuTe {norm_type}norm kernel")
|
||||
eval(torch_ref, f"PyTorch {norm_type}norm reference")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
||||
)
|
||||
parser.add_argument("--M", default=4096, type=int)
|
||||
parser.add_argument("--N", default=4096, type=int)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
help="Data type for input/output tensors (e.g. float16, bf16, float32)",
|
||||
)
|
||||
parser.add_argument("--norm_type", choices=["layer", "rms"], default="layer", type=str)
|
||||
parser.add_argument("--threads", default=None, type=int)
|
||||
parser.add_argument("--warmup_iterations", default=2, type=int)
|
||||
parser.add_argument("--iterations", default=100, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run this example!")
|
||||
|
||||
run_layernorm(
|
||||
args.M,
|
||||
args.N,
|
||||
args.threads,
|
||||
args.norm_type,
|
||||
dtype=cutlass.Float16,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
)
|
||||
print("\nPASS")
|
||||
Reference in New Issue
Block a user