New RMS Norm example with unit tests (#2917)

* Add rmsnorm example

* Address reviewer comments. (1) use the cute.runtime definition directly. (2) use the nvvm_wrapper's warp reduce directly

* Separate out reduce.py

* Change copyright notice years
This commit is contained in:
Brian K. Ryu
2026-01-12 17:05:31 -08:00
committed by GitHub
parent 8c52459504
commit 147f5673d0
3 changed files with 1573 additions and 0 deletions

View File

@@ -0,0 +1,533 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Hierarchical Reduction Utilities for CuTe-DSL Kernels
=====================================================
This module provides reusable reduction primitives for GPU kernels that need to
reduce values across warps, thread blocks, and clusters (SM90+).
Overview
--------
GPU reductions typically follow a hierarchical pattern:
1. **Warp Reduction**: Threads within a warp reduce using shuffle instructions.
Use `cute.arch.warp_reduction()` from the CuTe-DSL library.
2. **Block Reduction**: Multiple warps within a block reduce using shared memory.
Use `block_reduce()` from this module.
3. **Cluster Reduction** (SM90+): Multiple CTAs in a cluster reduce using
distributed shared memory and mbarrier synchronization.
Use `cluster_reduce()` from this module.
4. **Row Reduction**: Orchestrates all levels based on problem configuration.
Use `row_reduce()` from this module.
Shared Memory Buffer Layout Assumptions
---------------------------------------
For `block_reduce`:
- Buffer shape: (rows_per_block, warps_per_row)
- Each warp's lane 0 writes its reduced value to buffer[row_idx, col_idx]
- Thread mapping: row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
Example for 8 warps, 2 rows, 4 warps per row:
Warp 0 -> buffer[0, 0] Warp 4 -> buffer[1, 0]
Warp 1 -> buffer[0, 1] Warp 5 -> buffer[1, 1]
Warp 2 -> buffer[0, 2] Warp 6 -> buffer[1, 2]
Warp 3 -> buffer[0, 3] Warp 7 -> buffer[1, 3]
For `cluster_reduce`:
- Buffer shape: (rows_per_block, (warps_per_row, cluster_n))
- The second dimension is hierarchical: (local_warp_slot, cta_rank)
- Each CTA contributes to its own slot in the cluster dimension
Example for cluster_n=4, 2 warps per row:
CTA 0, Warp 0 -> buffer[row, (0, 0)]
CTA 0, Warp 1 -> buffer[row, (1, 0)]
CTA 1, Warp 0 -> buffer[row, (0, 1)]
CTA 1, Warp 1 -> buffer[row, (1, 1)]
... etc for CTAs 2, 3
Mbarrier Requirements (Cluster Reduction)
-----------------------------------------
For cluster reduction, the caller must:
1. Allocate an mbarrier in shared memory
2. Initialize it with `cute.arch.mbarrier_init(mbar_ptr, thread_count)`
3. Pass the mbarrier pointer to `cluster_reduce()`
The cluster_reduce function handles:
- Setting up the expected transaction count
- Performing async cross-CTA stores
- Waiting for all stores to complete
Usage Example
-------------
.. code-block:: python
from reduce import row_reduce, block_reduce, cluster_reduce
@cute.jit
def my_kernel(...):
# Allocate shared memory for reduction
# Shape depends on warps_per_row and cluster_n
if cluster_n > 1:
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, (warps_per_row, cluster_n))),
Float32
)
else:
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, warps_per_row)),
Float32
)
# Perform row reduction
result = row_reduce(
tensor_ssa,
cute.ReductionOp.ADD,
threads_per_row,
reduction_buffer,
mbar_ptr,
cluster_n,
init_val=Float32(0.0)
)
References
----------
The cluster synchronization primitives (set_block_rank, store_shared_remote)
are inspired by Quack: https://github.com/Dao-AILab/quack
"""
import operator
from collections.abc import Callable
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32
from cutlass._mlir.dialects import llvm
from cutlass.cutlass_dsl import T, dsl_user_op
# =============================================================================
# Inline PTX Operations for Cluster Communication
# =============================================================================
#
# These operations enable cross-CTA communication within a cluster (SM90+).
# They use inline PTX assembly for functionality not yet exposed in MLIR.
# =============================================================================
@dsl_user_op
def set_block_rank(
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
) -> Int32:
"""
Map a shared memory pointer to the equivalent address in another CTA's
shared memory within the same cluster.
This uses the PTX `mapa.shared::cluster` instruction to translate a local
shared memory address to the corresponding address in a peer CTA's shared
memory space.
Args:
smem_ptr: Pointer to local shared memory
peer_cta_rank_in_cluster: Target CTA's rank within the cluster (0 to cluster_size-1)
Returns:
Int32 representing the mapped address in the peer CTA's shared memory
Note:
This operation requires SM90+ with cluster support enabled.
The cluster must be launched with the appropriate cluster dimensions.
"""
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
return Int32(
llvm.inline_asm(
T.i32(),
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
"mapa.shared::cluster.u32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def store_shared_remote(
val: Float32,
smem_ptr: cute.Pointer,
mbar_ptr: cute.Pointer,
peer_cta_rank_in_cluster: Int32,
*,
loc=None,
ip=None,
) -> None:
"""
Asynchronously store a Float32 value to shared memory on a remote CTA
within the cluster, with mbarrier completion tracking.
This uses the PTX `st.async.shared::cluster` instruction which:
1. Translates the local smem address to the peer CTA's address space
2. Performs an asynchronous store to the remote shared memory
3. Signals the mbarrier when the store completes
Args:
val: The Float32 value to store
smem_ptr: Pointer to the destination in local shared memory coordinates
mbar_ptr: Pointer to the mbarrier that tracks completion
peer_cta_rank_in_cluster: Target CTA's rank within the cluster
Note:
- The mbarrier must be initialized with the expected transaction byte count
- Use `cute.arch.mbarrier_arrive_and_expect_tx()` to set up the transaction
- Use `cute.arch.mbarrier_wait()` to wait for all stores to complete
- This operation requires SM90+ with cluster support enabled
"""
remote_smem_ptr_i32 = set_block_rank(
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
).ir_value()
remote_mbar_ptr_i32 = set_block_rank(
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
).ir_value()
llvm.inline_asm(
None,
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
"st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
"r,f,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@dsl_user_op
def elem_pointer(x: cute.Tensor, coord, *, loc=None, ip=None) -> cute.Pointer:
"""
Get a pointer to an element at the specified coordinate in a tensor.
This is useful for getting the shared memory address of a specific element
when performing cross-CTA stores in cluster reduction.
Args:
x: The tensor (typically a shared memory tensor)
coord: The coordinate tuple, can be hierarchical like (row, (col, cluster_idx))
Returns:
Pointer to the element at the specified coordinate
"""
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
# =============================================================================
# Block-Level Reduction
# =============================================================================
@cute.jit
def block_reduce(
val: Float32,
op: Callable,
reduction_buffer: cute.Tensor,
init_val: Float32,
) -> Float32:
"""
Reduce values across all warps within a thread block using shared memory.
This function assumes each warp has already performed a warp-level reduction
and is contributing a single value (from lane 0). The function then:
1. Writes each warp's value to shared memory
2. Synchronizes the block
3. Performs a final warp reduction across the collected values
Args:
val: The warp-reduced value (only lane 0's value is used)
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
reduction_buffer: Shared memory tensor with shape (rows_per_block, warps_per_row)
init_val: Identity element for the reduction (0 for sum, -inf for max)
Returns:
The block-reduced result (same value across all threads)
Buffer Layout:
- Shape: (rows_per_block, warps_per_row)
- warps_per_row is inferred from reduction_buffer.shape[1]
- Thread mapping:
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
Example:
For a block with 8 warps processing 2 rows (4 warps per row):
.. code-block:: python
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((2, 4)), # 2 rows, 4 warps per row
Float32
)
result = block_reduce(warp_val, operator.add, reduction_buffer, Float32(0.0))
"""
lane_idx = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()
warps_per_row = cute.size(reduction_buffer.shape[1])
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
# Lane 0 of each warp writes its value to shared memory
if lane_idx == 0:
reduction_buffer[row_idx, col_idx] = val
cute.arch.barrier()
# All lanes participate in reading and reducing
# Only lanes < warps_per_row have valid data
block_reduce_val = init_val
if lane_idx < warps_per_row:
block_reduce_val = reduction_buffer[row_idx, lane_idx]
return cute.arch.warp_reduction(block_reduce_val, op)
# =============================================================================
# Cluster-Level Reduction (SM90+)
# =============================================================================
@cute.jit
def cluster_reduce(
val: Float32,
op: Callable,
reduction_buffer: cute.Tensor,
mbar_ptr: cute.Pointer,
cluster_n: cutlass.Constexpr[int],
init_val: Float32,
) -> Float32:
"""
Reduce values across all CTAs within a cluster using distributed shared memory.
This function extends block reduction to work across multiple CTAs in a cluster
using asynchronous cross-CTA stores and mbarrier synchronization. It:
1. Sets up the mbarrier with expected transaction count
2. Asynchronously stores each warp's value to all peer CTAs
3. Waits for all stores to complete
4. Reduces across all collected values
Args:
val: The warp-reduced value (only lane 0's value is used for stores)
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
reduction_buffer: Shared memory tensor with hierarchical shape
(rows_per_block, (warps_per_row, cluster_n))
mbar_ptr: Pointer to an initialized mbarrier in shared memory
cluster_n: Number of CTAs in the cluster (compile-time constant)
init_val: Identity element for the reduction (0 for sum, -inf for max)
Returns:
The cluster-reduced result (same value across all threads in all CTAs)
Buffer Layout:
- Shape: (rows_per_block, (warps_per_row, cluster_n))
- The second dimension is hierarchical:
- First level: warps_per_row (local warp slots)
- Second level: cluster_n (one slot per CTA in cluster)
- Access pattern: buffer[row_idx, (col_idx, cta_rank)]
Requirements:
- SM90+ with cluster support
- Mbarrier must be initialized before calling
- Kernel must be launched with appropriate cluster dimensions
Example:
For a cluster of 4 CTAs, each with 2 warps per row:
.. code-block:: python
# Allocate buffer with cluster dimension
reduction_buffer = cute.make_smem_tensor(
cute.make_layout((rows_per_block, (2, 4))), # 2 warps, 4 CTAs
Float32
)
# Initialize mbarrier (once per kernel)
mbar = cute.make_smem_tensor(cute.make_layout((1,)), cute.arch.Mbarrier)
cute.arch.mbarrier_init(mbar.iterator, thread_count)
# Perform cluster reduction
result = cluster_reduce(
warp_val, operator.add, reduction_buffer,
mbar.iterator, cluster_n=4, init_val=Float32(0.0)
)
"""
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
lane_idx = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()
rows_per_block = reduction_buffer.shape[0]
warps_per_row = reduction_buffer.shape[1][0]
row_idx = warp_idx // warps_per_row
col_idx = warp_idx % warps_per_row
# Warp 0, lane 0 sets up mbarrier with expected transaction count
# Each warp sends cluster_n stores (one to each CTA), each store is 4 bytes
if warp_idx == 0:
with cute.arch.elect_one():
num_warps = rows_per_block * warps_per_row
expected_bytes = num_warps * cluster_n * 4 # 4 bytes per Float32
cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, expected_bytes)
# Each lane < cluster_n writes to a different CTA's shared memory
# This distributes the warp's value to all CTAs in the cluster
if lane_idx < cluster_n:
store_shared_remote(
val,
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
mbar_ptr,
peer_cta_rank_in_cluster=lane_idx,
)
# Wait for all cross-CTA stores to complete
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
# Now each CTA has all values from all CTAs in the cluster
# Reduce across all collected values
num_total = warps_per_row * cluster_n
num_iter = cute.ceil_div(num_total, 32)
block_reduce_val = init_val
for i in cutlass.range_constexpr(num_iter):
idx = lane_idx + i * 32
if idx < num_total:
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
return cute.arch.warp_reduction(block_reduce_val, op)
# =============================================================================
# Row Reduction (Orchestration Function)
# =============================================================================
@cute.jit
def row_reduce(
x: cute.TensorSSA,
op: cute.ReductionOp,
threads_per_row: cutlass.Constexpr[int],
reduction_buffer: cute.Tensor,
mbar_ptr,
cluster_n: cutlass.Constexpr[int],
init_val: Float32,
):
"""
Perform hierarchical row reduction with automatic selection of reduction strategy.
This function orchestrates the full reduction pipeline:
1. Local reduction: Each thread reduces its portion of the row
2. Warp reduction: Threads within a warp reduce using shuffles
3. Block reduction: If needed, warps reduce using shared memory
4. Cluster reduction: If needed, CTAs reduce using distributed shared memory
The function automatically selects the appropriate reduction level based on
`threads_per_row` and `cluster_n`.
Args:
x: TensorSSA containing the values to reduce (in registers)
op: Reduction operation (cute.ReductionOp.ADD or cute.ReductionOp.MAX)
threads_per_row: Number of threads cooperating on each row (compile-time)
reduction_buffer: Shared memory tensor for block/cluster reduction
mbar_ptr: Mbarrier pointer (only used if cluster_n > 1)
cluster_n: Number of CTAs in cluster (1 for single-CTA reduction)
init_val: Identity element for the reduction
Returns:
The fully reduced result for each row
Reduction Strategy Selection:
- threads_per_row <= 32, cluster_n == 1: Warp reduction only
- threads_per_row > 32, cluster_n == 1: Warp + block reduction
- cluster_n > 1: Warp + cluster reduction (handles all cases)
Example:
.. code-block:: python
# Sum reduction across 128 threads per row, single CTA
result = row_reduce(
tensor_ssa,
cute.ReductionOp.ADD,
threads_per_row=128,
reduction_buffer=smem_buffer,
mbar_ptr=None,
cluster_n=1,
init_val=Float32(0.0)
)
# Max reduction across 256 threads per row, 4 CTAs in cluster
result = row_reduce(
tensor_ssa,
cute.ReductionOp.MAX,
threads_per_row=256,
reduction_buffer=smem_buffer,
mbar_ptr=mbar.iterator,
cluster_n=4,
init_val=Float32.neg_inf
)
"""
# Step 1: Local reduction - each thread reduces its register values
local_val = x.reduce(op, init_val=init_val, reduction_profile=0)
# Map ReductionOp enum to binary operator for warp/block reductions
warp_op = {
cute.ReductionOp.ADD: operator.add,
cute.ReductionOp.MAX: cute.arch.fmax,
}[op]
# Step 2: Warp reduction
# If threads_per_row < 32, only use that many threads in the reduction
warp_width = min(threads_per_row, 32)
warp_val = cute.arch.warp_reduction(local_val, warp_op, threads_in_group=warp_width)
# Determine if we need additional reduction levels
warps_per_row = max(threads_per_row // 32, 1)
# Step 3 & 4: Block or cluster reduction (if needed)
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
if cutlass.const_expr(cluster_n == 1):
# Single CTA: use block reduction
return block_reduce(warp_val, warp_op, reduction_buffer, init_val)
else:
# Multiple CTAs: use cluster reduction
return cluster_reduce(
warp_val, warp_op, reduction_buffer, mbar_ptr, cluster_n, init_val
)
else:
# Single warp handles entire row: warp reduction is sufficient
return warp_val

View File

@@ -0,0 +1,840 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import ctypes
import functools
import math
from typing import Optional, Tuple, Union
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutlass import Boolean, Float32, Int32, Int64
from cutlass.cute.runtime import make_ptr
# Support both direct execution and module import
try:
from .reduce import row_reduce
except ImportError:
from reduce import row_reduce
"""
RMSNorm: Root Mean Square Layer Normalization for Hopper & Blackwell (SM90+)
====================================================================
A high-performance RMSNorm implementation using CuTe DSL with cluster-based
reduction for large hidden dimensions.
RMSNorm computes: y = x / sqrt(mean(x²) + eps) * weight
Key Features:
-------------
1. CLUSTER SYNCHRONIZATION (SM90+)
- Multiple CTAs cooperate to process large N dimensions
- Each CTA handles N/cluster_n elements, then reduces across the cluster
- Uses mbarrier for efficient cross-CTA synchronization
2. ARCHITECTURE-SPECIFIC TUNING
- SM80 (Ampere): Single-CTA execution (cluster_n=1)
- SM90 (Hopper): Cluster support enabled for large N
- SM100 (Blackwell): Same as SM90
3. VECTORIZED MEMORY ACCESS
- 128-bit vectorized loads/stores for optimal memory throughput
- TiledCopy abstraction for organized gmem↔smem↔rmem transfers
Cluster Size Selection (FP16):
------------------------------
- N <= 16K: cluster_n = 1 (single CTA)
- N <= 32K: cluster_n = 2
- N <= 64K: cluster_n = 4
- N <= 128K: cluster_n = 8
- Larger: cluster_n = 16
To run this example:
.. code-block:: bash
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16 --benchmark
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 32768 --dtype BFloat16 --benchmark
To collect performance with NCU profiler:
.. code-block:: bash
ncu python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16 --skip_ref_check
"""
# =============================================================================
# Architecture Detection
# =============================================================================
@functools.lru_cache(maxsize=16)
def get_sm_version(device: Optional[Union[int, torch.device, str]] = None) -> int:
"""Get the SM (compute capability) version of a CUDA device."""
if not torch.cuda.is_available():
return 80 # Default fallback
props = torch.cuda.get_device_properties(device)
return props.major * 10 + props.minor
def supports_cluster() -> bool:
"""Check if the current device supports cluster operations (SM90+)."""
return get_sm_version() >= 90
# =============================================================================
# Predicate Utility
# =============================================================================
@cute.jit
def predicate_k(tXcX: cute.Tensor, limit: int) -> cute.Tensor:
"""Create predicate tensor for bounds checking."""
tXpX = cute.make_rmem_tensor(
cute.make_layout(
(cute.size(tXcX, mode=[0, 1]), cute.size(tXcX, mode=[1]), cute.size(tXcX, mode=[2])),
stride=(cute.size(tXcX, mode=[2]), 0, 1),
),
Boolean,
)
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
tXpX[rest_v, 0, rest_k] = cute.elem_less(tXcX[(0, rest_v), 0, rest_k][1], limit)
return tXpX
# =============================================================================
# RMSNorm Configuration Class
# =============================================================================
class RMSNormConfig:
"""
Configuration for the RMSNorm kernel.
This class encapsulates all kernel configuration computed once at initialization,
following CuTe-DSL conventions from official CUTLASS examples.
"""
COPY_BITS = 128 # 128-bit vectorized loads
def __init__(
self,
dtype: type[cutlass.Numeric],
N: int,
has_weight: bool = True,
sm_version: int | None = None,
):
self.dtype = dtype
self.N = N
self.has_weight = has_weight
self.sm_version = sm_version if sm_version is not None else get_sm_version()
# Vector size for 128-bit loads
self.vec_size = self.COPY_BITS // dtype.width
# Compute cluster size (SM90+ only)
self.cluster_n = self._compute_cluster_n(N, dtype, self.sm_version)
# N per CTA for cluster case
self.N_per_cta = N // self.cluster_n
# Thread configuration using static methods
self.threads_per_row = self._compute_threads_per_row(self.N_per_cta)
self.num_threads = self._compute_num_threads(self.N_per_cta)
# Derived values
self.num_vec_blocks = max(
1, (self.N_per_cta // self.vec_size + self.threads_per_row - 1) // self.threads_per_row
)
self.rows_per_block = self.num_threads // self.threads_per_row
self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row
self.warps_per_row = max(self.threads_per_row // 32, 1)
@staticmethod
def _compute_cluster_n(N: int, dtype: type[cutlass.Numeric], sm_version: int) -> int:
"""Compute optimal cluster size based on N and architecture."""
if sm_version < 90:
return 1
if dtype.width == 16: # FP16/BF16
if N <= 16 * 1024:
return 1
elif N <= 32 * 1024:
return 2
elif N <= 64 * 1024:
return 4
elif N <= 128 * 1024:
return 8
else:
return 16
else: # FP32
if N <= 32 * 1024:
return 1
elif N <= 64 * 1024:
return 2
elif N <= 128 * 1024:
return 4
elif N <= 256 * 1024:
return 8
else:
return 16
@staticmethod
def _compute_threads_per_row(N_per_cta: int) -> int:
"""Compute optimal threads per row based on N per CTA."""
if N_per_cta <= 64:
return 8
elif N_per_cta <= 128:
return 16
elif N_per_cta <= 3072:
return 32
elif N_per_cta <= 6144:
return 64
elif N_per_cta <= 16384:
return 128
else:
return 256
@staticmethod
def _compute_num_threads(N_per_cta: int) -> int:
"""Compute total threads per block."""
return 128 if N_per_cta <= 16384 else 256
@staticmethod
def _make_tv_layout(
threads_per_row: int,
rows_per_block: int,
vec_size: int,
num_vec_blocks: int,
) -> tuple:
"""Create Thread-Value layout for coalesced vectorized memory access."""
shape = (
(threads_per_row, rows_per_block),
(vec_size, num_vec_blocks),
)
stride = (
(vec_size * rows_per_block, 1),
(rows_per_block, rows_per_block * vec_size * threads_per_row),
)
return shape, stride
def smem_size_in_bytes(self) -> int:
"""Calculate shared memory requirement in bytes."""
tile_bytes = self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8)
reduction_bytes = self.rows_per_block * self.warps_per_row * self.cluster_n * 4
mbar_bytes = 8 if self.cluster_n > 1 else 0
return tile_bytes + reduction_bytes + mbar_bytes
# =============================================================================
# RMSNorm Kernel Class
# =============================================================================
class RMSNormKernel:
"""
RMSNorm kernel with cluster synchronization for large N.
Features:
- Cluster-based reduction for large N (SM90+)
- Multiple CTAs cooperate via mbarrier
- Single reduction (sum of squares) with cluster-level aggregation
Example:
>>> kernel = RMSNormKernel(cutlass.Float16, N=4096)
>>> kernel(x_ptr, w_ptr, o_ptr, M, eps, stream)
"""
def __init__(
self,
dtype: cutlass.Numeric,
N: int,
has_weight: bool = True,
config: RMSNormConfig | None = None,
):
# Use provided config or create new one
if config is not None:
self.cfg = config
else:
self.cfg = RMSNormConfig(dtype, N, has_weight)
# Expose key attributes for convenience
self.dtype = self.cfg.dtype
self.N = self.cfg.N
self.has_weight = self.cfg.has_weight
self.cluster_n = self.cfg.cluster_n
@cute.jit
def __call__(
self,
x_ptr: cute.Pointer,
w_ptr: cute.Pointer | None,
o_ptr: cute.Pointer,
M: Int32,
eps: Float32,
stream: cuda.CUstream,
):
"""Host function to launch the RMSNorm kernel."""
cfg = self.cfg
# Create CuTe tensors from raw pointers
mX = cute.make_tensor(
x_ptr,
cute.make_layout((M, cfg.N), stride=(cfg.N, 1)),
)
mO = cute.make_tensor(
o_ptr,
cute.make_layout((M, cfg.N), stride=(cfg.N, 1)),
)
if cutlass.const_expr(cfg.has_weight and w_ptr is not None):
mW = cute.make_tensor(
w_ptr,
cute.make_layout((cfg.N,), stride=(1,)),
)
else:
mW = None
# Create TV layout using static helper
tv_shape, tv_stride = RMSNormConfig._make_tv_layout(
cfg.threads_per_row,
cfg.rows_per_block,
cfg.vec_size,
cfg.num_vec_blocks,
)
tv_layout = cute.make_layout(tv_shape, stride=tv_stride)
tiler_mn = (cfg.rows_per_block, cfg.cols_per_tile)
self.kernel(mX, mW, mO, eps, tv_layout, tiler_mn).launch(
grid=[cute.ceil_div(M, cfg.rows_per_block), cfg.cluster_n, 1],
block=[cfg.num_threads, 1, 1],
cluster=[1, cfg.cluster_n, 1] if cutlass.const_expr(cfg.cluster_n > 1) else None,
smem=cfg.smem_size_in_bytes(),
stream=stream,
)
@cute.kernel
def kernel(
self,
mX: cute.Tensor,
mW: cute.Tensor | None,
mO: cute.Tensor,
eps: Float32,
tv_layout: cute.Layout,
tiler_mn: cute.Shape,
):
"""Device kernel implementing RMSNorm with cluster support."""
cfg = self.cfg
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
if cutlass.const_expr(cfg.cluster_n > 1):
cluster_y = cute.arch.block_idx()[1]
else:
cluster_y = cutlass.const_expr(0)
M = mX.shape[0]
threads_per_row = tv_layout.shape[0][0]
warps_per_row = max(threads_per_row // 32, 1)
rows_per_block = tiler_mn[0]
# =====================================================================
# Allocate shared memory
# =====================================================================
smem = utils.SmemAllocator()
sX = smem.allocate_tensor(
mX.element_type,
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
byte_alignment=16,
)
if cutlass.const_expr(cfg.cluster_n == 1):
reduction_buffer = smem.allocate_tensor(
Float32,
cute.make_layout((rows_per_block, warps_per_row)),
byte_alignment=4,
)
mbar_ptr = None
else:
reduction_buffer = smem.allocate_tensor(
Float32,
cute.make_layout((rows_per_block, (warps_per_row, cfg.cluster_n))),
byte_alignment=4,
)
mbar_ptr = smem.allocate_array(Int64, num_elems=1)
# =====================================================================
# Initialize cluster
# =====================================================================
if cutlass.const_expr(cfg.cluster_n > 1):
if tidx == 0:
cute.arch.mbarrier_init(mbar_ptr, 1)
cute.arch.mbarrier_init_fence()
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()
# =====================================================================
# Create identity tensor and partition
# =====================================================================
idX = cute.make_identity_tensor(mX.shape)
gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y))
gO = cute.local_tile(mO, tiler_mn, (bidx, cluster_y))
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
if cutlass.const_expr(cfg.has_weight and mW is not None):
mW_expanded_layout = cute.prepend(
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
)
mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout)
gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y))
# =====================================================================
# Create TiledCopy operations
# =====================================================================
copy_atom_load_async = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
mX.element_type,
num_bits_per_copy=RMSNormConfig.COPY_BITS,
)
copy_atom_load_W = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
mX.element_type,
num_bits_per_copy=RMSNormConfig.COPY_BITS,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
mO.element_type,
num_bits_per_copy=RMSNormConfig.COPY_BITS,
)
tiled_copy_load = cute.make_tiled_copy(copy_atom_load_async, tv_layout, tiler_mn)
tiled_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn)
tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
thr_copy_X = tiled_copy_load.get_slice(tidx)
thr_copy_W = tiled_copy_W.get_slice(tidx)
thr_copy_O = tiled_copy_store.get_slice(tidx)
# Partition tensors
tXgX = thr_copy_X.partition_S(gX)
tXsX = thr_copy_X.partition_D(sX)
tXgO = thr_copy_O.partition_D(gO)
tXcX = thr_copy_X.partition_S(cX)
# Register fragments
tXrX = cute.make_fragment_like(tXgX)
tXrO = cute.make_fragment_like(tXgO)
if cutlass.const_expr(cfg.has_weight and mW is not None):
tWgW = thr_copy_W.partition_S(gW)
tWrW = cute.make_fragment_like(tWgW)
tXrW = thr_copy_X.retile(tWrW)
# =====================================================================
# Bounds checking
# =====================================================================
tXpX = predicate_k(tXcX, limit=cfg.N)
row_coord = tXcX[(0, 0), 0, 0]
row_in_bounds = row_coord[0] < M
# =====================================================================
# Async copy global → shared
# =====================================================================
if row_in_bounds:
cute.copy(copy_atom_load_async, tXgX, tXsX, pred=tXpX)
cute.arch.cp_async_commit_group()
# Load weight while waiting
if cutlass.const_expr(cfg.has_weight and mW is not None):
tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=cfg.N)
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
cute.arch.cp_async_wait_group(0)
# =====================================================================
# Pass 1: Compute sum of squares with cluster reduction
# =====================================================================
cute.autovec_copy(tXsX, tXrX)
x = tXrX.load().to(Float32)
x_sq = x * x
sum_sq = row_reduce(
x_sq,
cute.ReductionOp.ADD,
threads_per_row,
reduction_buffer,
mbar_ptr,
cfg.cluster_n,
Float32(0.0),
)
# rstd = 1 / sqrt(mean(x²) + eps)
mean_sq = sum_sq / cfg.N
rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True)
# Sync after reduction
if cutlass.const_expr(cfg.cluster_n > 1):
cute.arch.cluster_arrive_relaxed()
cute.arch.cluster_wait()
else:
cute.arch.barrier()
# =====================================================================
# Pass 2: Normalize and output
# =====================================================================
cute.autovec_copy(tXsX, tXrX)
x = tXrX.load().to(Float32)
y = x * rstd
# Apply weight if present
if cutlass.const_expr(cfg.has_weight and mW is not None):
w = tXrW.load().to(Float32)
y = y * w
# Store to global memory
tXrO.store(y.to(cfg.dtype))
if row_in_bounds:
cute.copy(copy_atom_store, tXrO, tXgO, pred=tXpX)
# =============================================================================
# Kernel Compilation and Caching
# =============================================================================
# Mapping from torch dtype to cutlass dtype
_torch_to_cutlass_dtype = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
}
# Cache for compiled kernels
_compile_cache: dict = {}
def get_compiled_kernel(
dtype: type[cutlass.Numeric],
N: int,
has_weight: bool,
stream: cuda.CUstream,
):
"""
Get or compile the RMSNorm kernel for the given configuration.
Uses compilation cache to avoid recompiling for same (dtype, N, has_weight) tuples.
:param dtype: Data type (Float16, BFloat16, Float32)
:type dtype: type[cutlass.Numeric]
:param N: Hidden dimension size
:type N: int
:param has_weight: Whether weight is applied
:type has_weight: bool
:param stream: CUDA stream
:type stream: cuda.CUstream
:return: Compiled kernel function
"""
key = (dtype, N, has_weight)
if key not in _compile_cache:
kernel_obj = RMSNormKernel(dtype, N, has_weight)
# Compile with representative arguments
compiled_kernel = cute.compile(
kernel_obj,
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16), # x_ptr
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16)
if has_weight
else None, # w_ptr
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16), # o_ptr
Int32(1), # M (dummy)
Float32(1e-6), # eps (dummy)
stream,
)
_compile_cache[key] = compiled_kernel
return _compile_cache[key]
# =============================================================================
# Tensor Creation Utilities
# =============================================================================
def create_tensors(
M: int,
N: int,
dtype: type[cutlass.Numeric],
has_weight: bool,
) -> Tuple:
"""Create input, weight, and output tensors for RMSNorm."""
torch.manual_seed(42)
torch_dtype = cutlass_torch.dtype(dtype)
x = torch.randn(M, N, device="cuda", dtype=torch_dtype)
weight = torch.randn(N, device="cuda", dtype=torch_dtype) if has_weight else None
out = torch.empty_like(x)
return x, weight, out
def rmsnorm_ref(
x: torch.Tensor,
weight: torch.Tensor | None = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""Reference RMSNorm implementation in PyTorch."""
x_f32 = x.float()
rms = torch.sqrt(torch.mean(x_f32**2, dim=-1, keepdim=True) + eps)
x_norm = x_f32 / rms
if weight is not None:
x_norm = x_norm * weight.float()
return x_norm.to(x.dtype)
# =============================================================================
# Run Function
# =============================================================================
def run(
M: int,
N: int,
dtype: type[cutlass.Numeric],
has_weight: bool = True,
eps: float = 1e-6,
tolerance: float = 1e-2,
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
benchmark: bool = False,
) -> float:
"""
Execute RMSNorm and optionally benchmark performance.
:param M: Number of rows (batch size * sequence length)
:type M: int
:param N: Hidden dimension size
:type N: int
:param dtype: Data type (Float16, BFloat16, Float32)
:type dtype: type[cutlass.Numeric]
:param has_weight: Whether to apply learnable weight
:type has_weight: bool
:param eps: Epsilon for numerical stability
:type eps: float
:param tolerance: Tolerance for correctness check
:type tolerance: float
:param warmup_iterations: Warmup iterations for benchmarking
:type warmup_iterations: int
:param iterations: Number of benchmark iterations
:type iterations: int
:param skip_ref_check: Skip reference correctness check
:type skip_ref_check: bool
:param benchmark: Enable benchmarking
:type benchmark: bool
:return: Execution time in microseconds (if benchmark=True, else 0)
:rtype: float
"""
print("Running RMSNorm test with:")
print(f" M: {M}, N: {N}")
print(f" dtype: {dtype}")
print(f" has_weight: {has_weight}")
print(f" eps: {eps}")
print(f" SM version: {get_sm_version()}")
if not torch.cuda.is_available():
raise RuntimeError("CUDA GPU is required to run this example!")
# Get CUDA stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
# Create tensors
x, weight, out = create_tensors(M, N, dtype, has_weight)
# Get configuration info
config = RMSNormConfig(dtype, N, has_weight)
print(f" cluster_n: {config.cluster_n}")
print(f" threads_per_row: {config.threads_per_row}")
print(f" rows_per_block: {config.rows_per_block}")
# Get compiled kernel
compiled_kernel = get_compiled_kernel(dtype, N, has_weight, stream)
# Create pointers for kernel call
x_ptr = make_ptr(dtype, x.data_ptr())
w_ptr = make_ptr(dtype, weight.data_ptr()) if weight is not None else None
out_ptr = make_ptr(dtype, out.data_ptr())
# Run kernel and verify
if not skip_ref_check:
compiled_kernel(x_ptr, w_ptr, out_ptr, Int32(M), Float32(eps), stream)
torch.cuda.synchronize()
ref = rmsnorm_ref(x, weight, eps)
torch.testing.assert_close(out, ref, atol=tolerance, rtol=tolerance)
print("Correctness check passed!")
if not benchmark:
return 0.0
# Benchmark
print(f"\nBenchmarking with {warmup_iterations} warmup, {iterations} iterations...")
def generate_tensors():
x, weight, out = create_tensors(M, N, dtype, has_weight)
x_ptr = make_ptr(dtype, x.data_ptr())
w_ptr = make_ptr(dtype, weight.data_ptr()) if weight is not None else None
out_ptr = make_ptr(dtype, out.data_ptr())
return testing.JitArguments(x_ptr, w_ptr, out_ptr, Int32(M), Float32(eps), stream)
exec_time_us = testing.benchmark(
compiled_kernel,
workspace_generator=generate_tensors,
workspace_count=10,
warmup_iterations=warmup_iterations,
iterations=iterations,
stream=stream,
)
# Calculate throughput
torch_dtype = cutlass_torch.dtype(dtype)
bytes_per_elem = torch.tensor([], dtype=torch_dtype).element_size()
total_bytes = M * N * bytes_per_elem * 2 # read x, write out
if has_weight:
total_bytes += N * bytes_per_elem # read weight (amortized across M)
throughput_gbps = (total_bytes / (exec_time_us / 1e6)) / 1e9
print(f"Kernel execution time: {exec_time_us:.2f} us")
print(f"Memory throughput: {throughput_gbps:.2f} GB/s")
return exec_time_us
# =============================================================================
# Main Entry Point
# =============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="RMSNorm kernel example for Blackwell (SM100)"
)
parser.add_argument("--M", type=int, default=2048, help="Number of rows")
parser.add_argument("--N", type=int, default=4096, help="Hidden dimension size")
parser.add_argument(
"--dtype",
type=cutlass.dtype,
default=cutlass.BFloat16,
help="Data type (Float16, BFloat16, Float32)",
)
parser.add_argument(
"--has_weight",
action="store_true",
default=True,
help="Apply learnable weight",
)
parser.add_argument(
"--no_weight",
action="store_true",
help="Disable learnable weight",
)
parser.add_argument(
"--eps",
type=float,
default=1e-6,
help="Epsilon for numerical stability",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-2,
help="Tolerance for correctness check",
)
parser.add_argument(
"--warmup_iterations",
type=int,
default=2,
help="Warmup iterations for benchmarking",
)
parser.add_argument(
"--iterations",
type=int,
default=100,
help="Number of benchmark iterations",
)
parser.add_argument(
"--skip_ref_check",
action="store_true",
help="Skip reference correctness check",
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Enable benchmarking",
)
args = parser.parse_args()
# Handle weight flag
has_weight = args.has_weight and not args.no_weight
run(
M=args.M,
N=args.N,
dtype=args.dtype,
has_weight=has_weight,
eps=args.eps,
tolerance=args.tolerance,
warmup_iterations=args.warmup_iterations,
iterations=args.iterations,
skip_ref_check=args.skip_ref_check,
benchmark=args.benchmark,
)
print("PASS")