From 53094f7fae538c44588294f05de6c5505decaee7 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 15 Sep 2019 12:13:58 -0500 Subject: [PATCH] clean up --- ..._v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp | 6 - ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 61 ++- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 13 +- driver/src/driver.cpp | 22 +- driver/src/driver.cu | 487 +++++++++++++++++- 5 files changed, 544 insertions(+), 45 deletions(-) mode change 120000 => 100644 driver/src/driver.cu diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp index 26f26ee358..bedaa0cadf 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -181,12 +181,6 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}); -#if 0 - { - printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset); - } -#endif - // weight tensor // tensor descriptor in device memory, src of blockwise copy constexpr auto wei_e_k_global_desc = diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 81f5e87960..7989f7fb19 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -44,7 +44,8 @@ template + index_t WeiBlockCopyDstDataPerWrite_K, + index_t OutThreadCopyDataPerAccess_B> struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw { __device__ void Run(const Float* const __restrict__ p_in_global, @@ -133,12 +134,16 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw BlockwiseGenericTensorSliceCopy_v2, - NormalTensorCoordinate, decltype(in_e_b_block_desc.GetLengths()), InBlockCopySubLengths_E_B, InBlockCopyClusterLengths_E_B, - InBlockCopyThreadClusterArrangeOrder>( + InBlockCopyThreadClusterArrangeOrder, + InBlockCopySrcAccessOrder, + InBlockCopyDstAccessOrder, + 1, + 1, + InBlockCopyDataPerAccess_B, + InBlockCopyDataPerAccess_B>( {0, b_block_data_on_global}, {0, 0}); // weight tensor @@ -155,16 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw // operator for blockwise copy of weight into LDS // slice a tensor, and copy it into another tensor // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2< - BlockSize, - decltype(wei_e_k_global_desc), - decltype(wei_e_k_block_desc), - NormalTensorCoordinate, - NormalTensorCoordinate, - decltype(wei_e_k_block_desc.GetLengths()), - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0}); + auto blockwise_wei_copy = + BlockwiseGenericTensorSliceCopy_v2( + {0, k_block_data_on_global}, {0, 0}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -283,15 +293,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw using OutThreadCopySliceLengths = Sequence; - auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< - decltype(out_k0_k1_b_thread_desc), - decltype(out_k0_k1_b_global_desc), - NormalTensorCoordinate, - MergedTensorCoordinate, - OutThreadCopySliceLengths>({0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - b_thread_data_on_global}); + auto threadwise_out_copy = + ThreadwiseGenericTensorSliceCopy_v2r1::type, + arithmetic_sequence_gen<0, 3, 1>::type, + 2, + 2, + OutThreadCopyDataPerAccess_B, + OutThreadCopyDataPerAccess_B>( + {0, 0, 0}, + {k_thread_data_on_global / K1, + k_thread_data_on_global % K1, + b_thread_data_on_global}); for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) { diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 67d2bf7dcf..7e95e8bb68 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -3,7 +3,7 @@ #include "device.hpp" #include "tensor.hpp" #include "gridwise_convolution_kernel_wrapper.hpp" -//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" using namespace ck; @@ -33,18 +33,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto out_nkhw_desc = OutDesc{}; - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t K = out_nkhw_desc.GetLength(I1); constexpr index_t Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - std::size_t data_sz = sizeof(T); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); @@ -171,7 +164,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); constexpr auto gridwise_conv = -#if 0 +#if 1 GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw #else GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 55dacfa289..c792ead481 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -91,8 +91,8 @@ int main(int argc, char* argv[]) // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; - constexpr index_t HI = 32; - constexpr index_t WI = 32; + constexpr index_t HI = 34; + constexpr index_t WI = 34; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -100,8 +100,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 8x8 image // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% @@ -398,7 +398,7 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(in_nchw_desc, in_nchw, wei_kcyx_desc, @@ -440,6 +440,18 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); +#elif 1 + device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); #endif if(do_verification) diff --git a/driver/src/driver.cu b/driver/src/driver.cu deleted file mode 120000 index 1ca4fea9d7..0000000000 --- a/driver/src/driver.cu +++ /dev/null @@ -1 +0,0 @@ -driver.cpp \ No newline at end of file diff --git a/driver/src/driver.cu b/driver/src/driver.cu new file mode 100644 index 0000000000..65f9afdf1f --- /dev/null +++ b/driver/src/driver.cu @@ -0,0 +1,486 @@ +#include +#include +#include +#include +#include +#include "config.hpp" +#include "ConstantTensorDescriptor.hpp" +#include "device.hpp" +#include "conv_common.hpp" +#include "host_conv.hpp" +#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" +#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" +#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp" +//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" +//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" +#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" +#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp" +//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" +#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" + +struct GeneratorTensor_1 +{ + template + double operator()(Is... is) + { + return 1; + } +}; + +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + double operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +struct GeneratorTensor_3 +{ + template + double operator()(Is... is) + { + std::array dims = {{static_cast(is)...}}; + + auto f_acc = [](auto a, auto b) { return 100 * a + b; }; + + return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + double operator()(Ts... Xs) const + { + std::array dims = {{Xs...}}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + +#if 0 + constexpr index_t N = 32; + constexpr index_t C = 8; + constexpr index_t HI = 1; + constexpr index_t WI = 1; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<0, 0>; +#elif 1 + // 3x3, 34x34 + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 34; + constexpr index_t WI = 34; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1 filter, 8x8 image + // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% + constexpr index_t N = 64; + constexpr index_t C = 1536; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 8x8 image + // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51% + constexpr index_t N = 128; + constexpr index_t C = 2048; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 384; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 7x7 image + // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64% + constexpr index_t N = 128; + constexpr index_t C = 832; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 384; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 8x8 image + // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65% + constexpr index_t N = 128; + constexpr index_t C = 1280; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 384; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 14x14 image + // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50% + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 8x8 image + // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61% + constexpr index_t N = 64; + constexpr index_t C = 1536; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 384; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 28x28 image + // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69% + constexpr index_t N = 128; + constexpr index_t C = 256; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 7x7 image + // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62% + constexpr index_t N = 128; + constexpr index_t C = 832; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output + // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% + constexpr index_t N = 128; + constexpr index_t C = 288; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 384; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 1 + // 1x1 filter, 17x17 input + // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76% + constexpr index_t N = 128; + constexpr index_t C = 768; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 14x14 image + // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64% + constexpr index_t N = 128; + constexpr index_t C = 528; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 14x14 image + // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75% + constexpr index_t N = 128; + constexpr index_t C = 528; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 7x7 image + // cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52% + constexpr index_t N = 128; + constexpr index_t C = 832; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#endif + + auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence{}); + auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence{}); + auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( + in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); + + ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); + ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); + ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); + + using in_data_t = float; + using out_data_t = float; + Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); + Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); + Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); + Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(argc != 3) + { + printf("arg1: do_verification, arg2: nrepeat\n"); + exit(1); + } + + bool do_verification = atoi(argv[1]); + index_t nrepeat = atoi(argv[2]); + + if(do_verification) + { +#if 0 + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); +#elif 1 + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei_kcyx.GenerateTensorValue(gen_wei, num_thread); +#endif + } + +#if 0 + device_convolution_direct_v2_nchw_kcyx_nkhw + (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 0 + device_convolution_implicit_gemm_v1_chwn_cyxk_khwn( + in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 0 + device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 0 + device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw( + in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 0 + device_convolution_implicit_gemm_v2_chwn_cyxk_khwn( + in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 0 + device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( + (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); +#elif 0 + device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + nrepeat); +#elif 0 + device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 0 + device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + nrepeat); +#elif 0 + device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + nrepeat); +#elif 1 + device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + nrepeat); +#elif 0 + device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#endif + + if(do_verification) + { +#if 1 + if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && + ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) + { + host_winograd_3x3_convolution( + in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{}); + } + else +#endif + { + host_direct_convolution(in_nchw, + wei_kcyx, + out_nkhw_host, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}); + } + check_error(out_nkhw_host, out_nkhw_device); + +#if 0 + LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; + LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; +#endif + } +}