diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp index 0b08b818a3..b104d2cdbb 100644 --- a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -21,16 +21,9 @@ struct DeviceBinaryElementwise : public BaseOperator { static constexpr auto I0 = Number<0>{}; - static auto MakeDescriptor_M0_1d(const std::vector& shape, - const std::vector& stride, - index_t gridSize, - index_t threadPerBlock) + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock) { - // 1d desc - [m] - const auto desc_m0 = - make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0])); - - // pad const auto m0 = desc_m0.GetLength(I0); const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; const auto pad = math::integer_least_multiple(m0, loop_step) - m0; @@ -42,6 +35,17 @@ struct DeviceBinaryElementwise : public BaseOperator return desc_m0_pad; } + static auto MakeDescriptor_M0_1d(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + const auto desc_m0 = + make_naive_tensor_descriptor(make_tuple(shape[0]), make_tuple(stride[0])); + + return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock); + } + static auto MakeDescriptor_M0_2d(const std::vector& shape, const std::vector& stride, index_t gridSize, @@ -61,16 +65,7 @@ struct DeviceBinaryElementwise : public BaseOperator make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{})); - // pad - const auto m0 = desc_m0.GetLength(I0); - const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; - const auto pad = math::integer_least_multiple(m0, loop_step) - m0; - const auto desc_m0_pad = - transform_tensor_descriptor(desc_m0, - make_tuple(make_right_pad_transform(m0, pad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return desc_m0_pad; + return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock); } static auto MakeDescriptor_M0(const std::vector& shape,