mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[TileEngine] Support for sparsity in codegen (#2128)
* Added sparsity flag in codegen * remove comments * clan formatted * added sparsity as runtime argument * updated README * updated stream config variable * fix typo for tail_num in hot loop
This commit is contained in:
20
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file → Executable file
20
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file → Executable file
@@ -10,12 +10,19 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
bool structured_sparsity,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
const ck_tile::stream_config& stream)
|
||||
{
|
||||
return GemmDispatcher::dispatch(
|
||||
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
|
||||
return GemmDispatcher::dispatch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -43,6 +50,7 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -76,6 +84,11 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(structured_sparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
@@ -153,6 +166,7 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
Reference in New Issue
Block a user