mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Extract pad
This commit is contained in:
@@ -21,16 +21,9 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock)
|
||||
{
|
||||
// 1d desc - [m]
|
||||
const auto desc_m0 =
|
||||
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
|
||||
|
||||
// pad
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
@@ -42,6 +35,17 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0_1d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
index_t threadPerBlock)
|
||||
{
|
||||
const auto desc_m0 =
|
||||
make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0]));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0_2d(const std::vector<int>& shape,
|
||||
const std::vector<int>& stride,
|
||||
index_t gridSize,
|
||||
@@ -61,16 +65,7 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
// pad
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock);
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<int>& shape,
|
||||
|
||||
Reference in New Issue
Block a user