Treat negative zero as equivalent to positive zero in sm90_sparse_gemm_compressor.hpp (#2110)

* Treat negative zero as zero in the sparse gemm compressor

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* format

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* Apply patch

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* sm90_sparse_gemm_compressor.hpp

* test/unit/transform/CMakeLists.txt

* test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp

* include/cutlass/numeric_types.h

---------

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
Tyler Michael Smith
2025-03-20 22:44:17 -07:00
committed by GitHub
parent 3fe62887d8
commit 8c4d1dc47d
5 changed files with 68 additions and 24 deletions

View File

@@ -34,8 +34,17 @@
*/
#pragma once
#include "cutlass/numeric_size.h"
#include "cute/util/type_traits.hpp"
#include "cutlass/numeric_size.h"
#include "cutlass/integer_subbyte.h"
#include "cutlass/half.h"
#include "cutlass/bfloat16.h"
#include "cutlass/tfloat32.h"
#include "cutlass/float8.h"
#include "cutlass/uint128.h"
#include "cutlass/exmy_base.h"
#include "cutlass/float_subbyte.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
@@ -56,6 +65,30 @@ struct index_sequence_helper<0, 0, Next...> {
template <size_t N>
using make_index_sequence = typename index_sequence_helper<N>::type;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Default case - no negative zero
template <typename T>
struct has_negative_zero : CUTE_STL_NAMESPACE::false_type{};
// Float types that support negative zero
template <> struct has_negative_zero<mx_float4_t<float_e2m1_t>> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<mx_float6_t<float_e2m3_t>> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<mx_float8_t<float_e4m3_t>> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<mx_float8_t<float_e5m2_t>> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<float_e2m1_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<float_e2m3_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<float_e4m3_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<float_e5m2_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<half_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<bfloat16_t> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<float> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<double> : CUTE_STL_NAMESPACE::true_type{};
template <> struct has_negative_zero<tfloat32_t> : CUTE_STL_NAMESPACE::true_type{};
// Helper variable template
template <typename T>
inline constexpr bool has_negative_zero_v = has_negative_zero<T>::value;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
@@ -75,14 +108,3 @@ struct get_unpacked_element_type {
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/integer_subbyte.h"
#include "cutlass/half.h"
#include "cutlass/bfloat16.h"
#include "cutlass/tfloat32.h"
#include "cutlass/float8.h"
#include "cutlass/uint128.h"
#include "cutlass/exmy_base.h"
#include "cutlass/float_subbyte.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -46,6 +46,7 @@
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
namespace cutlass::transform::kernel {
@@ -376,6 +377,15 @@ private:
copy_vec_pred<true, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
}
// Construct a sign bit mask for handling negative zeros
ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 };
if constexpr (has_negative_zero_v<ElementA>) {
ElementAMmaRawUnit one_sign_mask = static_cast<ElementAMmaRawUnit>(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v<ElementA> - 1)));
for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) {
sign_mask = static_cast<ElementAMmaRawUnit>((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v<ElementA>));
}
}
// * Compress
// cACsAC is always row major order
// TensorEAtomM threads perform the compression, each thread compress one row
@@ -401,7 +411,14 @@ private:
CUTE_UNROLL
for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) {
ElementAMmaRawUnit elem_A = tAsA[elt_log_idx];
if ( elem_A != ElementAMmaRawUnit{0} ) {
// Handle negative 0
ElementAMmaRawUnit masked_elem_A = elem_A;
if constexpr (has_negative_zero_v<ElementA>) {
masked_elem_A = elem_A & sign_mask;
}
if (masked_elem_A != ElementAMmaRawUnit{0}) {
non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx;
tACsAC[non_zero_cnt] = elem_A;
non_zero_cnt++;
@@ -489,6 +506,7 @@ private:
constexpr bool IsRowMajor = cute::is_same_v<LayoutTag, cutlass::layout::RowMajor>;
using Element = typename TensorSrc::element_type;
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dSrc))>, "shape(dSrc) needs to be static");
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dDst))>, "shape(dDst) needs to be static");
CUTE_STATIC_ASSERT(cute::sizeof_bits_v<typename TensorSrc::element_type> == cute::sizeof_bits_v<typename TensorDst::element_type>,
@@ -530,7 +548,7 @@ private:
for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
if constexpr ( (not pred)
if constexpr ( (not pred)
) {
dDst(row_i, col_i) = dSrc(row_i, col_i);
}

View File

@@ -35,16 +35,17 @@
#pragma once
#include <algorithm> // std::fill
#include <array> // std::array
#include <random> // std::mt19937
#include <algorithm> // std::fill
#include <array> // std::array
#include <random> // std::mt19937
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
#include "cutlass/arch/arch.h" // cutlass::arch::SmXY
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
#include "cutlass/arch/arch.h" // cutlass::arch::SmXY
#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp"

View File

@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(threadblock)
add_subdirectory(device)
add_subdirectory(kernel)
add_custom_target(

View File

@@ -155,7 +155,9 @@ namespace detail {
int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx;
subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0);
if (subchunk_elems[elem_idx] != ElementA(0)) {
ElementA zero = static_cast<ElementA>(0);
ElementA minus_zero = static_cast<ElementA>(ElementA(1) << cutlass::sizeof_bits_v<ElementA> - 1);
if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) {
if (non_zero_cnt >= PhysicalSubChunk) {
#ifdef __CUDA_ARCH__
asm volatile ("brkpt;\n" ::);