mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] Add 2:4 structured sparsity support for fp16 gemm (#1957)
* add structured sparsity fp16 support for gemm * added reviewer suggestions * update changelog * update changelog * add reviewers suggestions * Minor fix * clang fix * fix doxygen
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16
|
||||
@@ -64,6 +67,14 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK
|
||||
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
|
||||
using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmfmac<
|
||||
WarpGemmAttributeSmfmacImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
// bf16
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
|
||||
80
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp
Normal file
80
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp
Normal file
@@ -0,0 +1,80 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Class describing structured sparsity mfma instructions.
|
||||
*
|
||||
* @paragraph Overview "Overview"
|
||||
* Currently only 2:4 structured sparsity is supported, which is based on requirement that in every
|
||||
* groups of four continuous elements there are at most two non-zero, which results in processing
|
||||
* only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac
|
||||
* instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of
|
||||
* non-zero elements are stored in `index` which is an additional parameter to assembly instruction.
|
||||
* Every pair of two bit indexes are containing information about which two elements in current
|
||||
* group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity
|
||||
* format is supported only for A matrix for now.
|
||||
*/
|
||||
template <typename WarpGemmAttributeSmfmacImpl_>
|
||||
struct WarpGemmAttributeSmfmac
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeSmfmacImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using IdxDataType = typename Impl::IdxDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
static constexpr index_t kCompressionRatio = Impl::CompressionRatio;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
|
||||
"Multi-block WarpGemmAttributeSmfmacImpl is not supported");
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using BWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using CWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, idx, bool_constant<post_nop_>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16 2:4 structured sparsity
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using IdxDataType = int32_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t CompressionRatio = 2;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using IdxDataType = int32_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t CompressionRatio = 2;
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& idx,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = idx;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,7 +16,8 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -35,6 +36,10 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
@@ -70,7 +75,8 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
@@ -78,6 +84,7 @@ using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA>::Type;
|
||||
SwizzleA,
|
||||
UseStructuredSparsity>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
110
include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp
Normal file
110
include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp
Normal file
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WarpGemmAttribute_>
|
||||
struct WarpGemmSmfmacImpl
|
||||
{
|
||||
using WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>;
|
||||
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
/// @brief The number of elements in K dimension processed by single thread in wavefront.
|
||||
///
|
||||
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
|
||||
/// In such situation this value reflects this fact.
|
||||
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
using CDataType = typename WarpGemmAttribute::CDataType;
|
||||
|
||||
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
|
||||
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
|
||||
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
|
||||
|
||||
using AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>;
|
||||
using BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>;
|
||||
using CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>;
|
||||
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
|
||||
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
|
||||
{
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
|
||||
/// elements into lower part of a_vec to half its effective size.
|
||||
///
|
||||
/// @param a_vec Vector to be compressed.
|
||||
///
|
||||
/// @return Four 2-bit indexes of non-zero elements locations
|
||||
///
|
||||
template <typename AVec>
|
||||
CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const
|
||||
{
|
||||
int32_t idx = 0b11101110;
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
|
||||
int32_t non_zero_pos = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto j) {
|
||||
if(a_vec[i * 4 + j] != 0.0f)
|
||||
{
|
||||
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
|
||||
idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos));
|
||||
idx |= j << 2 * (i * 2 + non_zero_pos);
|
||||
++non_zero_pos;
|
||||
}
|
||||
});
|
||||
a_vec[i * 2] = nonzero_elems[0];
|
||||
a_vec[i * 2 + 1] = nonzero_elems[1];
|
||||
});
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using AVecCompressed =
|
||||
ext_vector_t<ADataType, ATensor::get_thread_buffer_size() / CompressionRatio>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
const int32_t idx = compress_a(a_vec);
|
||||
|
||||
// @TODO can we simply set a_vec_pruned to a_vec[0:3]?
|
||||
const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]};
|
||||
|
||||
// c_vec += a_vec * b_vec[idx]
|
||||
WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user