Unify Convolution FWD XDL 1D/2D implementation. (#93)

* Convolution ND

* Code unification across dimensions for generating tensor descriptors.
* Example
* Instances

* Move convnd f32 instance file to comply with repo structure.

* Conv 1D tensor layouts.

* Formatting and use ReferenceConv

* Reference ConvFwd supporting 1D and 2D convolution.

* Debug printing TensorLayout name.

* Conv fwd 1D instance f32

* Refactor conv ND example.

Needed to support various conv dimensio.

Needed to support various conv dimensions

* Rename conv nd example director to prevent conflicts.

* Refactor some common utility to single file.

Plus some tests.

* Refactor GetHostTensorDescriptor + UT.

* Add 1D test case.

* Test reference convolution 1d/2d

* Remove some leftovers.

* Fix convolution example error for 1D

* Refactor test check errors utility function.

* Test Conv2D Fwd XDL

* More UT for 1D case.

* Parameterize input & weight initializers.

* Rename example to prevent conflicts.

* Split convnd instance into separate files for 1d/2d

* Address review comments.

* Fix data type for flops/gbytes calculations.

* Assign example number 11.

Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Adam Osewski
2022-02-23 17:44:20 +01:00
committed by GitHub
parent 6dfb92bbef
commit 756a761727
17 changed files with 2698 additions and 108 deletions

View File

@@ -2,6 +2,7 @@
#define REFERENCE_CONV_FWD_HPP
#include <iostream>
#include <type_traits>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
@@ -10,21 +11,38 @@ namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout.
// Weights tensor supports KCYX data layout. Output tensor supports
// NKHoWo data layout.
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NumDimSpatial Number of spatial dimensions.
//
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
Argument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
@@ -32,9 +50,9 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
: input_{input},
weight_{weight},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
@@ -45,9 +63,9 @@ struct ReferenceConvFwd : public device::BaseOperator
{
}
const Tensor<InDataType>& in_n_c_hi_wi_;
const Tensor<WeiDataType>& wei_k_c_y_x_;
Tensor<OutDataType>& out_n_k_ho_wo_;
const Tensor<InDataType>& input_;
const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
@@ -59,58 +77,98 @@ struct ReferenceConvFwd : public device::BaseOperator
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvFwd::Argument;
float Run(const Argument& arg)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
if constexpr(NumDimSpatial == 1)
{
auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.wei_k_c_y_x_(k, c, y, x)));
arg.in_element_op_(v_in,
static_cast<const float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei,
static_cast<const float>(arg.weight_(k, c, x)));
v_acc += v_in * v_wei;
}
}
}
}
float v_out;
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = v_out;
};
arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_ncw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nchw,
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
return 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.input_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg, int) override
@@ -127,9 +185,9 @@ struct ReferenceConvFwd : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
@@ -138,9 +196,9 @@ struct ReferenceConvFwd : public device::BaseOperator
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo,
return Argument{input,
weight,
output,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,