added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals

This commit is contained in:
Chao Liu
2019-04-23 17:51:14 -05:00
parent 87d8740bf5
commit 569ad66e2a
22 changed files with 2117 additions and 1107 deletions

View File

@@ -46,7 +46,7 @@ struct GeneratorTensor_3
#if 0
auto f_acc = std::plus<index_t>{};
#else
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
auto f_acc = [](auto a, auto b) { return 100 * a + b; };
#endif
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
@@ -390,8 +390,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
template <class T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
// printf("\n");
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
@@ -405,10 +403,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
ref_value = ref.mData[i];
result_value = result.mData[i];
}
// printf("{%f, %f}", double(ref.mData[i]), double(result.mData[i]));
}
// printf("\n");
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
@@ -416,38 +411,27 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 0
constexpr index_t N = 128;
constexpr index_t C = 8;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
#if 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
@@ -499,7 +483,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 1
#elif 0
// 5x5 filter, 20x86 image
constexpr index_t N = 16;
constexpr index_t C = 256;
@@ -547,7 +531,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 10
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 512;
@@ -619,9 +603,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn