mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
New RMS Norm example with unit tests (#2917)
* Add rmsnorm example * Address reviewer comments. (1) use the cute.runtime definition directly. (2) use the nvvm_wrapper's warp reduce directly * Separate out reduce.py * Change copyright notice years
This commit is contained in:
533
examples/python/CuTeDSL/blackwell/reduce.py
Normal file
533
examples/python/CuTeDSL/blackwell/reduce.py
Normal file
@@ -0,0 +1,533 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Hierarchical Reduction Utilities for CuTe-DSL Kernels
|
||||
=====================================================
|
||||
|
||||
This module provides reusable reduction primitives for GPU kernels that need to
|
||||
reduce values across warps, thread blocks, and clusters (SM90+).
|
||||
|
||||
Overview
|
||||
--------
|
||||
GPU reductions typically follow a hierarchical pattern:
|
||||
|
||||
1. **Warp Reduction**: Threads within a warp reduce using shuffle instructions.
|
||||
Use `cute.arch.warp_reduction()` from the CuTe-DSL library.
|
||||
|
||||
2. **Block Reduction**: Multiple warps within a block reduce using shared memory.
|
||||
Use `block_reduce()` from this module.
|
||||
|
||||
3. **Cluster Reduction** (SM90+): Multiple CTAs in a cluster reduce using
|
||||
distributed shared memory and mbarrier synchronization.
|
||||
Use `cluster_reduce()` from this module.
|
||||
|
||||
4. **Row Reduction**: Orchestrates all levels based on problem configuration.
|
||||
Use `row_reduce()` from this module.
|
||||
|
||||
Shared Memory Buffer Layout Assumptions
|
||||
---------------------------------------
|
||||
|
||||
For `block_reduce`:
|
||||
- Buffer shape: (rows_per_block, warps_per_row)
|
||||
- Each warp's lane 0 writes its reduced value to buffer[row_idx, col_idx]
|
||||
- Thread mapping: row_idx = warp_idx // warps_per_row
|
||||
col_idx = warp_idx % warps_per_row
|
||||
|
||||
Example for 8 warps, 2 rows, 4 warps per row:
|
||||
Warp 0 -> buffer[0, 0] Warp 4 -> buffer[1, 0]
|
||||
Warp 1 -> buffer[0, 1] Warp 5 -> buffer[1, 1]
|
||||
Warp 2 -> buffer[0, 2] Warp 6 -> buffer[1, 2]
|
||||
Warp 3 -> buffer[0, 3] Warp 7 -> buffer[1, 3]
|
||||
|
||||
For `cluster_reduce`:
|
||||
- Buffer shape: (rows_per_block, (warps_per_row, cluster_n))
|
||||
- The second dimension is hierarchical: (local_warp_slot, cta_rank)
|
||||
- Each CTA contributes to its own slot in the cluster dimension
|
||||
|
||||
Example for cluster_n=4, 2 warps per row:
|
||||
CTA 0, Warp 0 -> buffer[row, (0, 0)]
|
||||
CTA 0, Warp 1 -> buffer[row, (1, 0)]
|
||||
CTA 1, Warp 0 -> buffer[row, (0, 1)]
|
||||
CTA 1, Warp 1 -> buffer[row, (1, 1)]
|
||||
... etc for CTAs 2, 3
|
||||
|
||||
Mbarrier Requirements (Cluster Reduction)
|
||||
-----------------------------------------
|
||||
For cluster reduction, the caller must:
|
||||
1. Allocate an mbarrier in shared memory
|
||||
2. Initialize it with `cute.arch.mbarrier_init(mbar_ptr, thread_count)`
|
||||
3. Pass the mbarrier pointer to `cluster_reduce()`
|
||||
|
||||
The cluster_reduce function handles:
|
||||
- Setting up the expected transaction count
|
||||
- Performing async cross-CTA stores
|
||||
- Waiting for all stores to complete
|
||||
|
||||
Usage Example
|
||||
-------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from reduce import row_reduce, block_reduce, cluster_reduce
|
||||
|
||||
@cute.jit
|
||||
def my_kernel(...):
|
||||
# Allocate shared memory for reduction
|
||||
# Shape depends on warps_per_row and cluster_n
|
||||
if cluster_n > 1:
|
||||
reduction_buffer = cute.make_smem_tensor(
|
||||
cute.make_layout((rows_per_block, (warps_per_row, cluster_n))),
|
||||
Float32
|
||||
)
|
||||
else:
|
||||
reduction_buffer = cute.make_smem_tensor(
|
||||
cute.make_layout((rows_per_block, warps_per_row)),
|
||||
Float32
|
||||
)
|
||||
|
||||
# Perform row reduction
|
||||
result = row_reduce(
|
||||
tensor_ssa,
|
||||
cute.ReductionOp.ADD,
|
||||
threads_per_row,
|
||||
reduction_buffer,
|
||||
mbar_ptr,
|
||||
cluster_n,
|
||||
init_val=Float32(0.0)
|
||||
)
|
||||
|
||||
References
|
||||
----------
|
||||
The cluster synchronization primitives (set_block_rank, store_shared_remote)
|
||||
are inspired by Quack: https://github.com/Dao-AILab/quack
|
||||
"""
|
||||
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Float32, Int32
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inline PTX Operations for Cluster Communication
|
||||
# =============================================================================
|
||||
#
|
||||
# These operations enable cross-CTA communication within a cluster (SM90+).
|
||||
# They use inline PTX assembly for functionality not yet exposed in MLIR.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def set_block_rank(
|
||||
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
|
||||
) -> Int32:
|
||||
"""
|
||||
Map a shared memory pointer to the equivalent address in another CTA's
|
||||
shared memory within the same cluster.
|
||||
|
||||
This uses the PTX `mapa.shared::cluster` instruction to translate a local
|
||||
shared memory address to the corresponding address in a peer CTA's shared
|
||||
memory space.
|
||||
|
||||
Args:
|
||||
smem_ptr: Pointer to local shared memory
|
||||
peer_cta_rank_in_cluster: Target CTA's rank within the cluster (0 to cluster_size-1)
|
||||
|
||||
Returns:
|
||||
Int32 representing the mapped address in the peer CTA's shared memory
|
||||
|
||||
Note:
|
||||
This operation requires SM90+ with cluster support enabled.
|
||||
The cluster must be launched with the appropriate cluster dimensions.
|
||||
"""
|
||||
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
return Int32(
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
|
||||
"mapa.shared::cluster.u32 $0, $1, $2;",
|
||||
"=r,r,r",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def store_shared_remote(
|
||||
val: Float32,
|
||||
smem_ptr: cute.Pointer,
|
||||
mbar_ptr: cute.Pointer,
|
||||
peer_cta_rank_in_cluster: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store a Float32 value to shared memory on a remote CTA
|
||||
within the cluster, with mbarrier completion tracking.
|
||||
|
||||
This uses the PTX `st.async.shared::cluster` instruction which:
|
||||
1. Translates the local smem address to the peer CTA's address space
|
||||
2. Performs an asynchronous store to the remote shared memory
|
||||
3. Signals the mbarrier when the store completes
|
||||
|
||||
Args:
|
||||
val: The Float32 value to store
|
||||
smem_ptr: Pointer to the destination in local shared memory coordinates
|
||||
mbar_ptr: Pointer to the mbarrier that tracks completion
|
||||
peer_cta_rank_in_cluster: Target CTA's rank within the cluster
|
||||
|
||||
Note:
|
||||
- The mbarrier must be initialized with the expected transaction byte count
|
||||
- Use `cute.arch.mbarrier_arrive_and_expect_tx()` to set up the transaction
|
||||
- Use `cute.arch.mbarrier_wait()` to wait for all stores to complete
|
||||
- This operation requires SM90+ with cluster support enabled
|
||||
"""
|
||||
remote_smem_ptr_i32 = set_block_rank(
|
||||
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
||||
).ir_value()
|
||||
remote_mbar_ptr_i32 = set_block_rank(
|
||||
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
||||
).ir_value()
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
||||
"st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
|
||||
"r,f,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def elem_pointer(x: cute.Tensor, coord, *, loc=None, ip=None) -> cute.Pointer:
|
||||
"""
|
||||
Get a pointer to an element at the specified coordinate in a tensor.
|
||||
|
||||
This is useful for getting the shared memory address of a specific element
|
||||
when performing cross-CTA stores in cluster reduction.
|
||||
|
||||
Args:
|
||||
x: The tensor (typically a shared memory tensor)
|
||||
coord: The coordinate tuple, can be hierarchical like (row, (col, cluster_idx))
|
||||
|
||||
Returns:
|
||||
Pointer to the element at the specified coordinate
|
||||
"""
|
||||
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Block-Level Reduction
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@cute.jit
|
||||
def block_reduce(
|
||||
val: Float32,
|
||||
op: Callable,
|
||||
reduction_buffer: cute.Tensor,
|
||||
init_val: Float32,
|
||||
) -> Float32:
|
||||
"""
|
||||
Reduce values across all warps within a thread block using shared memory.
|
||||
|
||||
This function assumes each warp has already performed a warp-level reduction
|
||||
and is contributing a single value (from lane 0). The function then:
|
||||
1. Writes each warp's value to shared memory
|
||||
2. Synchronizes the block
|
||||
3. Performs a final warp reduction across the collected values
|
||||
|
||||
Args:
|
||||
val: The warp-reduced value (only lane 0's value is used)
|
||||
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
|
||||
reduction_buffer: Shared memory tensor with shape (rows_per_block, warps_per_row)
|
||||
init_val: Identity element for the reduction (0 for sum, -inf for max)
|
||||
|
||||
Returns:
|
||||
The block-reduced result (same value across all threads)
|
||||
|
||||
Buffer Layout:
|
||||
- Shape: (rows_per_block, warps_per_row)
|
||||
- warps_per_row is inferred from reduction_buffer.shape[1]
|
||||
- Thread mapping:
|
||||
row_idx = warp_idx // warps_per_row
|
||||
col_idx = warp_idx % warps_per_row
|
||||
|
||||
Example:
|
||||
For a block with 8 warps processing 2 rows (4 warps per row):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
reduction_buffer = cute.make_smem_tensor(
|
||||
cute.make_layout((2, 4)), # 2 rows, 4 warps per row
|
||||
Float32
|
||||
)
|
||||
result = block_reduce(warp_val, operator.add, reduction_buffer, Float32(0.0))
|
||||
"""
|
||||
lane_idx = cute.arch.lane_idx()
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warps_per_row = cute.size(reduction_buffer.shape[1])
|
||||
row_idx = warp_idx // warps_per_row
|
||||
col_idx = warp_idx % warps_per_row
|
||||
|
||||
# Lane 0 of each warp writes its value to shared memory
|
||||
if lane_idx == 0:
|
||||
reduction_buffer[row_idx, col_idx] = val
|
||||
cute.arch.barrier()
|
||||
|
||||
# All lanes participate in reading and reducing
|
||||
# Only lanes < warps_per_row have valid data
|
||||
block_reduce_val = init_val
|
||||
if lane_idx < warps_per_row:
|
||||
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
||||
return cute.arch.warp_reduction(block_reduce_val, op)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cluster-Level Reduction (SM90+)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cluster_reduce(
|
||||
val: Float32,
|
||||
op: Callable,
|
||||
reduction_buffer: cute.Tensor,
|
||||
mbar_ptr: cute.Pointer,
|
||||
cluster_n: cutlass.Constexpr[int],
|
||||
init_val: Float32,
|
||||
) -> Float32:
|
||||
"""
|
||||
Reduce values across all CTAs within a cluster using distributed shared memory.
|
||||
|
||||
This function extends block reduction to work across multiple CTAs in a cluster
|
||||
using asynchronous cross-CTA stores and mbarrier synchronization. It:
|
||||
1. Sets up the mbarrier with expected transaction count
|
||||
2. Asynchronously stores each warp's value to all peer CTAs
|
||||
3. Waits for all stores to complete
|
||||
4. Reduces across all collected values
|
||||
|
||||
Args:
|
||||
val: The warp-reduced value (only lane 0's value is used for stores)
|
||||
op: Binary reduction operator, e.g., `operator.add` or `cute.arch.fmax`
|
||||
reduction_buffer: Shared memory tensor with hierarchical shape
|
||||
(rows_per_block, (warps_per_row, cluster_n))
|
||||
mbar_ptr: Pointer to an initialized mbarrier in shared memory
|
||||
cluster_n: Number of CTAs in the cluster (compile-time constant)
|
||||
init_val: Identity element for the reduction (0 for sum, -inf for max)
|
||||
|
||||
Returns:
|
||||
The cluster-reduced result (same value across all threads in all CTAs)
|
||||
|
||||
Buffer Layout:
|
||||
- Shape: (rows_per_block, (warps_per_row, cluster_n))
|
||||
- The second dimension is hierarchical:
|
||||
- First level: warps_per_row (local warp slots)
|
||||
- Second level: cluster_n (one slot per CTA in cluster)
|
||||
- Access pattern: buffer[row_idx, (col_idx, cta_rank)]
|
||||
|
||||
Requirements:
|
||||
- SM90+ with cluster support
|
||||
- Mbarrier must be initialized before calling
|
||||
- Kernel must be launched with appropriate cluster dimensions
|
||||
|
||||
Example:
|
||||
For a cluster of 4 CTAs, each with 2 warps per row:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Allocate buffer with cluster dimension
|
||||
reduction_buffer = cute.make_smem_tensor(
|
||||
cute.make_layout((rows_per_block, (2, 4))), # 2 warps, 4 CTAs
|
||||
Float32
|
||||
)
|
||||
|
||||
# Initialize mbarrier (once per kernel)
|
||||
mbar = cute.make_smem_tensor(cute.make_layout((1,)), cute.arch.Mbarrier)
|
||||
cute.arch.mbarrier_init(mbar.iterator, thread_count)
|
||||
|
||||
# Perform cluster reduction
|
||||
result = cluster_reduce(
|
||||
warp_val, operator.add, reduction_buffer,
|
||||
mbar.iterator, cluster_n=4, init_val=Float32(0.0)
|
||||
)
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
lane_idx = cute.arch.lane_idx()
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
|
||||
rows_per_block = reduction_buffer.shape[0]
|
||||
warps_per_row = reduction_buffer.shape[1][0]
|
||||
|
||||
row_idx = warp_idx // warps_per_row
|
||||
col_idx = warp_idx % warps_per_row
|
||||
|
||||
# Warp 0, lane 0 sets up mbarrier with expected transaction count
|
||||
# Each warp sends cluster_n stores (one to each CTA), each store is 4 bytes
|
||||
if warp_idx == 0:
|
||||
with cute.arch.elect_one():
|
||||
num_warps = rows_per_block * warps_per_row
|
||||
expected_bytes = num_warps * cluster_n * 4 # 4 bytes per Float32
|
||||
cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr, expected_bytes)
|
||||
|
||||
# Each lane < cluster_n writes to a different CTA's shared memory
|
||||
# This distributes the warp's value to all CTAs in the cluster
|
||||
if lane_idx < cluster_n:
|
||||
store_shared_remote(
|
||||
val,
|
||||
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
||||
mbar_ptr,
|
||||
peer_cta_rank_in_cluster=lane_idx,
|
||||
)
|
||||
|
||||
# Wait for all cross-CTA stores to complete
|
||||
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
||||
|
||||
# Now each CTA has all values from all CTAs in the cluster
|
||||
# Reduce across all collected values
|
||||
num_total = warps_per_row * cluster_n
|
||||
num_iter = cute.ceil_div(num_total, 32)
|
||||
|
||||
block_reduce_val = init_val
|
||||
for i in cutlass.range_constexpr(num_iter):
|
||||
idx = lane_idx + i * 32
|
||||
if idx < num_total:
|
||||
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
||||
|
||||
return cute.arch.warp_reduction(block_reduce_val, op)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Row Reduction (Orchestration Function)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@cute.jit
|
||||
def row_reduce(
|
||||
x: cute.TensorSSA,
|
||||
op: cute.ReductionOp,
|
||||
threads_per_row: cutlass.Constexpr[int],
|
||||
reduction_buffer: cute.Tensor,
|
||||
mbar_ptr,
|
||||
cluster_n: cutlass.Constexpr[int],
|
||||
init_val: Float32,
|
||||
):
|
||||
"""
|
||||
Perform hierarchical row reduction with automatic selection of reduction strategy.
|
||||
|
||||
This function orchestrates the full reduction pipeline:
|
||||
1. Local reduction: Each thread reduces its portion of the row
|
||||
2. Warp reduction: Threads within a warp reduce using shuffles
|
||||
3. Block reduction: If needed, warps reduce using shared memory
|
||||
4. Cluster reduction: If needed, CTAs reduce using distributed shared memory
|
||||
|
||||
The function automatically selects the appropriate reduction level based on
|
||||
`threads_per_row` and `cluster_n`.
|
||||
|
||||
Args:
|
||||
x: TensorSSA containing the values to reduce (in registers)
|
||||
op: Reduction operation (cute.ReductionOp.ADD or cute.ReductionOp.MAX)
|
||||
threads_per_row: Number of threads cooperating on each row (compile-time)
|
||||
reduction_buffer: Shared memory tensor for block/cluster reduction
|
||||
mbar_ptr: Mbarrier pointer (only used if cluster_n > 1)
|
||||
cluster_n: Number of CTAs in cluster (1 for single-CTA reduction)
|
||||
init_val: Identity element for the reduction
|
||||
|
||||
Returns:
|
||||
The fully reduced result for each row
|
||||
|
||||
Reduction Strategy Selection:
|
||||
- threads_per_row <= 32, cluster_n == 1: Warp reduction only
|
||||
- threads_per_row > 32, cluster_n == 1: Warp + block reduction
|
||||
- cluster_n > 1: Warp + cluster reduction (handles all cases)
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Sum reduction across 128 threads per row, single CTA
|
||||
result = row_reduce(
|
||||
tensor_ssa,
|
||||
cute.ReductionOp.ADD,
|
||||
threads_per_row=128,
|
||||
reduction_buffer=smem_buffer,
|
||||
mbar_ptr=None,
|
||||
cluster_n=1,
|
||||
init_val=Float32(0.0)
|
||||
)
|
||||
|
||||
# Max reduction across 256 threads per row, 4 CTAs in cluster
|
||||
result = row_reduce(
|
||||
tensor_ssa,
|
||||
cute.ReductionOp.MAX,
|
||||
threads_per_row=256,
|
||||
reduction_buffer=smem_buffer,
|
||||
mbar_ptr=mbar.iterator,
|
||||
cluster_n=4,
|
||||
init_val=Float32.neg_inf
|
||||
)
|
||||
"""
|
||||
# Step 1: Local reduction - each thread reduces its register values
|
||||
local_val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
||||
|
||||
# Map ReductionOp enum to binary operator for warp/block reductions
|
||||
warp_op = {
|
||||
cute.ReductionOp.ADD: operator.add,
|
||||
cute.ReductionOp.MAX: cute.arch.fmax,
|
||||
}[op]
|
||||
|
||||
# Step 2: Warp reduction
|
||||
# If threads_per_row < 32, only use that many threads in the reduction
|
||||
warp_width = min(threads_per_row, 32)
|
||||
warp_val = cute.arch.warp_reduction(local_val, warp_op, threads_in_group=warp_width)
|
||||
|
||||
# Determine if we need additional reduction levels
|
||||
warps_per_row = max(threads_per_row // 32, 1)
|
||||
|
||||
# Step 3 & 4: Block or cluster reduction (if needed)
|
||||
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
||||
if cutlass.const_expr(cluster_n == 1):
|
||||
# Single CTA: use block reduction
|
||||
return block_reduce(warp_val, warp_op, reduction_buffer, init_val)
|
||||
else:
|
||||
# Multiple CTAs: use cluster reduction
|
||||
return cluster_reduce(
|
||||
warp_val, warp_op, reduction_buffer, mbar_ptr, cluster_n, init_val
|
||||
)
|
||||
else:
|
||||
# Single warp handles entire row: warp reduction is sufficient
|
||||
return warp_val
|
||||
|
||||
840
examples/python/CuTeDSL/blackwell/rmsnorm.py
Normal file
840
examples/python/CuTeDSL/blackwell/rmsnorm.py
Normal file
@@ -0,0 +1,840 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import ctypes
|
||||
import functools
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass import Boolean, Float32, Int32, Int64
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
# Support both direct execution and module import
|
||||
try:
|
||||
from .reduce import row_reduce
|
||||
except ImportError:
|
||||
from reduce import row_reduce
|
||||
|
||||
"""
|
||||
RMSNorm: Root Mean Square Layer Normalization for Hopper & Blackwell (SM90+)
|
||||
====================================================================
|
||||
|
||||
A high-performance RMSNorm implementation using CuTe DSL with cluster-based
|
||||
reduction for large hidden dimensions.
|
||||
|
||||
RMSNorm computes: y = x / sqrt(mean(x²) + eps) * weight
|
||||
|
||||
Key Features:
|
||||
-------------
|
||||
1. CLUSTER SYNCHRONIZATION (SM90+)
|
||||
- Multiple CTAs cooperate to process large N dimensions
|
||||
- Each CTA handles N/cluster_n elements, then reduces across the cluster
|
||||
- Uses mbarrier for efficient cross-CTA synchronization
|
||||
|
||||
2. ARCHITECTURE-SPECIFIC TUNING
|
||||
- SM80 (Ampere): Single-CTA execution (cluster_n=1)
|
||||
- SM90 (Hopper): Cluster support enabled for large N
|
||||
- SM100 (Blackwell): Same as SM90
|
||||
|
||||
3. VECTORIZED MEMORY ACCESS
|
||||
- 128-bit vectorized loads/stores for optimal memory throughput
|
||||
- TiledCopy abstraction for organized gmem↔smem↔rmem transfers
|
||||
|
||||
Cluster Size Selection (FP16):
|
||||
------------------------------
|
||||
- N <= 16K: cluster_n = 1 (single CTA)
|
||||
- N <= 32K: cluster_n = 2
|
||||
- N <= 64K: cluster_n = 4
|
||||
- N <= 128K: cluster_n = 8
|
||||
- Larger: cluster_n = 16
|
||||
|
||||
To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16
|
||||
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16 --benchmark
|
||||
python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 32768 --dtype BFloat16 --benchmark
|
||||
|
||||
To collect performance with NCU profiler:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
ncu python examples/python/CuTeDSL/blackwell/rmsnorm.py --M 2048 --N 4096 --dtype BFloat16 --skip_ref_check
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# Architecture Detection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def get_sm_version(device: Optional[Union[int, torch.device, str]] = None) -> int:
|
||||
"""Get the SM (compute capability) version of a CUDA device."""
|
||||
if not torch.cuda.is_available():
|
||||
return 80 # Default fallback
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
return props.major * 10 + props.minor
|
||||
|
||||
|
||||
def supports_cluster() -> bool:
|
||||
"""Check if the current device supports cluster operations (SM90+)."""
|
||||
return get_sm_version() >= 90
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Predicate Utility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@cute.jit
|
||||
def predicate_k(tXcX: cute.Tensor, limit: int) -> cute.Tensor:
|
||||
"""Create predicate tensor for bounds checking."""
|
||||
tXpX = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(cute.size(tXcX, mode=[0, 1]), cute.size(tXcX, mode=[1]), cute.size(tXcX, mode=[2])),
|
||||
stride=(cute.size(tXcX, mode=[2]), 0, 1),
|
||||
),
|
||||
Boolean,
|
||||
)
|
||||
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
|
||||
tXpX[rest_v, 0, rest_k] = cute.elem_less(tXcX[(0, rest_v), 0, rest_k][1], limit)
|
||||
return tXpX
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RMSNorm Configuration Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RMSNormConfig:
|
||||
"""
|
||||
Configuration for the RMSNorm kernel.
|
||||
|
||||
This class encapsulates all kernel configuration computed once at initialization,
|
||||
following CuTe-DSL conventions from official CUTLASS examples.
|
||||
"""
|
||||
|
||||
COPY_BITS = 128 # 128-bit vectorized loads
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: type[cutlass.Numeric],
|
||||
N: int,
|
||||
has_weight: bool = True,
|
||||
sm_version: int | None = None,
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.N = N
|
||||
self.has_weight = has_weight
|
||||
self.sm_version = sm_version if sm_version is not None else get_sm_version()
|
||||
|
||||
# Vector size for 128-bit loads
|
||||
self.vec_size = self.COPY_BITS // dtype.width
|
||||
|
||||
# Compute cluster size (SM90+ only)
|
||||
self.cluster_n = self._compute_cluster_n(N, dtype, self.sm_version)
|
||||
|
||||
# N per CTA for cluster case
|
||||
self.N_per_cta = N // self.cluster_n
|
||||
|
||||
# Thread configuration using static methods
|
||||
self.threads_per_row = self._compute_threads_per_row(self.N_per_cta)
|
||||
self.num_threads = self._compute_num_threads(self.N_per_cta)
|
||||
|
||||
# Derived values
|
||||
self.num_vec_blocks = max(
|
||||
1, (self.N_per_cta // self.vec_size + self.threads_per_row - 1) // self.threads_per_row
|
||||
)
|
||||
self.rows_per_block = self.num_threads // self.threads_per_row
|
||||
self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row
|
||||
self.warps_per_row = max(self.threads_per_row // 32, 1)
|
||||
|
||||
@staticmethod
|
||||
def _compute_cluster_n(N: int, dtype: type[cutlass.Numeric], sm_version: int) -> int:
|
||||
"""Compute optimal cluster size based on N and architecture."""
|
||||
if sm_version < 90:
|
||||
return 1
|
||||
|
||||
if dtype.width == 16: # FP16/BF16
|
||||
if N <= 16 * 1024:
|
||||
return 1
|
||||
elif N <= 32 * 1024:
|
||||
return 2
|
||||
elif N <= 64 * 1024:
|
||||
return 4
|
||||
elif N <= 128 * 1024:
|
||||
return 8
|
||||
else:
|
||||
return 16
|
||||
else: # FP32
|
||||
if N <= 32 * 1024:
|
||||
return 1
|
||||
elif N <= 64 * 1024:
|
||||
return 2
|
||||
elif N <= 128 * 1024:
|
||||
return 4
|
||||
elif N <= 256 * 1024:
|
||||
return 8
|
||||
else:
|
||||
return 16
|
||||
|
||||
@staticmethod
|
||||
def _compute_threads_per_row(N_per_cta: int) -> int:
|
||||
"""Compute optimal threads per row based on N per CTA."""
|
||||
if N_per_cta <= 64:
|
||||
return 8
|
||||
elif N_per_cta <= 128:
|
||||
return 16
|
||||
elif N_per_cta <= 3072:
|
||||
return 32
|
||||
elif N_per_cta <= 6144:
|
||||
return 64
|
||||
elif N_per_cta <= 16384:
|
||||
return 128
|
||||
else:
|
||||
return 256
|
||||
|
||||
@staticmethod
|
||||
def _compute_num_threads(N_per_cta: int) -> int:
|
||||
"""Compute total threads per block."""
|
||||
return 128 if N_per_cta <= 16384 else 256
|
||||
|
||||
@staticmethod
|
||||
def _make_tv_layout(
|
||||
threads_per_row: int,
|
||||
rows_per_block: int,
|
||||
vec_size: int,
|
||||
num_vec_blocks: int,
|
||||
) -> tuple:
|
||||
"""Create Thread-Value layout for coalesced vectorized memory access."""
|
||||
shape = (
|
||||
(threads_per_row, rows_per_block),
|
||||
(vec_size, num_vec_blocks),
|
||||
)
|
||||
stride = (
|
||||
(vec_size * rows_per_block, 1),
|
||||
(rows_per_block, rows_per_block * vec_size * threads_per_row),
|
||||
)
|
||||
return shape, stride
|
||||
|
||||
def smem_size_in_bytes(self) -> int:
|
||||
"""Calculate shared memory requirement in bytes."""
|
||||
tile_bytes = self.rows_per_block * self.cols_per_tile * (self.dtype.width // 8)
|
||||
reduction_bytes = self.rows_per_block * self.warps_per_row * self.cluster_n * 4
|
||||
mbar_bytes = 8 if self.cluster_n > 1 else 0
|
||||
return tile_bytes + reduction_bytes + mbar_bytes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RMSNorm Kernel Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RMSNormKernel:
|
||||
"""
|
||||
RMSNorm kernel with cluster synchronization for large N.
|
||||
|
||||
Features:
|
||||
- Cluster-based reduction for large N (SM90+)
|
||||
- Multiple CTAs cooperate via mbarrier
|
||||
- Single reduction (sum of squares) with cluster-level aggregation
|
||||
|
||||
Example:
|
||||
>>> kernel = RMSNormKernel(cutlass.Float16, N=4096)
|
||||
>>> kernel(x_ptr, w_ptr, o_ptr, M, eps, stream)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: cutlass.Numeric,
|
||||
N: int,
|
||||
has_weight: bool = True,
|
||||
config: RMSNormConfig | None = None,
|
||||
):
|
||||
# Use provided config or create new one
|
||||
if config is not None:
|
||||
self.cfg = config
|
||||
else:
|
||||
self.cfg = RMSNormConfig(dtype, N, has_weight)
|
||||
|
||||
# Expose key attributes for convenience
|
||||
self.dtype = self.cfg.dtype
|
||||
self.N = self.cfg.N
|
||||
self.has_weight = self.cfg.has_weight
|
||||
self.cluster_n = self.cfg.cluster_n
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
x_ptr: cute.Pointer,
|
||||
w_ptr: cute.Pointer | None,
|
||||
o_ptr: cute.Pointer,
|
||||
M: Int32,
|
||||
eps: Float32,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
"""Host function to launch the RMSNorm kernel."""
|
||||
cfg = self.cfg
|
||||
|
||||
# Create CuTe tensors from raw pointers
|
||||
mX = cute.make_tensor(
|
||||
x_ptr,
|
||||
cute.make_layout((M, cfg.N), stride=(cfg.N, 1)),
|
||||
)
|
||||
mO = cute.make_tensor(
|
||||
o_ptr,
|
||||
cute.make_layout((M, cfg.N), stride=(cfg.N, 1)),
|
||||
)
|
||||
|
||||
if cutlass.const_expr(cfg.has_weight and w_ptr is not None):
|
||||
mW = cute.make_tensor(
|
||||
w_ptr,
|
||||
cute.make_layout((cfg.N,), stride=(1,)),
|
||||
)
|
||||
else:
|
||||
mW = None
|
||||
|
||||
# Create TV layout using static helper
|
||||
tv_shape, tv_stride = RMSNormConfig._make_tv_layout(
|
||||
cfg.threads_per_row,
|
||||
cfg.rows_per_block,
|
||||
cfg.vec_size,
|
||||
cfg.num_vec_blocks,
|
||||
)
|
||||
tv_layout = cute.make_layout(tv_shape, stride=tv_stride)
|
||||
tiler_mn = (cfg.rows_per_block, cfg.cols_per_tile)
|
||||
|
||||
self.kernel(mX, mW, mO, eps, tv_layout, tiler_mn).launch(
|
||||
grid=[cute.ceil_div(M, cfg.rows_per_block), cfg.cluster_n, 1],
|
||||
block=[cfg.num_threads, 1, 1],
|
||||
cluster=[1, cfg.cluster_n, 1] if cutlass.const_expr(cfg.cluster_n > 1) else None,
|
||||
smem=cfg.smem_size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mX: cute.Tensor,
|
||||
mW: cute.Tensor | None,
|
||||
mO: cute.Tensor,
|
||||
eps: Float32,
|
||||
tv_layout: cute.Layout,
|
||||
tiler_mn: cute.Shape,
|
||||
):
|
||||
"""Device kernel implementing RMSNorm with cluster support."""
|
||||
cfg = self.cfg
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
if cutlass.const_expr(cfg.cluster_n > 1):
|
||||
cluster_y = cute.arch.block_idx()[1]
|
||||
else:
|
||||
cluster_y = cutlass.const_expr(0)
|
||||
|
||||
M = mX.shape[0]
|
||||
threads_per_row = tv_layout.shape[0][0]
|
||||
warps_per_row = max(threads_per_row // 32, 1)
|
||||
rows_per_block = tiler_mn[0]
|
||||
|
||||
# =====================================================================
|
||||
# Allocate shared memory
|
||||
# =====================================================================
|
||||
smem = utils.SmemAllocator()
|
||||
|
||||
sX = smem.allocate_tensor(
|
||||
mX.element_type,
|
||||
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
||||
byte_alignment=16,
|
||||
)
|
||||
|
||||
if cutlass.const_expr(cfg.cluster_n == 1):
|
||||
reduction_buffer = smem.allocate_tensor(
|
||||
Float32,
|
||||
cute.make_layout((rows_per_block, warps_per_row)),
|
||||
byte_alignment=4,
|
||||
)
|
||||
mbar_ptr = None
|
||||
else:
|
||||
reduction_buffer = smem.allocate_tensor(
|
||||
Float32,
|
||||
cute.make_layout((rows_per_block, (warps_per_row, cfg.cluster_n))),
|
||||
byte_alignment=4,
|
||||
)
|
||||
mbar_ptr = smem.allocate_array(Int64, num_elems=1)
|
||||
|
||||
# =====================================================================
|
||||
# Initialize cluster
|
||||
# =====================================================================
|
||||
if cutlass.const_expr(cfg.cluster_n > 1):
|
||||
if tidx == 0:
|
||||
cute.arch.mbarrier_init(mbar_ptr, 1)
|
||||
cute.arch.mbarrier_init_fence()
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
cute.arch.cluster_wait()
|
||||
|
||||
# =====================================================================
|
||||
# Create identity tensor and partition
|
||||
# =====================================================================
|
||||
idX = cute.make_identity_tensor(mX.shape)
|
||||
|
||||
gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y))
|
||||
gO = cute.local_tile(mO, tiler_mn, (bidx, cluster_y))
|
||||
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
||||
|
||||
if cutlass.const_expr(cfg.has_weight and mW is not None):
|
||||
mW_expanded_layout = cute.prepend(
|
||||
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
||||
)
|
||||
mW_2d = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
||||
gW = cute.local_tile(mW_2d, tiler_mn, (0, cluster_y))
|
||||
|
||||
# =====================================================================
|
||||
# Create TiledCopy operations
|
||||
# =====================================================================
|
||||
copy_atom_load_async = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(),
|
||||
mX.element_type,
|
||||
num_bits_per_copy=RMSNormConfig.COPY_BITS,
|
||||
)
|
||||
|
||||
copy_atom_load_W = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
mX.element_type,
|
||||
num_bits_per_copy=RMSNormConfig.COPY_BITS,
|
||||
)
|
||||
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
mO.element_type,
|
||||
num_bits_per_copy=RMSNormConfig.COPY_BITS,
|
||||
)
|
||||
|
||||
tiled_copy_load = cute.make_tiled_copy(copy_atom_load_async, tv_layout, tiler_mn)
|
||||
tiled_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn)
|
||||
tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
|
||||
|
||||
thr_copy_X = tiled_copy_load.get_slice(tidx)
|
||||
thr_copy_W = tiled_copy_W.get_slice(tidx)
|
||||
thr_copy_O = tiled_copy_store.get_slice(tidx)
|
||||
|
||||
# Partition tensors
|
||||
tXgX = thr_copy_X.partition_S(gX)
|
||||
tXsX = thr_copy_X.partition_D(sX)
|
||||
tXgO = thr_copy_O.partition_D(gO)
|
||||
tXcX = thr_copy_X.partition_S(cX)
|
||||
|
||||
# Register fragments
|
||||
tXrX = cute.make_fragment_like(tXgX)
|
||||
tXrO = cute.make_fragment_like(tXgO)
|
||||
|
||||
if cutlass.const_expr(cfg.has_weight and mW is not None):
|
||||
tWgW = thr_copy_W.partition_S(gW)
|
||||
tWrW = cute.make_fragment_like(tWgW)
|
||||
tXrW = thr_copy_X.retile(tWrW)
|
||||
|
||||
# =====================================================================
|
||||
# Bounds checking
|
||||
# =====================================================================
|
||||
tXpX = predicate_k(tXcX, limit=cfg.N)
|
||||
|
||||
row_coord = tXcX[(0, 0), 0, 0]
|
||||
row_in_bounds = row_coord[0] < M
|
||||
|
||||
# =====================================================================
|
||||
# Async copy global → shared
|
||||
# =====================================================================
|
||||
if row_in_bounds:
|
||||
cute.copy(copy_atom_load_async, tXgX, tXsX, pred=tXpX)
|
||||
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# Load weight while waiting
|
||||
if cutlass.const_expr(cfg.has_weight and mW is not None):
|
||||
tWpW = predicate_k(thr_copy_W.partition_S(cX), limit=cfg.N)
|
||||
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
||||
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
|
||||
# =====================================================================
|
||||
# Pass 1: Compute sum of squares with cluster reduction
|
||||
# =====================================================================
|
||||
cute.autovec_copy(tXsX, tXrX)
|
||||
x = tXrX.load().to(Float32)
|
||||
|
||||
x_sq = x * x
|
||||
sum_sq = row_reduce(
|
||||
x_sq,
|
||||
cute.ReductionOp.ADD,
|
||||
threads_per_row,
|
||||
reduction_buffer,
|
||||
mbar_ptr,
|
||||
cfg.cluster_n,
|
||||
Float32(0.0),
|
||||
)
|
||||
|
||||
# rstd = 1 / sqrt(mean(x²) + eps)
|
||||
mean_sq = sum_sq / cfg.N
|
||||
rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True)
|
||||
|
||||
# Sync after reduction
|
||||
if cutlass.const_expr(cfg.cluster_n > 1):
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cute.arch.barrier()
|
||||
|
||||
# =====================================================================
|
||||
# Pass 2: Normalize and output
|
||||
# =====================================================================
|
||||
cute.autovec_copy(tXsX, tXrX)
|
||||
x = tXrX.load().to(Float32)
|
||||
|
||||
y = x * rstd
|
||||
|
||||
# Apply weight if present
|
||||
if cutlass.const_expr(cfg.has_weight and mW is not None):
|
||||
w = tXrW.load().to(Float32)
|
||||
y = y * w
|
||||
|
||||
# Store to global memory
|
||||
tXrO.store(y.to(cfg.dtype))
|
||||
|
||||
if row_in_bounds:
|
||||
cute.copy(copy_atom_store, tXrO, tXgO, pred=tXpX)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kernel Compilation and Caching
|
||||
# =============================================================================
|
||||
|
||||
# Mapping from torch dtype to cutlass dtype
|
||||
_torch_to_cutlass_dtype = {
|
||||
torch.float16: cutlass.Float16,
|
||||
torch.bfloat16: cutlass.BFloat16,
|
||||
torch.float32: cutlass.Float32,
|
||||
}
|
||||
|
||||
# Cache for compiled kernels
|
||||
_compile_cache: dict = {}
|
||||
|
||||
|
||||
def get_compiled_kernel(
|
||||
dtype: type[cutlass.Numeric],
|
||||
N: int,
|
||||
has_weight: bool,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
"""
|
||||
Get or compile the RMSNorm kernel for the given configuration.
|
||||
|
||||
Uses compilation cache to avoid recompiling for same (dtype, N, has_weight) tuples.
|
||||
|
||||
:param dtype: Data type (Float16, BFloat16, Float32)
|
||||
:type dtype: type[cutlass.Numeric]
|
||||
:param N: Hidden dimension size
|
||||
:type N: int
|
||||
:param has_weight: Whether weight is applied
|
||||
:type has_weight: bool
|
||||
:param stream: CUDA stream
|
||||
:type stream: cuda.CUstream
|
||||
:return: Compiled kernel function
|
||||
"""
|
||||
key = (dtype, N, has_weight)
|
||||
if key not in _compile_cache:
|
||||
kernel_obj = RMSNormKernel(dtype, N, has_weight)
|
||||
|
||||
# Compile with representative arguments
|
||||
compiled_kernel = cute.compile(
|
||||
kernel_obj,
|
||||
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16), # x_ptr
|
||||
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16)
|
||||
if has_weight
|
||||
else None, # w_ptr
|
||||
make_ptr(dtype, 16, cute.AddressSpace.gmem, assumed_align=16), # o_ptr
|
||||
Int32(1), # M (dummy)
|
||||
Float32(1e-6), # eps (dummy)
|
||||
stream,
|
||||
)
|
||||
|
||||
_compile_cache[key] = compiled_kernel
|
||||
|
||||
return _compile_cache[key]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tensor Creation Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_tensors(
|
||||
M: int,
|
||||
N: int,
|
||||
dtype: type[cutlass.Numeric],
|
||||
has_weight: bool,
|
||||
) -> Tuple:
|
||||
"""Create input, weight, and output tensors for RMSNorm."""
|
||||
torch.manual_seed(42)
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
|
||||
x = torch.randn(M, N, device="cuda", dtype=torch_dtype)
|
||||
weight = torch.randn(N, device="cuda", dtype=torch_dtype) if has_weight else None
|
||||
out = torch.empty_like(x)
|
||||
|
||||
return x, weight, out
|
||||
|
||||
|
||||
def rmsnorm_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor | None = None,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""Reference RMSNorm implementation in PyTorch."""
|
||||
x_f32 = x.float()
|
||||
rms = torch.sqrt(torch.mean(x_f32**2, dim=-1, keepdim=True) + eps)
|
||||
x_norm = x_f32 / rms
|
||||
if weight is not None:
|
||||
x_norm = x_norm * weight.float()
|
||||
return x_norm.to(x.dtype)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run Function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run(
|
||||
M: int,
|
||||
N: int,
|
||||
dtype: type[cutlass.Numeric],
|
||||
has_weight: bool = True,
|
||||
eps: float = 1e-6,
|
||||
tolerance: float = 1e-2,
|
||||
warmup_iterations: int = 2,
|
||||
iterations: int = 100,
|
||||
skip_ref_check: bool = False,
|
||||
benchmark: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Execute RMSNorm and optionally benchmark performance.
|
||||
|
||||
:param M: Number of rows (batch size * sequence length)
|
||||
:type M: int
|
||||
:param N: Hidden dimension size
|
||||
:type N: int
|
||||
:param dtype: Data type (Float16, BFloat16, Float32)
|
||||
:type dtype: type[cutlass.Numeric]
|
||||
:param has_weight: Whether to apply learnable weight
|
||||
:type has_weight: bool
|
||||
:param eps: Epsilon for numerical stability
|
||||
:type eps: float
|
||||
:param tolerance: Tolerance for correctness check
|
||||
:type tolerance: float
|
||||
:param warmup_iterations: Warmup iterations for benchmarking
|
||||
:type warmup_iterations: int
|
||||
:param iterations: Number of benchmark iterations
|
||||
:type iterations: int
|
||||
:param skip_ref_check: Skip reference correctness check
|
||||
:type skip_ref_check: bool
|
||||
:param benchmark: Enable benchmarking
|
||||
:type benchmark: bool
|
||||
:return: Execution time in microseconds (if benchmark=True, else 0)
|
||||
:rtype: float
|
||||
"""
|
||||
print("Running RMSNorm test with:")
|
||||
print(f" M: {M}, N: {N}")
|
||||
print(f" dtype: {dtype}")
|
||||
print(f" has_weight: {has_weight}")
|
||||
print(f" eps: {eps}")
|
||||
print(f" SM version: {get_sm_version()}")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA GPU is required to run this example!")
|
||||
|
||||
# Get CUDA stream
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
# Create tensors
|
||||
x, weight, out = create_tensors(M, N, dtype, has_weight)
|
||||
|
||||
# Get configuration info
|
||||
config = RMSNormConfig(dtype, N, has_weight)
|
||||
print(f" cluster_n: {config.cluster_n}")
|
||||
print(f" threads_per_row: {config.threads_per_row}")
|
||||
print(f" rows_per_block: {config.rows_per_block}")
|
||||
|
||||
# Get compiled kernel
|
||||
compiled_kernel = get_compiled_kernel(dtype, N, has_weight, stream)
|
||||
|
||||
# Create pointers for kernel call
|
||||
x_ptr = make_ptr(dtype, x.data_ptr())
|
||||
w_ptr = make_ptr(dtype, weight.data_ptr()) if weight is not None else None
|
||||
out_ptr = make_ptr(dtype, out.data_ptr())
|
||||
|
||||
# Run kernel and verify
|
||||
if not skip_ref_check:
|
||||
compiled_kernel(x_ptr, w_ptr, out_ptr, Int32(M), Float32(eps), stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ref = rmsnorm_ref(x, weight, eps)
|
||||
torch.testing.assert_close(out, ref, atol=tolerance, rtol=tolerance)
|
||||
print("Correctness check passed!")
|
||||
|
||||
if not benchmark:
|
||||
return 0.0
|
||||
|
||||
# Benchmark
|
||||
print(f"\nBenchmarking with {warmup_iterations} warmup, {iterations} iterations...")
|
||||
|
||||
def generate_tensors():
|
||||
x, weight, out = create_tensors(M, N, dtype, has_weight)
|
||||
x_ptr = make_ptr(dtype, x.data_ptr())
|
||||
w_ptr = make_ptr(dtype, weight.data_ptr()) if weight is not None else None
|
||||
out_ptr = make_ptr(dtype, out.data_ptr())
|
||||
return testing.JitArguments(x_ptr, w_ptr, out_ptr, Int32(M), Float32(eps), stream)
|
||||
|
||||
exec_time_us = testing.benchmark(
|
||||
compiled_kernel,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=10,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# Calculate throughput
|
||||
torch_dtype = cutlass_torch.dtype(dtype)
|
||||
bytes_per_elem = torch.tensor([], dtype=torch_dtype).element_size()
|
||||
total_bytes = M * N * bytes_per_elem * 2 # read x, write out
|
||||
if has_weight:
|
||||
total_bytes += N * bytes_per_elem # read weight (amortized across M)
|
||||
|
||||
throughput_gbps = (total_bytes / (exec_time_us / 1e6)) / 1e9
|
||||
|
||||
print(f"Kernel execution time: {exec_time_us:.2f} us")
|
||||
print(f"Memory throughput: {throughput_gbps:.2f} GB/s")
|
||||
|
||||
return exec_time_us
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Entry Point
|
||||
# =============================================================================
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="RMSNorm kernel example for Blackwell (SM100)"
|
||||
)
|
||||
|
||||
parser.add_argument("--M", type=int, default=2048, help="Number of rows")
|
||||
parser.add_argument("--N", type=int, default=4096, help="Hidden dimension size")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=cutlass.dtype,
|
||||
default=cutlass.BFloat16,
|
||||
help="Data type (Float16, BFloat16, Float32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--has_weight",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Apply learnable weight",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_weight",
|
||||
action="store_true",
|
||||
help="Disable learnable weight",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eps",
|
||||
type=float,
|
||||
default=1e-6,
|
||||
help="Epsilon for numerical stability",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Tolerance for correctness check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Warmup iterations for benchmarking",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ref_check",
|
||||
action="store_true",
|
||||
help="Skip reference correctness check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Enable benchmarking",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle weight flag
|
||||
has_weight = args.has_weight and not args.no_weight
|
||||
|
||||
run(
|
||||
M=args.M,
|
||||
N=args.N,
|
||||
dtype=args.dtype,
|
||||
has_weight=has_weight,
|
||||
eps=args.eps,
|
||||
tolerance=args.tolerance,
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
skip_ref_check=args.skip_ref_check,
|
||||
benchmark=args.benchmark,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user