mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Add elementwise with dynamic vector dim (#1198)
* Add elementwise with dynamic vector dim
* Reduce number of instaces
* Fixes
* Fixes
[ROCm/composable_kernel commit: 9c052804a7]
This commit is contained in:
@@ -37,6 +37,20 @@ static void print_helper_msg()
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void init_strides(const std::vector<ck::index_t>& lengths,
|
||||
const std::vector<ck::index_t>& dims_order,
|
||||
std::vector<ck::index_t>& strides)
|
||||
{
|
||||
|
||||
ck::index_t stride = 1;
|
||||
for(ck::index_t d = lengths.size() - 1; d >= 0; d--)
|
||||
{
|
||||
ck::index_t dim = dims_order[d];
|
||||
strides[dim] = stride;
|
||||
stride *= lengths[dim];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_permute_scale(int argc, char* argv[])
|
||||
@@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[])
|
||||
const int num_dims = dims_argc / 3;
|
||||
|
||||
std::vector<ck::index_t> lengths(num_dims);
|
||||
std::vector<ck::index_t> input_strides(num_dims);
|
||||
std::vector<ck::index_t> output_strides(num_dims);
|
||||
std::vector<ck::index_t> input_dims_order(num_dims);
|
||||
std::vector<ck::index_t> output_dims_order(num_dims);
|
||||
|
||||
for(int i = 0; i < num_dims; i++)
|
||||
{
|
||||
lengths[i] = std::stoi(argv[control_argc + i]);
|
||||
input_strides[i] = std::stoi(argv[control_argc + num_dims + i]);
|
||||
output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
|
||||
lengths[i] = std::stoi(argv[control_argc + i]);
|
||||
input_dims_order[i] = std::stoi(argv[control_argc + num_dims + i]);
|
||||
output_dims_order[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> input_strides(num_dims);
|
||||
std::vector<ck::index_t> output_strides(num_dims);
|
||||
init_strides(lengths, input_dims_order, input_strides);
|
||||
init_strides(lengths, output_dims_order, output_strides);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user