From e2b4c5b469db934af4b8b81d860df56be69764ce Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 5 Dec 2019 12:36:36 -0600 Subject: [PATCH] update implicit GEMM forward v4r4 to use gridwise gemm (#9) * updated fwd v4r4 to use gridwise gemm * updated gridwise gemm api calls in bwd-data v1r1 and v2r1 --- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 113 ++--- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 93 ++-- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 167 +++++++ ..._v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp | 408 ------------------ .../blockwise_generic_tensor_slice_copy.hpp | 22 +- .../tensor_operation/gridwise_gemm.hpp | 66 +-- .../threadwise_generic_tensor_slice_copy.hpp | 140 +++--- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 48 ++- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 83 ++-- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 291 ++++++------- driver/src/conv_bwd_data_driver.cpp | 2 +- driver/src/conv_driver.cpp | 8 +- 12 files changed, 599 insertions(+), 842 deletions(-) create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp delete mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 06d413278f..0b089f3987 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -8,6 +8,9 @@ namespace ck { +// GemmM = C * Y * X +// GemmN = N * Ho * Wo +// GemmK = K template + typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + index_t GemmABlockCopySrcDataPerRead_GemmN, + index_t GemmABlockCopyDstDataPerWrite_GemmN, + typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + index_t GemmBBlockCopySrcDataPerRead_GemmN, + index_t GemmBBlockCopyDstDataPerWrite_GemmN, + index_t GemmCThreadCopyDstDataPerWrite_GemmN1> struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw { __device__ void Run(Float* __restrict__ p_in_global, @@ -49,8 +54,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto True = integral_constant{}; - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; @@ -73,14 +76,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationW = ConvDilations{}[1]; - constexpr index_t E = C * Y * X; - constexpr index_t B = N * Ho * Wo; - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) && - (X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); + // TODO: this logic may not be correct for bwd-data + static_assert( + (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && + (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); // output tensor constexpr auto out_n_k_howo_global_desc = @@ -99,8 +101,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw // input tensor constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, PassThrough{}, Pad, LeftPads, RightPads>{}), + make_tuple(PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); @@ -121,33 +124,43 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw // GEMM: atomic add constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1r1{}; + GridwiseGemmTransposedANormalBNormalC_v1, + Sequence<0, 1>, + 1, + GemmABlockCopySrcDataPerRead_GemmN, + GemmABlockCopyDstDataPerWrite_GemmN, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<0, 1, 2, 3>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index ecd501b293..0e24a80c85 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -8,9 +8,9 @@ namespace ck { -// GemmK = K * Ydot * Xdot; // GemmM = C * Ytilda * Xtilda; // GemmN = N * Htilda * Wtilda; +// GemmK = K * Ydot * Xdot; template + typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + index_t GemmABlockCopySrcDataPerRead_GemmM, + index_t GemmABlockCopyDstDataPerWrite_GemmM, + typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + index_t GemmBBlockCopySrcDataPerRead_GemmN, + index_t GemmBBlockCopyDstDataPerWrite_GemmN, + index_t GemmCThreadCopyDstDataPerWrite_GemmN1> struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw { __device__ void Run(Float* __restrict__ p_in_global, @@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw constexpr index_t ConvDilationW = ConvDilations{}[1]; // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) && - (X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); + // TODO: this logic may not be correct for bwd-data + static_assert( + (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && + (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); @@ -172,33 +175,43 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw // GEMM constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1r1{}; + GridwiseGemmTransposedANormalBNormalC_v1, + Sequence<0, 1>, + 1, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<0, 1, 2, 3>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..5a4c4a1930 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,167 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; + constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; + constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; + + constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + // sanity-check for vectorized memory load + static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) && + (X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) && + InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 && + InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0, + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); + + // weight tensor + constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( + unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); + + // input tensor + constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, Sequence>{}, + Embed, Sequence>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto in_e_b_global_desc = transform_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + constexpr auto out_k_b_global_desc = + transform_tensor_descriptor(out_n_k_ho_wo_global_desc, + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // GEMM + constexpr auto gridwise_gemm = + GridwiseGemmTransposedANormalBNormalC_v1, + Sequence<1, 0>, + 0, + GemmABlockCopySrcDataPerRead_GemmK, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<0, 1, 2, 3>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index d739ebd9a6..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,408 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { -// B = merge(N, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_hi_wi_global_desc = - make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); - constexpr auto wei_k_c_y_x_global_desc = - make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides()); - constexpr auto out_n_k_ho_wo_global_desc = - make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides()); - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); - constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t E = C * Y * X; - constexpr index_t B = N * Ho * Wo; - - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) && - (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; - - // input tensor - // global mem - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, PassThrough{}, Pad, LeftPads, RightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_e_b_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // LDS mem - // be careful of LDS alignment - constexpr auto in_e_b_block_desc = - make_native_tensor_descriptor_packed(Sequence{}); - - // input blockwise copy - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v4( - {0, b_block_data_on_global}, {0, 0}); - - // weight tensor - // global mem - constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); - - // LDS - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, - "GemmDataPerReadA alignment requirement is not satisfied"); - - // weight blockwise copy - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v4( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, BPerBlock] is in LDS - // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc); - - // sanity check - static_assert( - KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 && - BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - constexpr index_t GemmNRepeat = - BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_b_block_mtx_desc), - decltype(c_k0k1_b0b1_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store last data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); - - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - } - } - - // copy output: register to global memory - { - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col; - - // src descriptor - constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - // dst descriptor - constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster; - - constexpr index_t K0 = K / K1; - constexpr index_t B0 = B / B1; - - constexpr auto out_k_b_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor( - out_k_b_global_desc, - make_tuple(UnMerge>{}, UnMerge>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // output threadwise copy - ThreadwiseGenericTensorSliceCopy_v4r2< - decltype(out_k0_k1_b0_b1_thread_desc), - decltype(out_k0_k1_b0_b1_global_desc), - decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 4, 1>::type, - 3, - OutThreadCopyDataPerAccess_B, - OutThreadCopyDataPerAccess_B, - AddressSpace::vgpr, - AddressSpace::global>({0, 0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - b_thread_data_on_global / B1, - b_thread_data_on_global % B1}) - .Run(p_out_thread, p_out_global); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 66b7b6a66f..5149be0739 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -9,6 +9,12 @@ namespace ck { +// This threadwise copy allow vector access of src and dst. +// It allows the vector size to be different on src and dst. +// The dimension of vector access can be different for src and dst. +// The dimension access order can be different for src and dst. +// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping +// Will do valid mapping check on dst data: No write if dst data has a invalid mapping template ; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index 2f203f0ec4..8a6a3f72c3 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -31,14 +31,24 @@ template -struct GridwiseGemmTransposedANormalBNormalC_v1r1 + typename ABlockCopyThreadSliceLengths_K_M, + typename ABlockCopyThreadClusterLengths_K_M, + typename ABlockCopyThreadClusterArrangeOrder, + typename ABlockCopySrcAccessOrder, + index_t ABlockCopySrcVectorReadDim, + index_t ABlockCopySrcDataPerRead, + index_t ABlockCopyDstDataPerWrite_M, + typename BBlockCopyThreadSliceLengths_K_N, + typename BBlockCopyThreadClusterLengths_K_N, + typename BBlockCopyThreadClusterArrangeOrder, + typename BBlockCopySrcAccessOrder, + index_t BBlockCopySrcVectorReadDim, + index_t BBlockCopySrcDataPerRead, + index_t BBlockCopyDstDataPerWrite_N, + typename CThreadCopySrcDstAccessOrder, + index_t CThreadCopySrcDstVectorReadWriteDim, + index_t CThreadCopyDstDataPerWrite> +struct GridwiseGemmTransposedANormalBNormalC_v1 { __device__ void Run(const Float* __restrict__ p_a_global, const Float* __restrict__ p_b_global, @@ -55,8 +65,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 constexpr auto N = b_k_n_global_desc.GetLengths()[1]; // lds max alignment - constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M, - BBlockCopyDataPerAccess_N, + constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, + BBlockCopyDstDataPerWrite_N, ThreadGemmDataPerReadM, ThreadGemmDataPerReadN); @@ -86,15 +96,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 decltype(a_k_m_global_desc), decltype(a_k_m_block_desc), decltype(a_k_m_block_desc.GetLengths()), - ABlockCopySubLengths_K_M, - ABlockCopyClusterLengths_K_M, - Sequence<0, 1>, - Sequence<0, 1>, + ABlockCopyThreadSliceLengths_K_M, + ABlockCopyThreadClusterLengths_K_M, + ABlockCopyThreadClusterArrangeOrder, + ABlockCopySrcAccessOrder, Sequence<0, 1>, + ABlockCopySrcVectorReadDim, 1, - 1, - ABlockCopyDataPerAccess_M, - ABlockCopyDataPerAccess_M, + ABlockCopySrcDataPerRead, + ABlockCopyDstDataPerWrite_M, AddressSpace::global, AddressSpace::vgpr, AddressSpace::lds, @@ -112,15 +122,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 decltype(b_k_n_global_desc), decltype(b_k_n_block_desc), decltype(b_k_n_block_desc.GetLengths()), - BBlockCopySubLengths_K_N, - BBlockCopyClusterLengths_K_N, - Sequence<0, 1>, - Sequence<0, 1>, + BBlockCopyThreadSliceLengths_K_N, + BBlockCopyThreadClusterLengths_K_N, + BBlockCopyThreadClusterArrangeOrder, + BBlockCopySrcAccessOrder, Sequence<0, 1>, + BBlockCopySrcVectorReadDim, 1, - 1, - BBlockCopyDataPerAccess_N, - BBlockCopyDataPerAccess_N, + BBlockCopySrcDataPerRead, + BBlockCopyDstDataPerWrite_N, AddressSpace::global, AddressSpace::vgpr, AddressSpace::lds, @@ -304,10 +314,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1 ThreadwiseGenericTensorSliceCopy_v4r2, - 3, - CThreadCopyDataPerAccess_N, - CThreadCopyDataPerAccess_N, + CThreadCopySrcDstAccessOrder, + CThreadCopySrcDstVectorReadWriteDim, + 1, + CThreadCopyDstDataPerWrite, AddressSpace::vgpr, AddressSpace::global, CGlobalMemoryDataOperation>( diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 784bc1a3a1..b3a4d46fc9 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -8,20 +8,19 @@ namespace ck { -// This version use multi-index transformation // This threadwise copy allow vector access of src and dst. // It allows the vector size to be different on src and dst. // The dimensions of vector access should be the same on src and dst. // The dimension access order should be the same on src and dst. -// It is designed for cases, where one of src and dst is register, and -// the other is device memory or LDS +// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping +// Will do valid mapping check on dst data: No write if dst data has a invalid mapping template @@ -39,16 +38,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && - nDim == DimAccessOrder::Size(), + nDim == SrcDstDimAccessOrder::Size(), "wrong! # of dimensions not the same"); - static_assert(is_valid_sequence_map{}, "wrong! map is not valid"); + static_assert(is_valid_sequence_map{}, "wrong! map is not valid"); - static_assert( - SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0, - "wrong! cannot evenly divide"); + static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] % + math::lcm(SrcDataPerRead, DstDataPerWrite) == + 0, + "wrong! cannot evenly divide"); - // TODO:: sanity-check if vectorized memory access is allowed on src and dst + // TODO:: sanity-check if vectorized memory read/write is allowed on src and dst } __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2() @@ -67,22 +67,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 mDstSliceOrigin = dst_slice_origin; } - // Will do padding check on src data: Read 0 if src data is in padding area. - // Will do padding check on dst data: No write if dst data is in paddin area. template __device__ void Run(const SrcData* p_src, DstData* p_dst) const { - constexpr auto vector_access_dim = Number{}; + constexpr auto vector_access_dim = Number{}; - constexpr auto src_data_per_access = Number{}; - constexpr auto dst_data_per_access = Number{}; + constexpr auto src_data_per_access = Number{}; + constexpr auto dst_data_per_access = Number{}; - constexpr auto long_vector_size = Number{}; + constexpr auto long_vector_size = Number{}; constexpr auto long_vector_access_lengths = SliceLengths::Modify( vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); - ford{}([&]( + ford{}([&]( auto long_vector_access_id) { // data id w.r.t slicing-window @@ -109,13 +107,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); - // Check src vector's padding situation, only check the first data in this src + // Check src data's valid mapping situation, only check the first data in this src // vector. It's user's responsiblity to make sure all data in the src vector - // has the same padding situation + // has the valid/invalid mapping situation if(src_coord.IsUpperIndexMappedToValidOffset()) { move_data( @@ -141,13 +139,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); - // Check dst vector's padding situation, only check the first data in this dst + // Check dst data's valid mapping situation, only check the first data in this dst // vector. It's user's responsiblity to make sure all data in the dst vector - // has the same padding situation + // has the valid/invalid mapping situation if(dst_coord.IsUpperIndexMappedToValidOffset()) { move_data( @@ -165,20 +163,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 return Sequence<(Mask ? Lengths : 1)...>{}; } - // Will do padding check on src data: Read 0 if src data is in padding area. - // Will do padding check on dst data: No write if dst data is in paddin area. + // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping + // Will do valid mapping check on dst data: No write if dst data has a invalid mapping // This version is optimized for address calculation of src tensor // TODO: this function is not compiled to expected ISA template __device__ void Run_optimized_src_address_calculation(const SrcData* p_src, DstData* p_dst) const { - constexpr auto vector_access_dim = Number{}; + constexpr auto vector_access_dim = Number{}; - constexpr auto src_data_per_access = Number{}; - constexpr auto dst_data_per_access = Number{}; + constexpr auto src_data_per_access = Number{}; + constexpr auto dst_data_per_access = Number{}; - constexpr auto long_vector_size = Number{}; + constexpr auto long_vector_size = Number{}; constexpr auto long_vector_access_lengths = SliceLengths::Modify( vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); @@ -187,10 +185,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask(); constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask(); - static_assert(src_linear_dim_mask.At(VectorAccessDim) || - long_vector_size == SrcDataPerAccess, - "Warning! VectorAccessDim is not SrcDesc's linear dimension, performance " - "would drop"); + static_assert( + src_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == SrcDataPerRead, + "Warning! SrcDstVectorReadWriteDim is not SrcDesc's linear dimension, performance " + "would drop"); // separate steps into linear and non-linear components, accoording to src tensor constexpr auto linear_long_vector_access_lengths = @@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 p_src_long_vector[i] = 0; } - // Loop over VectorAccessDim, and load data from src to the + // Loop over SrcDstVectorReadWriteDim, and load data from src to the // long-vector buffer. - // If VectorAccessDim is src's linear dimension, then src's + // If SrcDstVectorReadWriteDim is src's linear dimension, then src's // offset-diff due to this looping is known at compile-time. If - // VectorAccessDim is src's nonlinear dimension, then src's + // SrcDstVectorReadWriteDim is src's nonlinear dimension, then src's // offset-diff due to this looping is only known at run-time. For best - // performance, VectorAccessDim, should be src's linear dimension + // performance, SrcDstVectorReadWriteDim, should be src's linear dimension for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) { auto scalar_id = make_zero_array(); @@ -258,13 +256,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 src_coord.GetOffset() - src_nonlinear_coord.GetOffset(); #endif - // Check src vector's padding situation, only check the first data in - // this src vector. It's user's responsiblity to make sure all data in - // the src vector has the same padding situation + // Check src data's valid mapping situation, only check the first data in this + // src + // vector. It's user's responsiblity to make sure all data in the src vector + // has the valid/invalid mapping situation if(src_coord.IsUpperIndexMappedToValidOffset()) { move_data(p_src, @@ -296,13 +295,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps + linear_dim_data_steps + scalar_id); - // Check dst vector's padding situation, only check the first data in - // this dst vector. It's user's responsiblity to make sure all data in - // the dst vector has the same padding situation + // Check dst data's valid mapping situation, only check the first data in this + // dst + // vector. It's user's responsiblity to make sure all data in the dst vector + // has the valid/invalid mapping situation if(dst_coord.IsUpperIndexMappedToValidOffset()) { move_data( @@ -313,20 +313,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 }); } - // Will do padding check on src data: Read 0 if src data is in padding area. - // Will do padding check on dst data: No write if dst data is in paddin area. // This version is optimized for address calculation of dst tensor // TODO: this function is not compiled to expected ISA template __device__ void Run_optimized_dst_address_calculation(const SrcData* p_src, DstData* p_dst) const { - constexpr auto vector_access_dim = Number{}; + constexpr auto vector_access_dim = Number{}; - constexpr auto src_data_per_access = Number{}; - constexpr auto dst_data_per_access = Number{}; + constexpr auto src_data_per_access = Number{}; + constexpr auto dst_data_per_access = Number{}; - constexpr auto long_vector_size = Number{}; + constexpr auto long_vector_size = Number{}; constexpr auto long_vector_access_lengths = SliceLengths::Modify( vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); @@ -335,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask(); constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask(); - static_assert(dst_linear_dim_mask.At(VectorAccessDim) || - long_vector_size == DstDataPerAccess, - "Warning! VectorAccessDim is not DstDesc's linear dimension, performance " - "would drop"); + static_assert( + dst_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == DstDataPerWrite, + "Warning! SrcDstVectorReadWriteDim is not DstDesc's linear dimension, performance " + "would drop"); // separate steps into linear and non-linear components, accoording to dst tensor constexpr auto linear_long_vector_access_lengths = @@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 p_src_long_vector[i] = 0; } - // Loop over VectorAccessDim, and load data from src to the + // Loop over SrcDstVectorReadWriteDim, and load data from src to the // long-vector buffer. - // If VectorAccessDim is dst's linear dimension, then dst's + // If SrcDstVectorReadWriteDim is dst's linear dimension, then dst's // offset-diff due to this looping is known at compile-time. If - // VectorAccessDim is dst's nonlinear dimension, then dst's + // SrcDstVectorReadWriteDim is dst's nonlinear dimension, then dst's // offset-diff due to this looping is only known at run-time. For best - // performance, VectorAccessDim, should be dst's linear dimension + // performance, SrcDstVectorReadWriteDim, should be dst's linear dimension for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) { auto scalar_id = make_zero_array(); @@ -397,13 +395,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps + linear_dim_data_steps + scalar_id); - // Check src vector's padding situation, only check the first data in - // this src vector. It's user's responsiblity to make sure all data in - // the src vector has the same padding situation + // Check src data's valid mapping situation, only check the first data in this + // src + // vector. It's user's responsiblity to make sure all data in the src vector + // has the valid/invalid mapping situation if(src_coord.IsUpperIndexMappedToValidOffset()) { move_data( @@ -441,13 +440,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset(); #endif - // Check dst vector's padding situation, only check the first data in - // this dst vector. It's user's responsiblity to make sure all data in - // the dst vector has the same padding situation + // Check dst data's valid mapping situation, only check the first data in this + // dst + // vector. It's user's responsiblity to make sure all data in the dst vector + // has the valid/invalid mapping situation if(dst_coord.IsUpperIndexMappedToValidOffset()) { move_data(p_dst_long_vector, diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 526abe37bd..0d85048a54 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -11,8 +11,8 @@ template + typename InLeftPads, + typename InRightPads> void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, Tensor& in_nchw, WeiDesc wei_kcyx_desc, @@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i const Tensor& out_nkhw, ConvStrides, ConvDilations, - LeftPads, - RightPads, + InLeftPads, + InRightPads, std::size_t nrepeat) { using namespace ck; @@ -62,24 +62,26 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; - using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M - using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N - using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif constexpr index_t GemmM = C * Y * X; constexpr index_t GemmN = N * Ho * Wo; - constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * - ((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); @@ -93,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i decltype(out_nkhw_desc), ConvStrides, ConvDilations, - LeftPads, - RightPads, + InLeftPads, + InRightPads, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, @@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, - GemmABlockCopySubLengths, - GemmABlockCopyClusterLengths, - GemmABlockCopyDataPerAccess, - GemmBBlockCopySubLengths, - GemmBBlockCopyClusterLengths, - GemmBBlockCopyDataPerAccess, - GemmCThreadCopyDataPerAccess>{}; + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; for(index_t i = 0; i < nrepeat; ++i) { diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index c0b1fe9789..9e5e0cc553 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -11,8 +11,8 @@ template + typename InLeftPads, + typename InRightPads> void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, Tensor& in_nchw, WeiDesc wei_kcyx_desc, @@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i const Tensor& out_nkhw, ConvStrides, ConvDilations, - LeftPads, - RightPads, + InLeftPads, + InRightPads, std::size_t nrepeat) { using namespace ck; @@ -68,54 +68,26 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; - using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M - using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N - using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N -#elif 0 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M - using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M - - constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M - - using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N - using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N - - constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N - - constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif - // TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for - // simplicity constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong - constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong + constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); @@ -126,12 +98,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t Htilda = Ho + right_pad_ho; constexpr index_t Wtilda = Wo + right_pad_wo; - constexpr index_t GemmK = K * Ydot * Xdot; constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmN = N * Htilda * Wtilda; - constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) * - ((GemmN + GemmNPerBlock - 1) / GemmNPerBlock); + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); @@ -145,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i decltype(out_nkhw_desc), ConvStrides, ConvDilations, - LeftPads, - RightPads, + InLeftPads, + InRightPads, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, @@ -159,13 +130,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, - GemmABlockCopySubLengths, - GemmABlockCopyClusterLengths, - GemmABlockCopyDataPerAccess, - GemmBBlockCopySubLengths, - GemmBBlockCopyClusterLengths, - GemmBBlockCopyDataPerAccess, - GemmCThreadCopyDataPerAccess>{}; + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; for(index_t i = 0; i < nrepeat; ++i) { diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 767e2a824d..4077846c0e 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -3,7 +3,7 @@ #include "device.hpp" #include "tensor.hpp" #include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp" +#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" template + class InLeftPads, + class InRightPads> void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, const Tensor& in_nchw, WeiDesc, @@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, Tensor& out_nkhw, ConvStrides, ConvDilations, - LeftPads, - RightPads, + InLeftPads, + InRightPads, ck::index_t nrepeat) { using namespace ck; @@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; + constexpr auto in_nchw_desc = + make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides()); + constexpr auto wei_kcyx_desc = + make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides()); + constexpr auto out_nkhw_desc = + make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides()); constexpr index_t N = out_nkhw_desc.GetLength(I0); constexpr index_t K = out_nkhw_desc.GetLength(I1); @@ -51,198 +54,174 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); #if 1 - // BlockSize = 256, EPerBlock = 8 + // BlockSize = 256, GemmKPerBlock = 8 constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; - using InBlockCopySubLengths_E_B = Sequence<4, 1>; - using InBlockCopyClusterLengths_E_B = Sequence<2, 128>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - constexpr index_t InBlockCopyDataPerAccess_B = 1; + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - constexpr index_t OutThreadCopyDataPerAccess_B = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 - // BlockSize = 256, EPerBlock = 8 + // BlockSize = 256, GemmKPerBlock = 8 // 1x1 filter, 8x8 image constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; - using InBlockCopySubLengths_E_B = Sequence<1, 4>; - using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - constexpr index_t InBlockCopyDataPerAccess_B = 4; + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - constexpr index_t OutThreadCopyDataPerAccess_B = 4; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #elif 0 - // BlockSize = 256, EPerBlock = 16 + // BlockSize = 256, GemmKPerBlock = 16 // 1x1 filter, 8x8 image constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 16; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; - using InBlockCopySubLengths_E_B = Sequence<2, 4>; - using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; - constexpr index_t InBlockCopyDataPerAccess_B = 4; + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - constexpr index_t OutThreadCopyDataPerAccess_B = 4; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #elif 1 // 1x1 filter, 14x14 image constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; - using InBlockCopySubLengths_E_B = Sequence<2, 2>; - using InBlockCopyClusterLengths_E_B = Sequence<4, 64>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - constexpr index_t InBlockCopyDataPerAccess_B = 2; + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; - constexpr index_t OutThreadCopyDataPerAccess_B = 2; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; #endif - constexpr index_t B = N * Ho * Wo; + constexpr index_t GemmM = K; + constexpr index_t GemmN = N * Ho * Wo; - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - BPerBlock, - KPerBlock, - EPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_B, - InBlockCopyClusterLengths_E_B, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopyDataPerAccess_B, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K, - OutThreadCopyDataPerAccess_B>{}; + constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + ThreadGemmDataPerReadM, + ThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmK, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; for(index_t i = 0; i < nrepeat; ++i) { diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index 2f0df590ae..3a655b6a56 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) { using namespace ck; -#if 0 +#if 1 constexpr index_t N = 8; constexpr index_t C = 128; constexpr index_t HI = 16; diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 5f87b738b1..3762392847 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -43,7 +43,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; @@ -250,7 +250,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% constexpr index_t N = 128; @@ -296,7 +296,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; -#elif 0 +#elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -403,7 +403,7 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,