mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
Treat negative zero as equivalent to positive zero in sm90_sparse_gemm_compressor.hpp (#2110)
* Treat negative zero as zero in the sparse gemm compressor Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * format Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * Apply patch Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * sm90_sparse_gemm_compressor.hpp * test/unit/transform/CMakeLists.txt * test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp * include/cutlass/numeric_types.h --------- Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com> Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3fe62887d8
commit
8c4d1dc47d
@@ -34,8 +34,17 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/tfloat32.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/uint128.h"
|
||||
#include "cutlass/exmy_base.h"
|
||||
#include "cutlass/float_subbyte.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
@@ -56,6 +65,30 @@ struct index_sequence_helper<0, 0, Next...> {
|
||||
template <size_t N>
|
||||
using make_index_sequence = typename index_sequence_helper<N>::type;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Default case - no negative zero
|
||||
template <typename T>
|
||||
struct has_negative_zero : CUTE_STL_NAMESPACE::false_type{};
|
||||
|
||||
// Float types that support negative zero
|
||||
template <> struct has_negative_zero<mx_float4_t<float_e2m1_t>> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<mx_float6_t<float_e2m3_t>> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<mx_float8_t<float_e4m3_t>> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<mx_float8_t<float_e5m2_t>> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<float_e2m1_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<float_e2m3_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<float_e4m3_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<float_e5m2_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<half_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<bfloat16_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<float> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<double> : CUTE_STL_NAMESPACE::true_type{};
|
||||
template <> struct has_negative_zero<tfloat32_t> : CUTE_STL_NAMESPACE::true_type{};
|
||||
|
||||
// Helper variable template
|
||||
template <typename T>
|
||||
inline constexpr bool has_negative_zero_v = has_negative_zero<T>::value;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@@ -75,14 +108,3 @@ struct get_unpacked_element_type {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/tfloat32.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/uint128.h"
|
||||
#include "cutlass/exmy_base.h"
|
||||
#include "cutlass/float_subbyte.h"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
||||
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v
|
||||
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
|
||||
|
||||
namespace cutlass::transform::kernel {
|
||||
@@ -376,6 +377,15 @@ private:
|
||||
copy_vec_pred<true, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
|
||||
}
|
||||
|
||||
// Construct a sign bit mask for handling negative zeros
|
||||
ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 };
|
||||
if constexpr (has_negative_zero_v<ElementA>) {
|
||||
ElementAMmaRawUnit one_sign_mask = static_cast<ElementAMmaRawUnit>(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v<ElementA> - 1)));
|
||||
for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) {
|
||||
sign_mask = static_cast<ElementAMmaRawUnit>((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v<ElementA>));
|
||||
}
|
||||
}
|
||||
|
||||
// * Compress
|
||||
// cACsAC is always row major order
|
||||
// TensorEAtomM threads perform the compression, each thread compress one row
|
||||
@@ -401,7 +411,14 @@ private:
|
||||
CUTE_UNROLL
|
||||
for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) {
|
||||
ElementAMmaRawUnit elem_A = tAsA[elt_log_idx];
|
||||
if ( elem_A != ElementAMmaRawUnit{0} ) {
|
||||
|
||||
// Handle negative 0
|
||||
ElementAMmaRawUnit masked_elem_A = elem_A;
|
||||
if constexpr (has_negative_zero_v<ElementA>) {
|
||||
masked_elem_A = elem_A & sign_mask;
|
||||
}
|
||||
|
||||
if (masked_elem_A != ElementAMmaRawUnit{0}) {
|
||||
non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx;
|
||||
tACsAC[non_zero_cnt] = elem_A;
|
||||
non_zero_cnt++;
|
||||
@@ -489,6 +506,7 @@ private:
|
||||
|
||||
constexpr bool IsRowMajor = cute::is_same_v<LayoutTag, cutlass::layout::RowMajor>;
|
||||
using Element = typename TensorSrc::element_type;
|
||||
|
||||
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dSrc))>, "shape(dSrc) needs to be static");
|
||||
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dDst))>, "shape(dDst) needs to be static");
|
||||
CUTE_STATIC_ASSERT(cute::sizeof_bits_v<typename TensorSrc::element_type> == cute::sizeof_bits_v<typename TensorDst::element_type>,
|
||||
@@ -530,7 +548,7 @@ private:
|
||||
for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
|
||||
const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
|
||||
const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
|
||||
if constexpr ( (not pred)
|
||||
if constexpr ( (not pred)
|
||||
) {
|
||||
dDst(row_i, col_i) = dSrc(row_i, col_i);
|
||||
}
|
||||
|
||||
@@ -35,16 +35,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm> // std::fill
|
||||
#include <array> // std::array
|
||||
#include <random> // std::mt19937
|
||||
#include <algorithm> // std::fill
|
||||
#include <array> // std::array
|
||||
#include <random> // std::mt19937
|
||||
|
||||
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
|
||||
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
|
||||
#include "cutlass/arch/arch.h" // cutlass::arch::SmXY
|
||||
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
|
||||
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
|
||||
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
|
||||
#include "cutlass/arch/arch.h" // cutlass::arch::SmXY
|
||||
#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false
|
||||
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
|
||||
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
||||
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
||||
|
||||
#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp"
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(threadblock)
|
||||
add_subdirectory(device)
|
||||
add_subdirectory(kernel)
|
||||
|
||||
add_custom_target(
|
||||
|
||||
@@ -155,7 +155,9 @@ namespace detail {
|
||||
int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx;
|
||||
subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0);
|
||||
|
||||
if (subchunk_elems[elem_idx] != ElementA(0)) {
|
||||
ElementA zero = static_cast<ElementA>(0);
|
||||
ElementA minus_zero = static_cast<ElementA>(ElementA(1) << cutlass::sizeof_bits_v<ElementA> - 1);
|
||||
if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) {
|
||||
if (non_zero_cnt >= PhysicalSubChunk) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
asm volatile ("brkpt;\n" ::);
|
||||
|
||||
Reference in New Issue
Block a user