From 3406a1148adf283f31a345549b63de633a4ff61e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 27 Jan 2020 15:29:33 -0600 Subject: [PATCH] Update for recent MIOpen integration (#11) * update for MIOpen integration --- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 6 +- ..._v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp | 2 +- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 16 +-- ...data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp | 16 +-- ...data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 71 ++++++++--- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 7 +- .../multi_index_transform.hpp | 14 ++- .../threadwise_generic_tensor_slice_copy.hpp | 75 +++++------ .../include/utility/amd_inline_asm.hpp | 7 -- .../include/utility/config.amd.hpp.in | 7 +- .../utility/in_memory_operation.amd.hpp.in | 2 +- .../utility/in_memory_operation.nvidia.hpp.in | 2 +- composable_kernel/include/utility/math.hpp | 18 +-- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 8 +- ...data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp | 8 +- ...data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 119 +++++++++--------- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 42 ++++++- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 6 +- driver/src/conv_bwd_data_driver.cpp | 30 ++--- driver/src/conv_driver.cpp | 12 +- 20 files changed, 270 insertions(+), 198 deletions(-) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 8221f32358..fa3c6f2ffb 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw const Float* __restrict__ p_wei_global, const Float* __restrict__ p_out_global) const { - constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; @@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw "be violated"); // output tensor - constexpr auto out_n_k_howo_global_desc = - unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3); - constexpr auto out_k_b_global_desc = - transform_tensor_descriptor(out_n_k_howo_global_desc, + transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), make_tuple(PassThrough{}, Merge>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp index 0211df2b5a..6f244808ce 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl } { -#if 1 // debug +#if 1 // debug // input: register to global memory, atomic add constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) ? InMemoryDataOperation::none diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 4615fae759..70a0738d8a 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw "be violated"); #endif - constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); - constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); @@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence>{}, + Sequence>{}, Embed, - Sequence>{}), + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{}, + Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{}, Embed, - Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}), + Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp index a0c94a892e..068b8c1931 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw "be violated"); #endif - constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); - constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); @@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}, Embed, - Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, + Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>, out_skip_all_out_of_bound_check>{}, Embed, - Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, + Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>, out_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index f96e99af6f..11d8f6540f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -22,8 +22,6 @@ template struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw { - __device__ void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) const + __host__ __device__ static constexpr index_t GetNumberOfGemm() + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + + return Ytilda * Xtilda; + } + + template + __device__ static void RunImpl(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) { constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; @@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw "be violated"); #endif - constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); - constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); @@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}, Embed, - Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw PassThrough{}, Embed, - Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, + Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>, out_skip_all_out_of_bound_check>{}, Embed, - Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, + Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>, out_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); // GEMM - constexpr index_t ytilda = Iter_ytilda; - constexpr index_t xtilda = Iter_xtilda; + constexpr index_t ytilda = iYTilda; + constexpr index_t xtilda = iXTilda; constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; @@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); } + + template + __device__ static void Run(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + + constexpr index_t iYTilda = GemmId / Xtilda; + constexpr index_t iXTilda = GemmId % Xtilda; + + static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda"); + + RunImpl(p_in_global, p_wei_global, p_out_global); + } }; } // namespace ck diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 099756997c..e3db0193bc 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const { - constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; @@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw // output tensor constexpr auto out_k_b_global_desc = - transform_tensor_descriptor(out_n_k_ho_wo_global_desc, - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // GEMM diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 1091c90130..681426c4d5 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -47,6 +47,9 @@ struct PassThrough } }; +// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is +// necessary +// However, the check will be skipped if SkipIsValidCheck is set to true by user // LowerLengths: Sequence<...> template // Coefficients: Sequence<...> // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] @@ -442,12 +448,12 @@ struct Embed __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { -#if 1 // debug + // skip valid check if user request it if(SkipIsValidCheck) { return true; } -#endif + bool flag = true; index_t ncorner = 1; diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index ce18a92d86..53331c10cd 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data( + transfer_data( p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); } } @@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data( + transfer_data( p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); } } @@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data(p_src, - src_nonlinear_coord.GetOffset() + - src_linear_offset, - p_src_long_vector, - buffer_offset); + transfer_data(p_src, + src_nonlinear_coord.GetOffset() + + src_linear_offset, + p_src_long_vector, + buffer_offset); } } @@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data( + transfer_data( p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); } } @@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data( + transfer_data( p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); } } @@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // has the valid/invalid mapping situation if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid()) { - move_data(p_dst_long_vector, - buffer_offset, - p_dst, - dst_nonlinear_coord.GetOffset() + dst_linear_offset); + transfer_data(p_dst_long_vector, + buffer_offset, + p_dst, + dst_nonlinear_coord.GetOffset() + + dst_linear_offset); } } }); diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index 7be6b9fe46..51ebfb9065 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -8,19 +8,12 @@ namespace ck { // outer-product: c[i,j] += inner_product(a[i], b[j]) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) { -// disable inline asm due to the compiler issue: SWDEV-202749 -///\to-do: enable the inline asm after the compiler fix -#if CK_WORKAROUND_SWDEV_202749 - c0 += a * b0; - c1 += a * b1; -#else asm volatile("\n \ v_mac_f32 %0, %2, %3 \n \ v_mac_f32 %1, %2, %4 \n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); -#endif } // outer-product: c[i,j] += inner_product(a[i], b[j]) diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index adf32ae32d..9d32cc81af 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -43,6 +43,10 @@ #define CK_USE_AMD_XDLOPS_INLINE_ASM 0 #endif +#ifndef CK_USE_AMD_XDLOPS_EMULATE +#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes +#endif + // experimental implementation #define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1 #define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0 @@ -51,9 +55,6 @@ #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 -// workaround -#define CK_WORKAROUND_SWDEV_202749 1 - namespace ck { enum AddressSpace diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 6ffe96a83a..294be0536a 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -70,7 +70,7 @@ template -__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { static_assert(DstInMemOp == InMemoryDataOperation::none || DstInMemOp == InMemoryDataOperation::atomic_add, diff --git a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in index d67059df0e..3a984a0e07 100644 --- a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in @@ -38,7 +38,7 @@ template -__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { static_assert(DstInMemOp == InMemoryDataOperation::none || DstInMemOp == InMemoryDataOperation::atomic_add, diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 7960f3ccee..20f51552f6 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs) return x < y ? x : y; } -// highest common factor +// greatest common divisor, aka highest common factor template -__host__ __device__ constexpr T hcf(T x, T y) +__host__ __device__ constexpr T gcd(T x, T y) { if(x == 0) { @@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y) if(x > y) { - return hcf(x - y, y); + return gcd(x - y, y); } - return hcf(x, y - x); + return gcd(x, y - x); } template -__host__ __device__ constexpr auto hcf(Number, Number) +__host__ __device__ constexpr auto gcd(Number, Number) { - constexpr auto result = hcf(X, Y); + constexpr auto result = gcd(X, Y); return Number{}; } template -__host__ __device__ constexpr auto hcf(X x, Ys... ys) +__host__ __device__ constexpr auto gcd(X x, Ys... ys) { - return hcf(x, ys...); + return gcd(x, ys...); } // least common multiple template __host__ __device__ constexpr T lcm(T x, T y) { - return (x * y) / hcf(x, y); + return (x * y) / gcd(x, y); } template diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 2a4dfecbf3..e82e72c179 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #endif - constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); - constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp index ac2a247a67..3b84c0ba9b 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif - constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); - constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index f6ee9d71a5..c34953d919 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -2,13 +2,18 @@ #include #include "device.hpp" #include "tensor.hpp" -#include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" namespace launcher { using namespace ck; +template +__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs) +{ + GridwiseOp::template Run(xs...); +} + template ; + KernelTimer timer; timer.Start(); - static_for<0, Ytilda, 1>{}([&](auto ytilda_) { - static_for<0, Xtilda, 1>{}([&](auto xtilda_) { - constexpr index_t ytilda = decltype(ytilda_){}; - constexpr index_t xtilda = decltype(xtilda_){}; + static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) { + constexpr index_t gemm_id = decltype(gemm_id_){}; - constexpr auto gridwise_conv = - GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - ytilda, - xtilda, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - }); + launch_kernel(run_gridwise_convolution_backward_data_v4r1, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); }); timer.End(); - float time = timer.GetElapsedTime(); printf("Elapsed time : %f ms, %f TFlop/s\n", diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 2bb353a825..07a3659856 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 1 +#if 0 // BlockSize = 256, EperBlock = 8, each thread hold 64 data constexpr index_t BlockSize = 256; @@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; +#elif 1 + // BlockSize = 256, EPerBlock = 16, each thread hold 64 data + // for 1x1 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 16; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; #elif 1 // BlockSize = 64, each thread hold 64 data constexpr index_t BlockSize = 64; diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 24f46cfa8d..f775054b58 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 +#elif 0 // BlockSize = 256, GemmKPerBlock = 16 constexpr index_t BlockSize = 256; @@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 // BlockSize = 256, GemmKPerBlock = 8 - // 1x1 filter, 8x8 image + // for 1x1 filter, vector-read-b = 4 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #elif 1 // BlockSize = 256, GemmKPerBlock = 16 - // 1x1 filter, 8x8 image + // for 1x1 filter, vector-read-b = 4 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index 9122498c37..17a0cd7e98 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -161,10 +161,10 @@ int main(int argc, char* argv[]) #elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 128; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 1024; + constexpr index_t K = 128; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -246,28 +246,28 @@ int main(int argc, char* argv[]) #endif } -#if 0 +#if 1 device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw -#elif 1 +#elif 0 device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw #elif 1 device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw #endif - (in_nchw_desc, - in_nchw_device, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); + (in_nchw_desc, + in_nchw_device, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); if(do_verification) { diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index bf3f598288..ae0dda3d4c 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) { using namespace ck; -#if 0 +#if 1 // 1x1 - constexpr index_t N = 256; - constexpr index_t C = 1024; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 1024; + constexpr index_t N = 64; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1;