refactor ConstantTensorDescriptor and functional

This commit is contained in:
Chao Liu
2019-04-16 17:36:18 -05:00
parent a2cf803c7e
commit 17f3d2d4bc
22 changed files with 390 additions and 276 deletions

View File

@@ -11,6 +11,7 @@
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
@@ -48,13 +49,10 @@ struct GeneratorTensor_3
#if 0
auto f_acc = std::plus<index_t>{};
#else
auto f_acc = [](auto a, auto b){ return 10*a + b;};
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
#endif
return std::accumulate(dims.begin(),
dims.end(),
index_t(0),
f_acc);
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
}
};
@@ -376,7 +374,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i)
{
std::size_t wo = WoPerTile * wtile + i;
std::size_t wo = WoPerTile * wtile + i;
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
}
}
@@ -435,13 +433,13 @@ int main(int argc, char* argv[])
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
@@ -505,7 +503,7 @@ int main(int argc, char* argv[])
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
@@ -666,6 +664,8 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#endif