[CK_TILE] Add 2:4 structured sparsity support for fp16 gemm (#1957)

* add structured sparsity fp16 support for gemm

* added reviewer suggestions

* update changelog

* update changelog

* add reviewers suggestions

* Minor fix

* clang fix

* fix doxygen

[ROCm/composable_kernel commit: 6c61f4d237]
This commit is contained in:
jakpiase
2025-04-11 12:18:26 +02:00
committed by GitHub
parent cca9cca699
commit addcd203eb
13 changed files with 401 additions and 20 deletions

View File

@@ -93,7 +93,8 @@ struct GemmConfig
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;

View File

@@ -55,7 +55,8 @@ void permute_tensor_b(Tensor& tensor)
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC>;
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
@@ -185,13 +186,15 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
@@ -259,6 +262,11 @@ int run_gemm_example_with_layouts(int argc,
b_k_n.SetZero();
}
if(GemmConfig::UseStructuredSparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());

View File

@@ -46,7 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC>;
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;