mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 01:10:08 +00:00
425 lines
16 KiB
Python
425 lines
16 KiB
Python
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import argparse
|
|
import torch
|
|
import time
|
|
import math
|
|
from typing import Type, Optional
|
|
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.cute.testing as testing
|
|
import cutlass.torch as cutlass_torch
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
"""
|
|
CTA-level LayerNorm / RMSNorm Example using CuTe DSL.
|
|
|
|
This example implements a CTA-level normalization kernel (LayerNorm or RMSNorm)
|
|
using the CuTe DSL. Each CTA processes one row of the input tensor and performs
|
|
the full normalization pipeline, including global memory loads, reduction,
|
|
normalization, and global memory stores.
|
|
|
|
In this kernel:
|
|
|
|
- Threads are arranged linearly within a CTA.
|
|
- Vectorized 128-bit loads/stores are used to maximize memory bandwidth.
|
|
|
|
To run this example:
|
|
|
|
.. code-block:: bash
|
|
|
|
python examples/python/CuTeDSL/hopper/cta_norm.py
|
|
python examples/python/CuTeDSL/hopper/cta_norm.py \
|
|
--M 4096 --N 8192 --dtype fp16 --threads 256 \
|
|
--norm_type rms --benchmark
|
|
|
|
To collect performance with NCU profiler:
|
|
|
|
.. code-block:: bash
|
|
ncu -k regex:".*cutlass.*" python examples/python/CuTeDSL/hopper/cta_norm.py
|
|
"""
|
|
|
|
DTYPE_MAP = {
|
|
"fp16": cutlass.Float16,
|
|
"bf16": cutlass.BFloat16,
|
|
"fp32": cutlass.Float32,
|
|
}
|
|
|
|
class CtaNorm:
|
|
def __init__(
|
|
self, N: int,
|
|
norm_type: str,
|
|
threads_per_cta: Optional[int] = None,
|
|
):
|
|
self.N = N # hidden_size
|
|
self.norm_type = norm_type # "layer" or "rms"
|
|
self.elems_per_thread = 8
|
|
self.warp_size = 32
|
|
self.threads_per_cta = threads_per_cta or self.heuristic_threads()
|
|
self.warps_per_cta = (self.threads_per_cta + 31) // self.warp_size
|
|
|
|
def heuristic_threads(self):
|
|
elems_per_warp = self.elems_per_thread * self.warp_size
|
|
heu_warps = (self.N + elems_per_warp - 1) // elems_per_warp // 4
|
|
heu_warps = max(heu_warps, 1) # at least one warp
|
|
heu_warps = (heu_warps + 1) // 2 * 2 # be multiple of 2
|
|
heu_threads = heu_warps * 32
|
|
return heu_threads
|
|
|
|
@cute.jit
|
|
def __call__(
|
|
self,
|
|
mY,
|
|
mX,
|
|
mWeight,
|
|
mBias,
|
|
eps: cutlass.Float32 = 1e-6,
|
|
):
|
|
print("[DSL INFO] Input Tensors:")
|
|
print(f"[DSL INFO] mY = {mY.type}")
|
|
print(f"[DSL INFO] mX = {mX.type}")
|
|
print(f"[DSL INFO] mWeight = {mWeight.type}")
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
print(f"[DSL INFO] mBias = {mBias.type}")
|
|
M, _ = mX.shape
|
|
atom_copy = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
mX.element_type,
|
|
num_bits_per_copy=128,
|
|
)
|
|
t_layout = cute.make_layout(self.threads_per_cta) # thread layout within a CTA
|
|
v_layout = cute.make_layout(self.elems_per_thread) # per-thread vector layout
|
|
tiled_copy = cute.make_tiled_copy_tv(atom_copy, t_layout, v_layout)
|
|
print("[DSL INFO] Tiling Parameters:")
|
|
print(f"[DSL INFO] tiled_copy = {tiled_copy}")
|
|
self.kernel(mY, mX, mWeight, mBias, tiled_copy, eps).launch(
|
|
grid=[M, 1, 1],
|
|
block=[self.warps_per_cta * self.warp_size, 1, 1],
|
|
)
|
|
|
|
@cute.kernel
|
|
def kernel(
|
|
self,
|
|
mY: cute.Tensor,
|
|
mX: cute.Tensor,
|
|
mWeight: Optional[cute.Tensor],
|
|
mBias: Optional[cute.Tensor],
|
|
tiled_copy: cute.TiledCopy,
|
|
eps: cute.Float,
|
|
):
|
|
tidx, _, _ = cute.arch.thread_idx() # thread index
|
|
bidx, _, _ = cute.arch.block_idx() # cta index
|
|
thr_copy = tiled_copy.get_slice(tidx)
|
|
|
|
gY = cute.local_tile(mY, tiler=(1, self.N), coord=(bidx, 0))
|
|
gX = cute.local_tile(mX, tiler=(1, self.N), coord=(bidx, 0))
|
|
gY, gX = gY[0, None], gX[0, None]
|
|
print("[DSL INFO] Tiled Tensors:")
|
|
print(f"[DSL INFO] gY = {gY.type}")
|
|
print(f"[DSL INFO] gX = {gX.type}")
|
|
tYgY = thr_copy.partition_S(gY)
|
|
pred = cute.make_rmem_tensor(
|
|
cute.size(tYgY, mode=[1]), cutlass.Boolean,
|
|
)
|
|
for i in range(cute.size(pred)):
|
|
offset = (i * self.threads_per_cta + tidx) * self.elems_per_thread
|
|
pred[i] = offset < self.N
|
|
tXgX = thr_copy.partition_S(gX)
|
|
tWgW = thr_copy.partition_S(mWeight)
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
tBgB = thr_copy.partition_S(mBias)
|
|
tXrX = cute.make_fragment_like(tXgX)
|
|
tXrX.fill(0) # initialize rmem fragment to zero to simplify reduction code
|
|
tWrW = cute.make_fragment_like(tWgW)
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
tBrB = cute.make_fragment_like(tBgB)
|
|
print("[DSL INFO] Sliced Tensors per thread:")
|
|
print(f"[DSL INFO] tYgY = {tYgY.type}")
|
|
print(f"[DSL INFO] tXgX = {tXgX.type}")
|
|
print(f"[DSL INFO] tWgW = {tWgW.type}")
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
print(f"[DSL INFO] tBgB = {tBgB.type}")
|
|
print(f"[DSL INFO] pred = {pred.type}")
|
|
for i in range(cute.size(tXrX, mode=[1])):
|
|
if pred[i]:
|
|
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # Global load
|
|
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # Global load
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # Global load
|
|
if cutlass.const_expr(self.norm_type == "layer"):
|
|
tYrY = self.apply_layernorm(tXrX, tWrW, tBrB, eps, tidx, pred)
|
|
elif cutlass.const_expr(self.norm_type == "rms"):
|
|
tYrY = self.apply_rmsnorm(tXrX, tWrW, eps, tidx, pred)
|
|
else:
|
|
raise ValueError("norm_type must be 'layer' or 'rms'.")
|
|
for i in range(cute.size(tXrX, mode=[1])):
|
|
if pred[i]:
|
|
cute.autovec_copy(tYrY[None, i], tYgY[None, i]) # STG.128
|
|
|
|
|
|
@cute.jit
|
|
def warp_reduce(self, val, reduce_size = 32):
|
|
iters = int(math.log2(reduce_size))
|
|
for i in range(iters):
|
|
val = val + cute.arch.shuffle_sync_bfly(val, offset=1<<i)
|
|
return val
|
|
|
|
@cute.jit
|
|
def cta_reduce(self, val, acc, tidx):
|
|
warp_id = tidx >> 5
|
|
lane_id = tidx & 31
|
|
if lane_id == 0:
|
|
acc[warp_id] = val
|
|
cute.arch.sync_threads()
|
|
if warp_id == 0:
|
|
val = acc[lane_id] if lane_id < self.warps_per_cta else cutlass.Float32(0)
|
|
val = self.warp_reduce(val)
|
|
acc[self.warps_per_cta] = val
|
|
cute.arch.sync_threads()
|
|
val = acc[self.warps_per_cta]
|
|
return val
|
|
|
|
@cute.jit
|
|
def apply_layernorm(
|
|
self,
|
|
x: cute.Tensor,
|
|
weight: cute.Tensor,
|
|
bias: cute.Tensor,
|
|
eps: cute.Float,
|
|
tidx: cutlass.Int32,
|
|
pred: cute.Tensor,
|
|
):
|
|
"""
|
|
mean = sum(x) / D
|
|
var = sum((x - mean) ^ 2) / D
|
|
y[i] = (x[i] - mean) / sqrt(var + eps) * weight[i] + bias[i]
|
|
"""
|
|
smem = cutlass.utils.SmemAllocator()
|
|
acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1)
|
|
# Reduce x
|
|
val = cute.Float32(0.0)
|
|
for idx in range(cute.size(x)):
|
|
# Accumulate in FP32 to improve numerical precision.
|
|
val += x[idx].to(cutlass.Float32)
|
|
val = self.warp_reduce(val)
|
|
val = self.cta_reduce(val, acc, tidx)
|
|
mean = val / self.N
|
|
# Reduce (x - mean) ^ 2
|
|
val = cute.Float32(0.0)
|
|
for i in range(cute.size(x, mode=[1])):
|
|
if pred[i]:
|
|
for idx in range(cute.size(x[None, i])):
|
|
# Accumulate in FP32 to improve numerical precision.
|
|
x_fp32 = x[None, i][idx].to(cutlass.Float32)
|
|
val += (x_fp32 - mean) * (x_fp32 - mean)
|
|
val = self.warp_reduce(val)
|
|
val = self.cta_reduce(val, acc, tidx)
|
|
factor = cute.rsqrt(val / self.N + eps)
|
|
# Normalize
|
|
normed = cute.make_fragment_like(x)
|
|
value = (x.load() - mean) * factor * weight.load() + bias.load()
|
|
normed.store(value.to(normed.element_type))
|
|
return normed
|
|
|
|
|
|
@cute.jit
|
|
def apply_rmsnorm(
|
|
self,
|
|
x: cute.Tensor,
|
|
weight: cute.Tensor,
|
|
eps: cute.Float,
|
|
tidx: cutlass.Int32,
|
|
pred: cute.Tensor,
|
|
):
|
|
"""
|
|
y[i] = x[i] / sqrt(sum(x ^ 2) / D + eps) * w[i]
|
|
"""
|
|
smem = cutlass.utils.SmemAllocator()
|
|
acc = smem.allocate_tensor(cutlass.Float32, self.warps_per_cta + 1)
|
|
val = cute.Float32(0.0)
|
|
for i in range(cute.size(x, mode=[1])):
|
|
if pred[i]:
|
|
for idx in range(cute.size(x[None, i])):
|
|
# Accumulate in FP32 to improve numerical precision.
|
|
x_fp32 = x[None, i][idx].to(cutlass.Float32)
|
|
val += x_fp32 * x_fp32
|
|
val = self.warp_reduce(val)
|
|
acc_sq = self.cta_reduce(val, acc, tidx)
|
|
factor = cute.rsqrt(acc_sq / self.N + eps)
|
|
tNrN = cute.make_fragment_like(x)
|
|
tNrN.store((x.load() * factor * weight.load()).to(tNrN.element_type))
|
|
return tNrN
|
|
|
|
|
|
def run_layernorm(
|
|
M: int,
|
|
N: int,
|
|
threads_per_cta: int,
|
|
norm_type: str,
|
|
dtype: Type[cutlass.Numeric],
|
|
skip_ref_check=False,
|
|
benchmark=True,
|
|
warmup_iterations=2,
|
|
iterations=200,
|
|
eps=1e-6,
|
|
):
|
|
if N % 8 > 0:
|
|
raise ValueError(f"N = {N} must be a multiple of 8 for this example.")
|
|
if threads_per_cta is not None:
|
|
if threads_per_cta % 32 != 0 or not (0 < threads_per_cta <= 1024):
|
|
raise ValueError(f"Invalid threads_per_cta = {threads_per_cta}")
|
|
|
|
print("Running CtaNorm test with:")
|
|
print(f"Tensor dimensions: [{M}, {N}]")
|
|
print(f"Input and Output Data type: {dtype}")
|
|
|
|
torch_dtype = cutlass_torch.dtype(dtype)
|
|
x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
bias = None
|
|
if norm_type == "layer":
|
|
bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
y = torch.empty_like(x)
|
|
|
|
print("Input tensor shapes:")
|
|
print(f"x: {x.shape}, dtype: {x.dtype}")
|
|
print(f"weight: {weight.shape}, dtype: {weight.dtype}")
|
|
if norm_type == "layer":
|
|
print(f"bias: {bias.shape}, dtype: {bias.dtype}")
|
|
print(f"y: {y.shape}, dtype: {y.dtype}\n")
|
|
|
|
_x = from_dlpack(x, assumed_align=16, enable_tvm_ffi=True)
|
|
_weight = from_dlpack(weight, assumed_align=16, enable_tvm_ffi=True)
|
|
_bias = None
|
|
if norm_type == "layer":
|
|
_bias = from_dlpack(bias, assumed_align=16, enable_tvm_ffi=True)
|
|
_y = from_dlpack(y, assumed_align=16, enable_tvm_ffi=True)
|
|
|
|
print("Compiling kernel with cute.compile ...")
|
|
start_time = time.time()
|
|
layernorm = CtaNorm(N, norm_type, threads_per_cta)
|
|
if norm_type == "layer":
|
|
compiled_func = cute.compile(
|
|
layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi",
|
|
)
|
|
else:
|
|
compiled_func = cute.compile(
|
|
layernorm, _y, _x, _weight, _bias, options="--generate-line-info --enable-tvm-ffi",
|
|
)
|
|
compilation_time = time.time() - start_time
|
|
print(f"Compilation time: {compilation_time:.4f} seconds")
|
|
|
|
print("Executing vector add kernel...")
|
|
if not skip_ref_check:
|
|
compiled_func(y, x, weight, bias, eps)
|
|
print("Verifying results...")
|
|
if norm_type == "layer":
|
|
ref = torch.layer_norm(x, (N,), weight, bias, eps)
|
|
else:
|
|
ref = torch.rms_norm(x, (N,), weight, eps)
|
|
torch.testing.assert_close(y, ref, atol=1e-3, rtol=1e-3)
|
|
print("Results verified successfully!")
|
|
|
|
if not benchmark:
|
|
return
|
|
|
|
def generate_tensors():
|
|
x = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
weight = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
bias = None
|
|
if norm_type == "layer":
|
|
bias = torch.randn(N, device=torch.device("cuda"), dtype=torch_dtype)
|
|
y = torch.empty_like(x)
|
|
return testing.JitArguments(y, x, weight, bias, eps)
|
|
|
|
def torch_ref(y, x, weight, bias, eps):
|
|
if norm_type == "layer":
|
|
y = torch.layer_norm(x, (N,), weight, bias, eps)
|
|
else:
|
|
y = torch.rms_norm(x, (N,), weight, eps)
|
|
|
|
def eval(func, name):
|
|
avg_time_us = testing.benchmark(
|
|
func,
|
|
workspace_generator=generate_tensors,
|
|
workspace_count=10,
|
|
warmup_iterations=warmup_iterations,
|
|
iterations=iterations,
|
|
)
|
|
|
|
# Print execution results
|
|
print(f"\n{name}")
|
|
print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms")
|
|
print(
|
|
f"Achieved memory throughput: {(2 * (x.numel() + weight.numel()) * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
|
)
|
|
|
|
eval(compiled_func, f"CuTe {norm_type}norm kernel")
|
|
eval(torch_ref, f"PyTorch {norm_type}norm reference")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
|
|
)
|
|
parser.add_argument("--M", default=4096, type=int)
|
|
parser.add_argument("--N", default=4096, type=int)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="fp16",
|
|
choices=DTYPE_MAP.keys(),
|
|
help="Data type for input/output tensors (e.g. float16, bf16, float32)",
|
|
)
|
|
parser.add_argument("--norm_type", choices=["layer", "rms"], default="layer", type=str)
|
|
parser.add_argument("--threads", default=None, type=int)
|
|
parser.add_argument("--warmup_iterations", default=2, type=int)
|
|
parser.add_argument("--iterations", default=100, type=int)
|
|
parser.add_argument("--skip_ref_check", action="store_true")
|
|
parser.add_argument("--benchmark", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("GPU is required to run this example!")
|
|
|
|
run_layernorm(
|
|
args.M,
|
|
args.N,
|
|
args.threads,
|
|
args.norm_type,
|
|
dtype=cutlass.Float16,
|
|
skip_ref_check=args.skip_ref_check,
|
|
benchmark=args.benchmark,
|
|
warmup_iterations=args.warmup_iterations,
|
|
iterations=args.iterations,
|
|
)
|
|
print("\nPASS")
|