diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 4191de7880..c21ffe500f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -158,24 +158,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // slice a merged tensor, reorder and copy to a normal tensor // this copy operator already has blockwise offset built-in auto blockwise_in_copy = -#if 0 - BlockwiseGenericTensorSliceCopy_v1 -#else - BlockwiseGenericTensorSliceCopy_v2 -#endif - ({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); + BlockwiseGenericTensorSliceCopy_v2( + {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); // weight tensor // tensor descriptor in device memory, src of blockwise copy @@ -192,24 +188,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // slice a tensor, and copy it into another tensor // this copy operator already have blockwise offset built-in auto blockwise_wei_copy = -#if 0 - BlockwiseGenericTensorSliceCopy_v1 -#else - BlockwiseGenericTensorSliceCopy_v2 -#endif - ({0, k_block_data_on_global}, {0, 0}); + BlockwiseGenericTensorSliceCopy_v2( + {0, k_block_data_on_global}, {0, 0}); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp index 959a9112d3..f73557d438 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp @@ -51,7 +51,7 @@ template struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded { -#if 1 +#if 0 __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const @@ -437,6 +437,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded "wrong! aligment requirement for vectorized global load of input tensor will " "be violated"); + // input constexpr auto in_n_c_hi_wi_global_desc = make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); @@ -465,6 +466,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // weight + constexpr auto wei_e_k_global_desc = + transform_tensor_descriptor(wei_k_c_y_x_global_desc, + make_tuple(Merge>{}, PassThrough{}), + make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { @@ -487,8 +495,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded print_array("idx1: ", idx1); print_array("idx0: ", idx0); } +#else + index_t itmp = get_block_1d_id() + get_thread_local_1d_id(); + auto wei_coord1 = make_tensor_coordinate_v2(wei_e_k_global_desc, {itmp, itmp + 1}); + + auto step_sizes = make_multi_index(EPerBlock, 0); + + wei_coord1 += step_sizes; + + p_out_global[0] = wei_coord1.GetLowerCoordinate().GetIndex()[0]; + p_out_global[1] = wei_coord1.GetLowerCoordinate().GetIndex()[1]; + p_out_global[2] = wei_coord1.GetLowerCoordinate().GetIndex()[2]; + p_out_global[3] = wei_coord1.GetLowerCoordinate().GetIndex()[3]; #endif - p_out_global[0] = in_e_n1_b_n2_global_desc.CalculateOffset({0, 0, 10, 0}); } #endif }; diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index d888f87d81..bf56678b63 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -197,7 +197,7 @@ struct Merge // do carry check in reversed order, starting from lowest dimension // don't check the highest dimension - static_for<0, nDimLow, 1>{}([&](auto ireverse) { + static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { constexpr index_t i = nDimLow - 1 - ireverse; if(carry) @@ -213,6 +213,12 @@ struct Merge carry = true; } }); + + // highest dimension, no out-of-bound check + if(carry) + { + ++idx_low_new(0); + } } else if(idx_up_diff[0] < 0) { @@ -220,7 +226,7 @@ struct Merge // do borrow check in reversed order, starting from lowest dimension // don't check the highest dimension - static_for<0, nDimLow, 1>{}([&](auto ireverse) { + static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { constexpr index_t i = nDimLow - 1 - ireverse; if(borrow) @@ -236,6 +242,12 @@ struct Merge borrow = true; } }); + + // highest dimension, no out-of-bound check + if(borrow) + { + --idx_low_new(0); + } } return idx_low_new - idx_low_old; diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 36ca649aa4..51b9e511af 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -70,7 +70,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] constexpr index_t InBlockCopySrcDataPerRead_B = 1; diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp index 93f91873e7..965f84b612 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp @@ -74,7 +74,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] constexpr index_t InBlockCopySrcDataPerRead_B = 1; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 710009c72f..045e796c38 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -74,12 +74,12 @@ int main(int argc, char* argv[]) { using namespace ck; -#if 1 - constexpr index_t N = 512; - constexpr index_t C = 16; +#if 0 + constexpr index_t N = 256; + constexpr index_t C = 64; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 512; + constexpr index_t K = 256; constexpr index_t Y = 17; constexpr index_t X = 17; @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>; -#elif 1 +#elif 0 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; @@ -117,8 +117,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 8x8 image // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51% @@ -133,8 +133,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 7x7 image // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64% @@ -149,8 +149,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 8x8 image // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65% @@ -165,8 +165,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 14x14 image // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50% @@ -181,8 +181,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 8x8 image // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61% @@ -197,8 +197,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 28x28 image // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69% @@ -213,8 +213,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 7x7 image // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62% @@ -229,25 +229,9 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output - // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% - constexpr index_t N = 128; - constexpr index_t C = 288; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; -#elif 1 // 1x1 filter, 17x17 input // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76% constexpr index_t N = 128; @@ -261,8 +245,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 14x14 image // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64% @@ -277,8 +261,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 14x14 image // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75% @@ -293,8 +277,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 7x7 image // cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52% @@ -309,8 +293,24 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 0; - constexpr index_t WPad = 0; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 1 + // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output + // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% + constexpr index_t N = 128; + constexpr index_t C = 288; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 384; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #endif auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence{});