mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Support any dimension for elementwise operation
This commit is contained in:
@@ -35,54 +35,30 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
const auto desc_m0 =
|
||||
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
const int m = shape[0];
|
||||
const int n = shape[1];
|
||||
|
||||
// 2d desc - [m, n]
|
||||
const auto desc_m_n =
|
||||
make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1]));
|
||||
|
||||
// 1d desc - [m * n]
|
||||
const auto desc_m0 =
|
||||
transform_tensor_descriptor(desc_m_n,
|
||||
make_tuple(make_merge_transform(make_tuple(m, n))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
static_assert(Dim == 1 || Dim == 2,
|
||||
"wrong! DeviceBinaryElementwise not support this dimension");
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
|
||||
// TODO - 3D, 4D, 5D
|
||||
if constexpr(Dim == 1)
|
||||
return MakeDescriptor_M0_1d(shape, stride, gridSize, threadPerBlock);
|
||||
else if constexpr(Dim == 2)
|
||||
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
else
|
||||
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
|
||||
return PadDescriptor_M0_1d(desc, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
@@ -169,7 +145,7 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
// m * n
|
||||
// shape[0] * shape[1] * shape[2] * ...
|
||||
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);
|
||||
|
||||
if(m0 % ScalarPerVector != 0)
|
||||
|
||||
Reference in New Issue
Block a user