Add elementwise with dynamic vector dim (#1198)

* Add elementwise with dynamic vector dim

* Reduce number of instaces

* Fixes

* Fixes

[ROCm/composable_kernel commit: 9c052804a7]
This commit is contained in:
Bartłomiej Kocot
2024-03-22 10:40:43 +01:00
committed by GitHub
parent 5f84554b12
commit aa64a8be0a
28 changed files with 2157 additions and 359 deletions

View File

@@ -37,6 +37,20 @@ static void print_helper_msg()
// clang-format on
}
void init_strides(const std::vector<ck::index_t>& lengths,
const std::vector<ck::index_t>& dims_order,
std::vector<ck::index_t>& strides)
{
ck::index_t stride = 1;
for(ck::index_t d = lengths.size() - 1; d >= 0; d--)
{
ck::index_t dim = dims_order[d];
strides[dim] = stride;
stride *= lengths[dim];
}
}
} // namespace
int profile_permute_scale(int argc, char* argv[])
@@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[])
const int num_dims = dims_argc / 3;
std::vector<ck::index_t> lengths(num_dims);
std::vector<ck::index_t> input_strides(num_dims);
std::vector<ck::index_t> output_strides(num_dims);
std::vector<ck::index_t> input_dims_order(num_dims);
std::vector<ck::index_t> output_dims_order(num_dims);
for(int i = 0; i < num_dims; i++)
{
lengths[i] = std::stoi(argv[control_argc + i]);
input_strides[i] = std::stoi(argv[control_argc + num_dims + i]);
output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
lengths[i] = std::stoi(argv[control_argc + i]);
input_dims_order[i] = std::stoi(argv[control_argc + num_dims + i]);
output_dims_order[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
}
std::vector<ck::index_t> input_strides(num_dims);
std::vector<ck::index_t> output_strides(num_dims);
init_strides(lengths, input_dims_order, input_strides);
init_strides(lengths, output_dims_order, output_strides);
using F32 = float;
using F16 = ck::half_t;