[TileEngine] Support for sparsity in codegen (#2128)

* Added sparsity flag in codegen

* remove comments

* clan formatted

* added sparsity as runtime argument

* updated README

* updated stream config variable

* fix typo for tail_num in hot loop
This commit is contained in:
Khushbu Agarwal
2025-04-28 18:19:23 -07:00
committed by GitHub
parent 4094ad158a
commit 768c99eca9
4 changed files with 56 additions and 35 deletions

20
tile_engine/ops/gemm/gemm_host_api.cpp Normal file → Executable file
View File

@@ -10,12 +10,19 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
bool structured_sparsity,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
const ck_tile::stream_config& stream)
{
return GemmDispatcher::dispatch(
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
return GemmDispatcher::dispatch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
args,
stream);
}
template <typename ADataType,
@@ -43,6 +50,7 @@ void run(const ck_tile::ArgParser& arg_parser)
int n_repeat = arg_parser.get_int("repeat");
int verify = arg_parser.get_int("v");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -76,6 +84,11 @@ void run(const ck_tile::ArgParser& arg_parser)
b_k_n.SetZero();
}
if(structured_sparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
@@ -153,6 +166,7 @@ void run(const ck_tile::ArgParser& arg_parser)
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});