diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 75381eb76f..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,268 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// GemmM = C * YTilda * XTilda; -// GemmN = N * HTildaSlice * WTildaSlice; -// GemmK = K * YDot * XDot; -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw -{ - __device__ void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) const - { - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - -#if 0 // debug - // sanity-check for vectorized memory load - // TODO: this logic may not be correct for bwd-data - static_assert( - (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && - (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - // weight tensor - constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( - wei_k_c_y_x_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence>{}, - Embed, - Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{}, - Embed, - Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htilda_xdot_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - - constexpr auto out_gemmk_gemmn_global_desc = - transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - -#if 1 // debug - constexpr bool in_skip_all_out_of_bound_check = false; -#else - constexpr bool in_skip_all_out_of_bound_check = true; -#endif - - // input tensor - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence, - in_skip_all_out_of_bound_check>{}, - Embed, - Sequence, - in_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = - transform_tensor_descriptor( - in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - - constexpr auto in_gemmm_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // GEMM - constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1, - Sequence<0, 1>, - 1, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<0, 1, 2, 3>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index a36e7edba0..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,388 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// Number of GEMMs: YTilda * XTilda -// GemmM = C -// GemmN = N * HTildaSlice * WTildaSlice -// GemmK = K * YDotSlice * XDotSlice -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw -{ - // this is a hack, should query this info from gridwise_gemm instead of duplicate its logic - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = - math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space = - math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align); - - return 2 * (a_block_space + b_block_space) * sizeof(Float); - } - - __device__ void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) const - { - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - -#if 0 // debug - // sanity-check for vectorized memory load - // TODO: this logic may not be correct for bwd-data - static_assert( - (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && - (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - constexpr bool wei_skip_all_out_of_bound_check = true; - - // weight tensor - constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( - wei_k_c_y_x_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence, - wei_skip_all_out_of_bound_check>{}, - Embed, - Sequence, - wei_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - -#if 1 // debug - constexpr bool out_skip_all_out_of_bound_check = false; -#else - constexpr bool out_skip_all_out_of_bound_check = true; -#endif - - // output tensor - constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, - out_skip_all_out_of_bound_check>{}, - Embed, - Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, - out_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htilda_xdot_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - -#if 1 // debug - constexpr bool in_skip_all_out_of_bound_check = false; -#else - constexpr bool in_skip_all_out_of_bound_check = true; -#endif - - // input tensor - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence, - in_skip_all_out_of_bound_check>{}, - Embed, - Sequence, - in_skip_all_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = - transform_tensor_descriptor( - in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - - // GEMMs - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); - - __shared__ Float p_shared_block[shared_block_size]; - - static_for<0, YTilda, 1>{}([&](auto iYTilda_) { - static_for<0, XTilda, 1>{}([&](auto iXTilda_) { - constexpr index_t iYTilda = decltype(iYTilda_){}; - constexpr index_t iXTilda = decltype(iXTilda_){}; - - constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; - - // A matrix - constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = - transform_tensor_descriptor( - wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Slice, - Sequence<0, 0>, - Sequence>{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); - - constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // B matrix - constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence<0, 0>, - Sequence>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, - Sequence<2, 4>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, - Sequence<2, 4>{})); - - constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // C matrix - constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = - transform_tensor_descriptor( - in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, - Sequence<2, 4>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<5>{}, - Sequence<2, 4>{})); - - constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1< - GridSize, - BlockSize, - Float, - AccFloat, - decltype(wei_gemmk_gemmm_global_desc), - decltype(out_gemmk_gemmn_global_desc), - decltype(in_gemmm_gemmn_global_desc), - InMemoryDataOperation::Set, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<0, 1, 2, 3>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block); - - // is synchronization necessary? - __syncthreads(); - }); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 1eaf724f0f..bc18872b38 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationW = ConvDilations{}[1]; - //\todo static_assert for global vector load/store - // statc_assert(); - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); @@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); constexpr index_t WTilda = @@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; + // A matrix: weight // weight out-of-bound check can be skipped constexpr bool wei_skip_out_of_bound_check = true; - // weight tensor constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( wei_k_c_y_x_global_desc, make_tuple(PassThrough{}, @@ -217,15 +217,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor( + wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}, + Freeze, Sequence>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{})); + + constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + wei_k_c_ydotslice_xdotslice_global_desc, + make_tuple(Merge>{}, PassThrough{}), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + +// B matrix: output tensor +// TODO sometimes output tensor out-of-bound check can be skipped, find out all such +// situations #if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK constexpr bool out_skip_out_of_bound_check = false; #else - //\todo sometimes output tensor out-of-bound check can be skipped, find out all such - // situations constexpr bool out_skip_out_of_bound_check = true; #endif - // output tensor constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( out_n_k_ho_wo_global_desc, make_tuple(PassThrough{}, @@ -246,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw out_n_k_ydot_htilda_xdot_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, + PassThrough{}, + PassThrough{}, Slice, Sequence, Sequence>{}), @@ -256,14 +272,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + + constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( + out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + +// C matrix: input tensor +// TODO sometimes input out-of-bound check can be skipped, find out all such situations #if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK constexpr bool in_skip_out_of_bound_check = false; #else - //\todo sometimes input out-of-bound check can be skipped, find out all such situations - constexpr bool in_skip_out_of_bound_check = true; + constexpr bool in_skip_out_of_bound_check = true; #endif - // input tensor constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple( @@ -291,87 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = - transform_tensor_descriptor( - in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - - // GEMM - constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; - - // A matrix - constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = - transform_tensor_descriptor( - wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}, - Slice, - Sequence, - Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); - - constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // B matrix - constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); - - constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // C matrix - constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = - transform_tensor_descriptor( - in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor( + in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Freeze, Sequence>{}, + Slice, + Sequence, + Sequence>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{})); constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + in_n_c_htildaslice_wtildaslice_global_desc, + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); constexpr auto gridwise_gemm = diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..24422daeda --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,406 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP +#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm.hpp" + +namespace ck { + +// Number of GEMMs = YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK0 = YDotSlice +// GemmK1 = XDotSlice +// GemmK2 = K +template +struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk +{ + __host__ __device__ static constexpr index_t GetNumberOfGemm() + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + return YTilda * XTilda; + } + + __host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda) + { + constexpr index_t N = InGlobalDesc::GetLengths()[0]; + constexpr index_t Hi = InGlobalDesc::GetLengths()[1]; + constexpr index_t Wi = InGlobalDesc::GetLengths()[2]; + constexpr index_t C = InGlobalDesc::GetLengths()[3]; + + constexpr index_t Ho = OutGlobalDesc::GetLengths()[1]; + constexpr index_t Wo = OutGlobalDesc::GetLengths()[2]; + constexpr index_t K = OutGlobalDesc::GetLengths()[3]; + + constexpr index_t Y = WeiGlobalDesc::GetLengths()[1]; + constexpr index_t X = WeiGlobalDesc::GetLengths()[2]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + + constexpr index_t HTilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + constexpr index_t iHTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t iWTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); + + constexpr index_t iHTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t iWTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; + constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; + + // GemmM and GemmN + constexpr index_t GemmM = C; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; + + // GemmK is different for each GEMM + index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + + index_t GemmK0 = YDotSlice; + index_t GemmK1 = XDotSlice; + index_t GemmK2 = K; + + return Array{GemmM, GemmN, GemmK0, GemmK1, GemmK2}; + } + + __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) + { + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + index_t iYTilda = gemm_id / XTilda; + index_t iXTilda = gemm_id % XTilda; + + return GetGemmSizeImpl(iYTilda, iXTilda); + } + + template + __device__ static void RunImpl(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) + { + constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{}; + constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0]; + constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1]; + constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2]; + constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3]; + + constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1]; + constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2]; + constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1]; + constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + + constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + + constexpr index_t HTilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + constexpr index_t iHTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t iWTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); + + constexpr index_t iHTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t iWTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; + constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; + + // A matrix: weight + // weight out-of-bound check can be skipped + constexpr bool wei_skip_out_of_bound_check = true; + + constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor( + wei_k_y_x_c_global_desc, + make_tuple(PassThrough{}, + Embed, + Sequence, + wei_skip_out_of_bound_check>{}, + Embed, + Sequence, + wei_skip_out_of_bound_check>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor( + wei_k_ydot_ytilda_xdot_xtilda_c_global_desc, + make_tuple( + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}, + Freeze, Sequence>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{})); + + constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc = + reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc, + Sequence<2, 0, 1, 3>{}); + +// B matrix: output tensor +// TODO sometimes output tensor out-of-bound check can be skipped, find out all such +// situations +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool out_skip_out_of_bound_check = false; +#else + constexpr bool out_skip_out_of_bound_check = true; +#endif + + constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor( + out_n_ho_wo_k_global_desc, + make_tuple(PassThrough{}, + Embed, + Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, + out_skip_out_of_bound_check>{}, + Embed, + Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, + out_skip_out_of_bound_check>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc = + transform_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_global_desc, + make_tuple( + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}, + Slice, + Sequence, + Sequence>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{})); + + constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + Merge>{}), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + +// C matrix: input tensor +// TODO sometimes input out-of-bound check can be skipped, find out all such situations +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool in_skip_out_of_bound_check = false; +#else + constexpr bool in_skip_out_of_bound_check = true; +#endif + + constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor( + in_n_hi_wi_c_global_desc, + make_tuple(PassThrough{}, + Pad, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1]; + constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2]; + + constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor( + in_n_hip_wip_c_global_desc, + make_tuple(PassThrough{}, + Embed, + Sequence, + in_skip_out_of_bound_check>{}, + Embed, + Sequence, + in_skip_out_of_bound_check>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_global_desc, + make_tuple(PassThrough{}, + Freeze, Sequence>{}, + Slice, + Sequence, + Sequence>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_global_desc, + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // call GEMM + constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v2< + GridSize, + BlockSize, + Float, + AccFloat, + decltype(wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc), + decltype(out_gemmk0_gemmk1_gemmk2_gemmn_global_desc), + decltype(in_gemmm_gemmn_global_desc), + InMemoryDataOperation::Set, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + ThreadGemmDataPerRead_GemmM, + ThreadGemmDataPerRead_GemmN, + GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 3, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, + Sequence<0, 1, 3, 2>, + Sequence<0, 1, 3, 2>, + 2, + GemmBBlockCopySrcDataPerRead_GemmK2, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); + } + + template + __device__ static void Run(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global, + Number) + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t iYTilda = GemmId / XTilda; + constexpr index_t iXTilda = GemmId % XTilda; + + static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda"); + + RunImpl(p_in_global, p_wei_global, p_out_global); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 681426c4d5..15a052ea31 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -488,6 +488,49 @@ struct Embed } }; +// LowerLengths: Sequence<...> +// LowerFreezePoint: Sequence<...> +template +struct Freeze +{ + static constexpr index_t nDimLow = LowerLengths::Size(); + static constexpr index_t nDimUp = 0; + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex; + + __host__ __device__ explicit constexpr Freeze() + { + // TODO: sanity check: LowerFreezePoint should be within range of LowerLengths + } + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; } + + __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/) + { + return to_array(LowerFreezePoint{}); + } + + __host__ __device__ static constexpr auto + CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */, + const UpperIndex& /* idx_up_old */, + const LowerIndex& /* idx_low_old */) + { + return make_zero_array(); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } +}; + template struct Vectorize { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index d4cbee1ced..fbf2bfe911 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -376,5 +376,400 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 } }; +template +struct GridwiseGemmTransposedANormalBNormalC_v2 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, + BBlockCopyDstDataPerWrite_N, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // LDS allocation for A and B: be careful of alignment + constexpr index_t a_block_space = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); + + constexpr index_t b_block_space = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); + + return 2 * (a_block_space + b_block_space) * sizeof(Float); + } + + __device__ void Run(const Float* __restrict__ p_a_global, + const Float* __restrict__ p_b_global, + Float* __restrict__ p_c_global, + Float* __restrict__ p_shared_block) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{}; + constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{}; + constexpr auto c_m_n_global_desc = CGlobalDesc{}; + + constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[0]; + constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[1]; + constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[2]; + constexpr auto M = c_m_n_global_desc.GetLengths()[0]; + constexpr auto N = c_m_n_global_desc.GetLengths()[1]; + + // don't do anything if K == 0 + if(K == 0) + { + return; + } + + // lds max alignment + constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, + BBlockCopyDstDataPerWrite_N, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N); + + // divide block work by [M, N] + static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t MBlockWork = M / MPerBlock; + constexpr index_t NBlockWork = N / NPerBlock; + + constexpr auto block_work_desc = + make_cluster_descriptor(Sequence{}); + + const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); + + const index_t m_block_data_on_global = block_work_id[0] * MPerBlock; + const index_t n_block_data_on_global = block_work_id[1] * NPerBlock; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_k1_k2_m_block_desc = make_native_tensor_descriptor_aligned( + Sequence<1, 1, KPerBlock, MPerBlock>{}, Number{}); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseGenericTensorSliceCopy_v4, + ABlockCopySrcVectorReadDim, + 3, + ABlockCopySrcDataPerRead, + ABlockCopyDstDataPerWrite_M, + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( + {0, 0, 0, m_block_data_on_global}, {0, 0, 0, 0}); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_k1_k2_n_block_desc = make_native_tensor_descriptor_aligned( + Sequence<1, 1, KPerBlock, NPerBlock>{}, Number{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseGenericTensorSliceCopy_v4, + BBlockCopySrcVectorReadDim, + 3, + BBlockCopySrcDataPerRead, + BBlockCopyDstDataPerWrite_N, + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( + {0, 0, 0, n_block_data_on_global}, {0, 0, 0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor( + unfold_tensor_descriptor(a_k0_k1_k2_m_block_desc, I0, I2)); + constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor( + unfold_tensor_descriptor(b_k0_k1_k2_n_block_desc, I0, I2)); + + // sanity check + static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && + NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, + "wrong!"); + + constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); + constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_k_m_block_mtx_desc), + decltype(b_k_n_block_mtx_desc), + decltype(c_m0m1_n0n1_thread_mtx_desc), + MPerThread, + NPerThread, + KPerThread, + MLevel0Cluster, + NLevel0Cluster, + MLevel1Cluster, + NLevel1Cluster, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N>{}; + + // LDS allocation for A and B: be careful of alignment + constexpr index_t a_block_space = + math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align); + + constexpr index_t b_block_space = + math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align); + + Float* p_a_block_double = p_shared_block; + Float* p_b_block_double = p_shared_block + 2 * a_block_space; + + // register allocation for output + AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); + + for(index_t k0 = 0; k0 < K0; ++k0) + { + for(index_t k1 = 0; k1 < K1; ++k1) + { + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.Run(p_a_global, p_a_block_double); + b_blockwise_copy.Run(p_b_global, p_b_block_double); + } + + constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; + constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; + + // LDS double buffer: main body + for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; + k_block_data_begin += 2 * KPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_a_block_now = + even_loop ? p_a_block_double : p_a_block_double + a_block_space; + Float* p_b_block_now = + even_loop ? p_b_block_double : p_b_block_double + b_block_space; + + Float* p_a_block_next = + even_loop ? p_a_block_double + a_block_space : p_a_block_double; + Float* p_b_block_next = + even_loop ? p_b_block_double + b_block_space : p_b_block_double; + + Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; + Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; + + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); + } + } + + // LDS double buffer: tail + { + constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); + + if(has_two_iteration_left) // if has 2 iteration left + { + Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; + Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; + + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, + p_a_block_double + a_block_space); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, + p_b_block_double + b_block_space); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_a_block_double + a_block_space, + p_b_block_double + b_block_space, + p_c_thread); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + } + } + + // reset slice windoww on K2 dimension, then move forward on K1 dimension + a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False); + b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False); + + a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True); + b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True); + } + + // reset slice windoww on K1 dimension, then move forward on K0 dimension + a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False); + b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False); + + a_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True); + b_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True); + } + + // input: register to global memory + { + constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; + constexpr index_t M0 = M / M1; + + constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; + constexpr index_t N0 = N / N1; + + // define input tensor descriptor for threadwise copy + // thread input tensor, src of threadwise copy + constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( + Sequence{}); + + constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( + c_m_n_global_desc, + make_tuple(UnMerge>{}, UnMerge>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // calculate origin of thread input tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t m_thread_data_on_global = + m_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t n_thread_data_on_global = + n_block_data_on_global + c_thread_mtx_on_block.col; + + ThreadwiseGenericTensorSliceCopy_v4r2( + {0, 0, 0, 0}, + {m_thread_data_on_global / M1, + m_thread_data_on_global % M1, + n_thread_data_on_global / N1, + n_thread_data_on_global % N1}) + .Run(p_c_thread, p_c_global); + } + } + + __device__ void Run(const Float* __restrict__ p_a_global, + const Float* __restrict__ p_b_global, + Float* __restrict__ p_c_global) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); + + __shared__ Float p_shared_block[shared_block_size]; + + Run(p_a_global, p_b_global, p_c_global, p_shared_block); + } +}; + } // namespace ck #endif diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp index 213b20530d..0f68ec7d58 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -12,6 +12,7 @@ struct Array using type = Array; using data_type = TData; + // TODO: implement empty Array index_t mData[NSize]; __host__ __device__ explicit constexpr Array() {} diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 27098cb3e8..dcf0be1674 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -24,6 +24,7 @@ #if CK_USE_AMD_XDLOPS #include "amd_xdlops.hpp" +#include "amd_xdlops_inline_asm.hpp" #endif #endif diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 4b274401eb..83ecae161c 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -108,8 +108,12 @@ struct SetData { const auto zeros = vector_t(0); - amd_buffer_store( - src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range); + amd_buffer_store(src_valid ? &(p_src[src_offset]) + : reinterpret_cast(&zeros), + p_dst, + dst_offset, + dst_valid, + dst_range); } #endif }; @@ -145,19 +149,17 @@ struct AtomicAddData template <> __device__ void Run(const T* p_src, index_t src_offset, + bool src_valid, index_t /* src_range */, - bool src_valid T* p_dst, + T* p_dst, index_t dst_offset, bool dst_valid, index_t dst_range) const { const auto zeros = vector_t(0); - amd_buffer_atomic_add(src_valid ? &(p_src[src_offset]) : &zeros, - p_dst, - dst_offset, - dst_valid, - index_t dst_range); + amd_buffer_atomic_add( + src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range); } #endif }; diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index a986b14e1d..10bb32f938 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -16,15 +16,14 @@ install(TARGETS host LIBRARY DESTINATION lib) if(DEVICE_BACKEND STREQUAL "AMD") set(CONV_SOURCE src/conv_driver.cpp) - set(COL2IM_SOURCE src/col2im_driver.cpp) set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp) elseif(DEVICE_BACKEND STREQUAL "NVIDIA") set(CONV_SOURCE src/conv_driver.cu) - set(COL2IM_SOURCE src/col2im_driver.cu) set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu) endif() add_executable(conv_driver ${CONV_SOURCE}) add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) + target_link_libraries(conv_driver PRIVATE host) target_link_libraries(conv_bwd_data_driver PRIVATE host) diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 92ad30c568..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,257 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - std::size_t nrepeat) -{ - using namespace ck; - - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - - constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; - constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; - - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // BlockSize = 256, each thread hold 64 data - // for 1x1 weight, 8x8 input - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - constexpr index_t GemmM = C * YTilda * XTilda; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nchw_device_buf.FromDevice(in_nchw.mData.data()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index ba68390326..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,196 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - std::size_t nrepeat) -{ - using namespace ck; - - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - - constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; - constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; - - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nchw_device_buf.FromDevice(in_nchw.mData.data()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index e870990a72..032fd375b6 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -57,8 +57,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 1 - // BlockSize = 256, each thread hold 64 data +#if 0 + // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -86,6 +86,36 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..789ebc4b9d --- /dev/null +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,266 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "gridwise_operation_wrapper.hpp" +#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp" + +namespace launcher { + +using namespace ck; + +template +void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc, + Tensor& in_nchw, + WeiDesc wei_kcyx_desc, + const Tensor& wei_kcyx, + OutDesc out_nkhw_desc, + const Tensor& out_nkhw, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + std::size_t nrepeat) +{ + constexpr index_t N = out_nkhw_desc.GetLengths()[0]; + constexpr index_t K = out_nkhw_desc.GetLengths()[1]; + constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; + + constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; + constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; + + constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; + constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; + + constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; + constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence{}); + + Tensor in_nhwc(make_HostTensorDescriptor(in_nhwc_desc)); + Tensor wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc)); + Tensor out_nhwk(make_HostTensorDescriptor(out_nhwk_desc)); + + auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { + in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); + }; + + auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { + wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); + }; + + auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { + out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); + }; + + make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); + + std::size_t data_sz = sizeof(T); + DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); + DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); + DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); + + in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); + wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); + out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); + +#if 0 + // cdata = 64, BlockSize = 256, 128x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#endif + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + constexpr index_t HTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t WTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); + + constexpr index_t HTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; + constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; + + constexpr index_t GemmM = C; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; + + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t i = 0; i < nrepeat; ++i) + { + using GridwiseConvBwdData = + GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk< + GridSize, + BlockSize, + T, + T, + decltype(in_nhwc_desc), + decltype(wei_kyxc_desc), + decltype(out_nhwk_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, + GemmBBlockCopySrcDataPerRead_GemmK2, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>; + + static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { + constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); + constexpr index_t gemm_k2 = gemm_sizes.At(4); + constexpr bool is_gemm_not_empty = gemm_k2 > 0; + + // only compile and run if GEMM is no empty + static_if{}([&](auto fwd) { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nhwc_device_buf.GetDeviceBuffer()), + static_cast(wei_kyxc_device_buf.GetDeviceBuffer()), + static_cast(out_nhwk_device_buf.GetDeviceBuffer()), + fwd(gemm_id)); + }); + }); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + in_nhwc_device_buf.FromDevice(in_nhwc.mData.data()); + + auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) { + in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c); + }; + + make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency()); +} + +} // namespace launcher diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 32d01136ca..080aa2006f 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 +#elif 0 // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; @@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 +#elif 1 // cdata = 64, BlockSize = 256, 128x128x16 constexpr index_t BlockSize = 256; diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index 5cc7d6621b..a5248bfb1a 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -15,9 +15,8 @@ #include "host_conv_bwd_data.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp" int main(int argc, char* argv[]) { @@ -55,7 +54,7 @@ int main(int argc, char* argv[]) #elif 0 // 3x3, 28x28 constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 256; constexpr index_t HI = 28; constexpr index_t WI = 28; constexpr index_t K = 1024; @@ -160,7 +159,7 @@ int main(int argc, char* argv[]) #elif 0 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 256; constexpr index_t HI = 17; constexpr index_t WI = 17; constexpr index_t K = 1024; @@ -175,7 +174,7 @@ int main(int argc, char* argv[]) #elif 0 // 7x1 filter, 3x0 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 256; constexpr index_t HI = 17; constexpr index_t WI = 17; constexpr index_t K = 1024; @@ -190,10 +189,10 @@ int main(int argc, char* argv[]) #elif 1 // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output constexpr index_t N = 128; - constexpr index_t C = 128; + constexpr index_t C = 256; constexpr index_t HI = 35; constexpr index_t WI = 35; - constexpr index_t K = 1024; + constexpr index_t K = 1280; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -247,14 +246,12 @@ int main(int argc, char* argv[]) #if 0 device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw -#elif 1 +#elif 0 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw #elif 0 - device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw -#elif 0 - device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw -#elif 1 device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw +#elif 1 + device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk #endif (in_nchw_desc, in_nchw_device,