Support any dimension for elementwise operation

This commit is contained in:
rocking
2022-05-18 04:56:10 +08:00
parent 06e52d902a
commit 7d44e782af
3 changed files with 137 additions and 43 deletions

View File

@@ -35,54 +35,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad;
}
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const auto desc_m0 =
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
}
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const int m = shape[0];
const int n = shape[1];
// 2d desc - [m, n]
const auto desc_m_n =
make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1]));
// 1d desc - [m * n]
const auto desc_m0 =
transform_tensor_descriptor(desc_m_n,
make_tuple(make_merge_transform(make_tuple(m, n))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
static_assert(Dim == 1 || Dim == 2,
"wrong! DeviceBinaryElementwise not support this dimension");
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// TODO - 3D, 4D, 5D
if constexpr(Dim == 1)
return MakeDescriptor_M0_1d(shape, stride, gridSize, threadPerBlock);
else if constexpr(Dim == 2)
return MakeDescriptor_M0_2d(shape, stride, gridSize, threadPerBlock);
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
}
else
return make_naive_tensor_descriptor(make_tuple(0), make_tuple(0));
return PadDescriptor_M0_1d(desc, gridSize, threadPerBlock);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
@@ -169,7 +145,7 @@ struct DeviceBinaryElementwise : public BaseOperator
if(pArg == nullptr)
return false;
// m * n
// shape[0] * shape[1] * shape[2] * ...
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);
if(m0 % ScalarPerVector != 0)