From c5da0377fbf4c2ab92a9d62f711bbc832481a92d Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 20 Jan 2020 10:20:03 -0600 Subject: [PATCH] Added bwd data v3r1 v4r1, tweaking v1 (#10) * Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel * Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels * Tweaked v1r1 and v1r2 (atomic) on AMD GPU --- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 16 +- ..._v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp | 32 +- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 141 ++++--- ...data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp | 393 ++++++++++++++++++ ...data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 333 +++++++++++++++ ..._v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp | 7 +- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 7 +- .../multi_index_transform.hpp | 284 +++++++++---- .../tensor_description/tensor_coordinate.hpp | 47 ++- .../tensor_description/tensor_descriptor.hpp | 59 ++- .../tensor_operation/gridwise_gemm.hpp | 51 ++- .../threadwise_generic_tensor_slice_copy.hpp | 12 +- .../include/utility/amd_buffer_addressing.hpp | 40 ++ .../include/utility/config.amd.hpp.in | 5 + .../include/utility/functional2.hpp | 5 +- .../utility/in_memory_operation.amd.hpp.in | 5 + composable_kernel/include/utility/math.hpp | 6 + driver/CMakeLists.txt | 12 +- driver/include/device.hpp | 58 ++- driver/include/device_col2im_eb_nchw.hpp | 3 +- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 32 +- ...data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp | 32 +- ...data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp | 123 +++++- ...data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp | 186 +++++++++ ...data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 232 +++++++++++ ...e_convolution_direct_v2_nchw_kcyx_nkhw.hpp | 14 +- ...lution_implicit_gemm_v1_chwn_cyxk_khwn.hpp | 3 +- ...implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp | 3 +- ...lution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp | 3 +- ...lution_implicit_gemm_v2_chwn_cyxk_khwn.hpp | 3 +- ...lution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp | 3 +- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 8 +- ...it_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp | 41 +- ...tion_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp | 3 +- ...tion_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp | 3 +- ...tion_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 39 +- ...it_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp | 3 +- ...onvolution_2_vectorized_nchw_kcyx_nkhw.hpp | 2 +- driver/src/conv_bwd_data_driver.cpp | 227 ++++------ driver/src/conv_driver.cpp | 81 ++-- script/cmake-cuda.sh | 6 +- script/cmake-cuda_docker.sh | 6 +- script/extract_asm-cuda.sh | 6 +- 43 files changed, 2123 insertions(+), 452 deletions(-) create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp create mode 100644 driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp create mode 100644 driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 0b089f3987..8221f32358 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw in_n_c_hip_wip_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), + Embed, + Sequence>{}, + Embed, + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - // GEMM: atomic add + // GEMM + constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) + ? InMemoryDataOperation::none + : InMemoryDataOperation::atomic_add; + constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1>{}, UnMerge>{}, - Embed, Sequence>{}, - Embed, Sequence>{}), + Embed, + Sequence>{}, + Embed, + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); @@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl InThreadCopyDstDataPerWrite_B, AddressSpace::vgpr, AddressSpace::global, - InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0}, - {e_thread_data_on_global / E1, - e_thread_data_on_global % E1, - 0, - b_thread_data_on_global / B1, - b_thread_data_on_global % B1, - 0}) + in_memory_op>({0, 0, 0, 0, 0, 0}, + {e_thread_data_on_global / E1, + e_thread_data_on_global % E1, + 0, + b_thread_data_on_global / B1, + b_thread_data_on_global % B1, + 0}) .Run(p_in_thread, p_in_global); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 0e24a80c85..4615fae759 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -9,7 +9,7 @@ namespace ck { // GemmM = C * Ytilda * Xtilda; -// GemmN = N * Htilda * Wtilda; +// GemmN = N * HtildaNonZero * WtildaNonZero; // GemmK = K * Ydot * Xdot; template {}, PassThrough{}, - Pad, - Sequence<0, 0>, - Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( - wei_k_c_yp_xp_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, + Embed, Sequence>{}, - Embed, + Embed, Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Pad, Sequence<0, 0>, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( - out_n_k_hop_wop_global_desc, + out_n_k_ho_wo_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Embed, + Embed, Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{}, - Embed, + Embed, Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_ydot_htilda_xdot_wtilda_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htilda_xdot_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + + constexpr auto out_gemmk_gemmn_global_desc = + transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + +#if 1 // debug + constexpr bool in_skip_all_out_of_bound_check = false; +#else + constexpr bool in_skip_all_out_of_bound_check = true; +#endif // input tensor constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Pad, InputLeftPads, InputRightPads>{}), + make_tuple( + PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}, + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + + constexpr auto in_gemmm_gemmn_global_desc = + transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); // GEMM constexpr auto gridwise_gemm = diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..a0c94a892e --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,393 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP +#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm.hpp" + +namespace ck { + +// Ytilda*Xtilda number of GEMMs +// GemmM = C; +// GemmN = N * HtildaNonZero * WtildaNonZero; +// GemmK = K * YdotNonZero * XdotNonZero; +template +struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw +{ + // this is a hack, should query this info from gridwise_gemm instead of duplicate its logic + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // LDS allocation for A and B: be careful of alignment + constexpr index_t a_block_space = + math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align); + + constexpr index_t b_block_space = + math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align); + + return 2 * (a_block_space + b_block_space) * sizeof(Float); + } + + __device__ void Run(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) const + { + constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; + constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; + constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; + + constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + +#if 0 // debug + // sanity-check for vectorized memory load + // TODO: this logic may not be correct for bwd-data + static_assert( + (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && + (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); +#endif + + constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); + constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + + constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); + constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + + constexpr index_t Htilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t Wtilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + constexpr index_t HtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); + constexpr index_t WtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + + constexpr index_t HtildaRight = math::min( + Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WtildaRight = math::min( + Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; + constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + + constexpr bool wei_skip_all_out_of_bound_check = true; + + // weight tensor + constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( + wei_k_c_y_x_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + wei_skip_all_out_of_bound_check>{}, + Embed, + Sequence, + wei_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + +#if 1 // debug + constexpr bool out_skip_all_out_of_bound_check = false; +#else + constexpr bool out_skip_all_out_of_bound_check = true; +#endif + + // output tensor + constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( + out_n_k_ho_wo_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, + out_skip_all_out_of_bound_check>{}, + Embed, + Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, + out_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htilda_xdot_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + +#if 1 // debug + constexpr bool in_skip_all_out_of_bound_check = false; +#else + constexpr bool in_skip_all_out_of_bound_check = true; +#endif + + // input tensor + constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + + constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}, + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + + // GEMMs + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); + + __shared__ Float p_shared_block[shared_block_size]; + +#if 1 // debug + static_for<0, Ytilda, 1>{}([&](auto ytilda_) { + static_for<0, Xtilda, 1>{}([&](auto xtilda_) { +#else + static_for<0, 1, 1>{}([&](auto ytilda_) { + static_for<0, 1, 1>{}([&](auto xtilda_) { +#endif + constexpr index_t ytilda = decltype(ytilda_){}; + constexpr index_t xtilda = decltype(xtilda_){}; + + constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; + constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; + + // A matrix + constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc = + transform_tensor_descriptor( + wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Slice, + Sequence<0, 0>, + Sequence>{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); + + constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B matrix + constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence<0, 0>, + Sequence>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2, 4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2, 4>{})); + + constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( + out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C matrix + constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2, 4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2, 4>{})); + + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( + in_n_c_1_htildatrim_1_wtildatrim_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1< + GridSize, + BlockSize, + Float, + AccFloat, + decltype(wei_gemmk_gemmm_global_desc), + decltype(out_gemmk_gemmn_global_desc), + decltype(in_gemmm_gemmn_global_desc), + InMemoryDataOperation::none, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<0, 1, 2, 3>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block); + + // is synchronization necessary? + __syncthreads(); + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..f96e99af6f --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,333 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP +#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm.hpp" + +namespace ck { + +// GemmM = C +// GemmN = N * Htilda * Wtilda; +// GemmK = K * YdotNonZero * XdotNonZero +template +struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw +{ + __device__ void Run(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) const + { + constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; + constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; + constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; + + constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + +#if 0 // debug + // sanity-check for vectorized memory load + // TODO: this logic may not be correct for bwd-data + static_assert( + (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && + (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); +#endif + + constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); + constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + + constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); + constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + + constexpr index_t Htilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t Wtilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + constexpr index_t HtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); + constexpr index_t WtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + + constexpr index_t HtildaRight = math::min( + Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WtildaRight = math::min( + Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; + constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + + constexpr bool wei_skip_all_out_of_bound_check = true; + + // weight tensor + constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( + wei_k_c_y_x_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + wei_skip_all_out_of_bound_check>{}, + Embed, + Sequence, + wei_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + +#if 1 // debug + constexpr bool out_skip_all_out_of_bound_check = false; +#else + constexpr bool out_skip_all_out_of_bound_check = true; +#endif + + // output tensor + constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( + out_n_k_ho_wo_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, + out_skip_all_out_of_bound_check>{}, + Embed, + Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, + out_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htilda_xdot_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + +#if 1 // debug + constexpr bool in_skip_all_out_of_bound_check = false; +#else + constexpr bool in_skip_all_out_of_bound_check = true; +#endif + + // input tensor + constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + + constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}, + Embed, + Sequence, + in_skip_all_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + + // GEMM + constexpr index_t ytilda = Iter_ytilda; + constexpr index_t xtilda = Iter_xtilda; + + constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; + constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; + + // A matrix + constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc = + transform_tensor_descriptor( + wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Slice, + Sequence<0, 0>, + Sequence>{}, + Slice, + Sequence, + Sequence>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); + + constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B matrix + constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence<0, 0>, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + + constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( + out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C matrix + constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor( + in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( + in_n_c_1_htildatrim_1_wtildatrim_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + constexpr auto gridwise_gemm = + GridwiseGemmTransposedANormalBNormalC_v1, + Sequence<0, 1>, + 1, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<0, 1, 2, 3>, + 3, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 9cbc2ce14e..c8830e310d 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple(UnMerge>{}, PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), + Embed, Sequence>{}, + Embed, Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 5a4c4a1930..099756997c 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), + Embed, Sequence>{}, + Embed, Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index bd69c402b5..1091c90130 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -41,15 +41,17 @@ struct PassThrough __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */) + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { return true; } }; // LowerLengths: Sequence<...> -template +template struct Pad { static constexpr index_t nDim = LowerLengths::Size(); @@ -57,6 +59,13 @@ struct Pad using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; + __host__ __device__ explicit constexpr Pad() + { + static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim && + RightPads::GetSize() == nDim, + "wrong! # of dimensions not consistent"); + } + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } @@ -81,20 +90,83 @@ struct Pad __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - __host__ __device__ constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { +#if 1 // debug + if(SkipIsValidCheck) + { + return true; + } +#endif bool flag = true; - static_for<0, nDim, 1>{}([&](auto idim) { - flag = flag && (idx_up[idim] >= LeftPads::At(idim)) && - (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim)); - }); + for(index_t i = 0; i < nDim; ++i) + { + flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0; + } return flag; } }; +// LowerLengths: Sequence<...> +// SliceBegins: Sequence<...> +// SliceEnds: Sequence<...> +template +struct Slice +{ + static constexpr index_t nDim = LowerLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex; + + __host__ __device__ explicit constexpr Slice() + { + static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim && + SliceEnds::GetSize() == nDim, + "wrong! # of dimensions not consistent"); + +#if 0 + // TODO: would not compile, error on constexpr + static_for<0, nDim, 1>{}([&](auto idim) { + static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) && + SliceBegins::At(idim) >= 0 && + SliceEnds::At(idim) <= LowerLengths::At(idim), + "wrong! Slice config is wrong"); + }); +#endif + } + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() + { + return SliceEnds{} - SliceBegins{}; + } + + __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) + { + return idx_up + SliceBegins{}; + } + + __host__ __device__ static constexpr auto + CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, + const UpperIndex& /* idx_up_old */, + const LowerIndex& /* idx_low_old */) + { + return idx_up_diff; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } +}; + // LowerLengths: Sequence<...> template struct Merge @@ -165,85 +237,101 @@ struct Merge const UpperIndex& /* idx_up_old */, const LowerIndex& idx_low_old) { - // do nothing if idx_up_diff == 0 if(idx_up_diff[0] == 0) { return make_zero_array(); } - - // CalculateLowerIndex(idx_up_diff) has multiple integer divisions. - // If idx_up_diff is known at compile-time, the calculation can - // be done at compile-time. However, if idx_up_diff is only known - // at run-time, then the calculation will also be computed at - // run-time, and can be very expensive. - LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff); - - if(idx_up_diff[0] > 0) + else { - bool carry = false; + // CalculateLowerIndex(idx_up_diff) has multiple integer divisions. + // If idx_up_diff is known at compile-time, the calculation can + // be done at compile-time. However, if idx_up_diff is only known + // at run-time, then the calculation will also be computed at + // run-time, and can be very expensive. + LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff); - // do carry check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { - constexpr index_t i = nDimLow - 1 - ireverse; + // find out the last low dimension that changed + index_t last_changed_low_dim = 0; + static_for<0, nDimLow, 1>{}([&](auto i) { + if(idx_low_diff_tmp[i] != 0) + { + last_changed_low_dim = i; + } + }); + + LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp; + + if(idx_up_diff[0] > 0) + { + // do carry check on each low dimension in reversed order + // starting from the first digit that changed + // don't check the highest dimension + bool carry = false; + + static_for{}([&](auto i) { + if(i <= last_changed_low_dim) + { + if(carry) + { + ++idx_low_new(i); + } + + carry = false; + + if(idx_low_new[i] >= LowerLengths::At(i)) + { + idx_low_new(i) -= LowerLengths::At(i); + carry = true; + } + } + }); + + // highest dimension, no out-of-bound check if(carry) { - ++idx_low_new(i); + ++idx_low_new(0); } - - carry = false; - - if(idx_low_new[i] >= LowerLengths::At(i)) - { - idx_low_new(i) -= LowerLengths::At(i); - carry = true; - } - }); - - // highest dimension, no out-of-bound check - if(carry) - { - ++idx_low_new(0); } - } - else if(idx_up_diff[0] < 0) - { - bool borrow = false; + else + { + // do borrow check on each low dimension in reversed order + // starting from the first digit that changed + // don't check the highest dimension + bool borrow = false; - // do borrow check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) { - constexpr index_t i = nDimLow - 1 - ireverse; + static_for{}([&](auto i) { + if(i <= last_changed_low_dim) + { + if(borrow) + { + --idx_low_new(i); + } + borrow = false; + + if(idx_low_new[i] < 0) + { + idx_low_new(i) += LowerLengths::At(i); + borrow = true; + } + } + }); + + // highest dimension, no out-of-bound check if(borrow) { - --idx_low_new(i); + --idx_low_new(0); } - - borrow = false; - - if(idx_low_new[i] < 0) - { - idx_low_new(i) += LowerLengths::At(i); - borrow = true; - } - }); - - // highest dimension, no out-of-bound check - if(borrow) - { - --idx_low_new(0); } - } - return idx_low_new - idx_low_old; + return idx_low_new - idx_low_old; + } } __host__ __device__ static constexpr bool IsLinearTransform() { return false; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */) + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { return true; } @@ -290,8 +378,7 @@ struct UnMerge __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */) + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { return true; } @@ -300,7 +387,10 @@ struct UnMerge // UpperLengths: Sequence<...> // Coefficients: Sequence<...> // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] -template +template struct Embed { static constexpr index_t nDimLow = 1; @@ -325,8 +415,10 @@ struct Embed { LowerIndex idx_low(Coefficients{}[nDimUp]); - static_for<0, nDimUp, 1>{}( - [&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; }); + for(index_t i = 0; i < nDimUp; ++i) + { + idx_low(0) += idx_up[i] * Coefficients{}[i]; + } return idx_low; } @@ -338,18 +430,55 @@ struct Embed { LowerIndex idx_low_diff{0}; - static_for<0, nDimUp, 1>{}( - [&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; }); + for(index_t i = 0; i < nDimUp; ++i) + { + idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i]; + } return idx_low_diff; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */) + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { - return true; +#if 1 // debug + if(SkipIsValidCheck) + { + return true; + } +#endif + bool flag = true; + + index_t ncorner = 1; + + for(index_t idim = 0; idim < nDimUp; ++idim) + { + ncorner *= 2; + } + + // loop over each corner of the upper tensor + for(index_t icorner = 0; icorner < ncorner; ++icorner) + { + // generate upper index for each corner + auto idx_up = make_zero_array(); + + index_t itmp = icorner; + + for(index_t idim = nDimUp - 1; idim >= 0; --idim) + { + idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1; + itmp /= 2; + } + + // calculate lower index + auto idx_low = CalculateLowerIndex(idx_up); + + // judge if lower index is valid + flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength; + } + + return flag; } }; @@ -389,8 +518,7 @@ struct Vectorize __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */) + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() { return true; } diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index 4e5c5cc8ec..f796dac880 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -53,6 +53,8 @@ struct NativeTensorCoordinate __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } + __host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; } + __host__ __device__ constexpr const Index& GetIndex() const { return mIndex; } __host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; } @@ -98,7 +100,24 @@ struct NativeTensorCoordinate return tensor_desc_type::CalculateOffsetDiff(idx_diff); } - __host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; } + // evaluated at run-time + __host__ __device__ constexpr bool IsUpperIndexValid() const + { + return tensor_desc_type::IsUpperIndexValid(GetUpperIndex()); + } + + // evaluated at run-time + __host__ __device__ constexpr bool IsOffsetValid() const + { + // For native tensor, offset is valid if upper-index is valid + return IsUpperIndexValid(); + } + + // evaluated at compile-time + __host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid() + { + return true; + } private: // mIndex may be saved and updated, however, the value of some (or all) of its entries may @@ -206,10 +225,30 @@ struct TransformedTensorCoordinate return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff); } - __host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const + // evaluated at run-time + __host__ __device__ constexpr bool IsUpperIndexValid() const { - return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) && - mCoordLow.IsUpperIndexMappedToValidOffset(); + return tensor_desc_type::IsUpperIndexValid(GetUpperIndex()); + } + + // evaluted at run-time + __host__ __device__ constexpr bool IsOffsetValid() const + { + return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid(); + } + + // most evaluatation is done at comile-time + __host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const + { + return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid( + GetLowerCoordinate().GetIndex()); + } + + // most evaluatation is done at comile-time + __host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const + { + return IsLowerIndexValidAssumingUpperIndexIsValid() && + GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid(); } private: diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index dec7e2b8da..de525748c7 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -120,10 +120,17 @@ struct NativeTensorDescriptor return Tuple<>{}; } - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidOffset(const Index& /* idx */) + // a multi-index is valid if there is a corresponding point for it in the tensor + __host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx) { - return true; + bool flag = true; + + for(index_t i = 0; i < nDim; ++i) + { + flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i]; + } + + return flag; } }; @@ -467,33 +474,49 @@ struct TransformedTensorDescriptor } #endif + // a multi-index is valid if there is a corresponding point for it in the tensor + __host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const + { + bool flag = true; + + for(index_t i = 0; i < nDimUp; ++i) + { + flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i]; + } + + return flag; + } + + // this function is for optimization purpose, it's called by tensor coordinate + // this function tells you: If a lower-index is valid or not, assuming upper index is valid __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) + IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low) { bool flag = true; static_for<0, nTransform, 1>{}([&](auto itran) { constexpr auto tran = Transforms{}.At(itran); - const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran)); + // check a indtransformation if it does not always has a valid mapping + constexpr bool is_valid_up_always_mapped_to_valid_low = + decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex(); - flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part)); + if(!is_valid_up_always_mapped_to_valid_low) + { + constexpr auto low_dims_part = LowDimensionIds{}.At(itran); + constexpr auto low_lengths_part = + GetLowerTensorDescriptor().GetLengths(low_dims_part); + const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part)); + + for(index_t i = 0; i < low_dims_part.Size(); ++i) + { + flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i]; + } + } }); return flag; } - - // Whenever this function is called, it will call CalculateLowerIndex() recursively. - // If you have created a tensor coordinate already, instead of calling this function, - // you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would - // be less expensive. - __host__ __device__ static constexpr bool - IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up) - { - return IsUpperIndexMappedToValidLowerIndex(idx_up) && - GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset( - CalculateLowerIndex(idx_up)); - } }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index 8a6a3f72c3..56d779616f 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -50,9 +50,37 @@ template struct GridwiseGemmTransposedANormalBNormalC_v1 { + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, + BBlockCopyDstDataPerWrite_N, + ThreadGemmDataPerReadM, + ThreadGemmDataPerReadN); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // LDS allocation for A and B: be careful of alignment + constexpr index_t a_block_space = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); + + constexpr index_t b_block_space = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); + + return 2 * (a_block_space + b_block_space) * sizeof(Float); + } + __device__ void Run(const Float* __restrict__ p_a_global, const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global) const + Float* __restrict__ p_c_global, + Float* __restrict__ p_shared_block) const { constexpr auto True = integral_constant{}; @@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 constexpr auto M = a_k_m_global_desc.GetLengths()[1]; constexpr auto N = b_k_n_global_desc.GetLengths()[1]; + // don't do anything if K == 0 + if(K == 0) + { + return; + } + // lds max alignment constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, BBlockCopyDstDataPerWrite_N, @@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 constexpr index_t b_block_space = math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - __shared__ Float p_a_block_double[2 * a_block_space]; - __shared__ Float p_b_block_double[2 * b_block_space]; + Float* p_a_block_double = p_shared_block; + Float* p_b_block_double = p_shared_block + 2 * a_block_space; // register allocation for output AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; @@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 .Run(p_c_thread, p_c_global); } } + + __device__ void Run(const Float* __restrict__ p_a_global, + const Float* __restrict__ p_b_global, + Float* __restrict__ p_c_global) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); + + __shared__ Float p_shared_block[shared_block_size]; + + Run(p_a_global, p_b_global, p_c_global, p_shared_block); + } }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index b3a4d46fc9..ce18a92d86 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 // Check src data's valid mapping situation, only check the first data in this src // vector. It's user's responsiblity to make sure all data in the src vector // has the valid/invalid mapping situation - if(src_coord.IsUpperIndexMappedToValidOffset()) + if(src_coord.IsOffsetValidAssumingUpperIndexIsValid()) { move_data::MemoryType index_t dst_thread_data_offset, index_t dst_const_data_offset); +template +__device__ void +amd_intrinsic_buffer_atomic_add(const typename vector_type::MemoryType& src, + T* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset); + template <> __device__ float amd_intrinsic_buffer_load(const float* p_src_block, index_t src_thread_data_offset, @@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store(const float4_t& src, #endif } +template <> +__device__ void amd_intrinsic_buffer_atomic_add(const float& src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); + + BufferLoadStoreDwordConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + +#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC + __llvm_amdgcn_buffer_atomic_add( + src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false); +#else + static_assert(false, " wrong! not implemented"); +#endif +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 052142679f..adf32ae32d 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -29,6 +29,11 @@ #define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1 #endif +// only support gfx908 +#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD +#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0 +#endif + // AMD XDLOPS #ifndef CK_USE_AMD_XDLOPS #define CK_USE_AMD_XDLOPS 0 diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index 68706a2973..ed0ce1ce0e 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -29,9 +29,10 @@ struct static_for { __host__ __device__ constexpr static_for() { - static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); - static_assert((NEnd - NBegin) % Increment == 0, + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should have NBegin <= NEnd"); } template diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 190e2b6132..6ffe96a83a 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in static_if{}( [&](auto) { +#if CK_USE_AMD_BUFFER_ATOMIC_ADD + amd_intrinsic_buffer_atomic_add( + *reinterpret_cast(&p_src[src_offset]), p_dst, dst_offset, 0); +#else atomicAdd(reinterpret_cast(&p_dst[dst_offset]), *reinterpret_cast(&p_src[src_offset])); +#endif }) .Else([&](auto fwd) { static_assert(fwd(false), "atomic_add doesn't support this memory space"); diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index feb7393945..7960f3ccee 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -49,6 +49,12 @@ struct integer_divide_ceiler } }; +template +__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + template __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) { diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 4f435cfae4..d64d0adbc9 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -24,9 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA") set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu) endif() -add_executable(conv ${CONV_SOURCE}) -add_executable(col2im ${COL2IM_SOURCE}) -add_executable(conv_bwd_data ${CONV_BWD_DATA_SOURCE}) -target_link_libraries(conv PRIVATE host) -target_link_libraries(col2im PRIVATE host) -target_link_libraries(conv_bwd_data PRIVATE host) +add_executable(conv_driver ${CONV_SOURCE}) +add_executable(col2im_driver ${COL2IM_SOURCE}) +add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) +target_link_libraries(conv_driver PRIVATE host) +target_link_libraries(col2im_driver PRIVATE host) +target_link_libraries(conv_bwd_data_driver PRIVATE host) diff --git a/driver/include/device.hpp b/driver/include/device.hpp index c43f14b751..09812fae58 100644 --- a/driver/include/device.hpp +++ b/driver/include/device.hpp @@ -30,33 +30,81 @@ struct KernelTimer std::unique_ptr impl; }; +#if CK_DEVICE_BACKEND_AMD +using device_stream_t = hipStream_t; + template -float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) +void launch_kernel(F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + hipStream_t stream_id, + Args... args) +{ + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); +} + +template +float launch_and_time_kernel(F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + hipStream_t stream_id, + Args... args) { KernelTimer timer; -#if CK_DEVICE_BACKEND_AMD timer.Start(); - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...); + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); timer.End(); hipGetErrorString(hipGetLastError()); + + return timer.GetElapsedTime(); +} + #elif CK_DEVICE_BACKEND_NVIDIA +using device_stream_t = cudaStream_t; + +template +void launch_kernel(F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + cudaStream_t stream_id, + Args... args) +{ + const void* f = reinterpret_cast(kernel); + void* p_args[] = {&args...}; + + cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id); +} + +template +float launch_and_time_kernel(F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + cudaStream_t stream_id, + Args... args) +{ + KernelTimer timer; + const void* f = reinterpret_cast(kernel); void* p_args[] = {&args...}; timer.Start(); - cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0); + cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id); timer.End(); checkCudaErrors(error); -#endif return timer.GetElapsedTime(); } +#endif #endif diff --git a/driver/include/device_col2im_eb_nchw.hpp b/driver/include/device_col2im_eb_nchw.hpp index 187cb4eaf5..0dde1c15bd 100644 --- a/driver/include/device_col2im_eb_nchw.hpp +++ b/driver/include/device_col2im_eb_nchw.hpp @@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_operation, dim3(GridSize), diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 0d85048a54..4545488aa2 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -5,6 +5,10 @@ #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" +namespace launcher { + +using namespace ck; + template , - dim3(GridSize), - dim3(BlockSize), - 0, - gridwise_conv, - const_cast( - static_cast(in_nchw_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(wei_kcyx_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(out_nkhw_device_buf.GetDeviceBuffer()))); + float time = launch_and_time_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + gridwise_conv, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms, %f TFlop/s\n", time, @@ -145,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i in_nchw_device_buf.FromDevice(in_nchw.mData.data()); } + +} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp index affd41a017..89f19725bf 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp @@ -5,6 +5,10 @@ #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" +namespace launcher { + +using namespace ck; + template , - dim3(GridSize), - dim3(BlockSize), - 0, - gridwise_conv, - const_cast( - static_cast(in_nchw_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(wei_kcyx_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(out_nkhw_device_buf.GetDeviceBuffer()))); + float time = launch_and_time_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + gridwise_conv, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms, %f TFlop/s\n", time, @@ -153,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i in_nchw_device_buf.FromDevice(in_nchw.mData.data()); } + +} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 9e5e0cc553..2a4dfecbf3 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -5,6 +5,10 @@ #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" +namespace launcher { + +using namespace ck; + template ; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // BlockSize = 256, each thread hold 64 data + // for 1x1 weight, 8x8 input + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #endif constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); @@ -92,14 +161,24 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); - constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); - constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); + constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - constexpr index_t Htilda = Ho + right_pad_ho; - constexpr index_t Wtilda = Wo + right_pad_wo; + constexpr index_t HtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); + constexpr index_t WtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + + constexpr index_t HtildaRight = math::min( + Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WtildaRight = math::min( + Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; + constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t GemmM = C * Ytilda * Xtilda; - constexpr index_t GemmN = N * Htilda * Wtilda; + constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * math::integer_divide_ceil(GemmN, GemmNPerBlock); @@ -142,20 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - gridwise_conv, - const_cast( - static_cast(in_nchw_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(wei_kcyx_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(out_nkhw_device_buf.GetDeviceBuffer()))); + float time = launch_and_time_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + gridwise_conv, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms, %f TFlop/s\n", time, @@ -166,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i in_nchw_device_buf.FromDevice(in_nchw.mData.data()); } + +} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..ac2a247a67 --- /dev/null +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,186 @@ +#pragma once +#include +#include "device.hpp" +#include "tensor.hpp" +#include "gridwise_operation_wrapper.hpp" +#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp" + +namespace launcher { + +using namespace ck; + +template +void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, + Tensor& in_nchw, + WeiDesc wei_kcyx_desc, + const Tensor& wei_kcyx, + OutDesc out_nkhw_desc, + const Tensor& out_nkhw, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + std::size_t nrepeat) +{ + using namespace ck; + + constexpr index_t N = out_nkhw_desc.GetLengths()[0]; + constexpr index_t K = out_nkhw_desc.GetLengths()[1]; + constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; + + constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; + constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; + + constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; + constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; + + constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; + constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); + DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); + out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); + +#if 1 + // BlockSize = 256, each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#endif + + constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); + constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + + constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); + constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + + constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + constexpr index_t HtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); + constexpr index_t WtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + + constexpr index_t HtildaRight = math::min( + Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WtildaRight = math::min( + Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; + constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + + constexpr index_t GemmM = C; + constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; + + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + for(index_t i = 0; i < nrepeat; ++i) + { + float time = launch_and_time_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + gridwise_conv, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + in_nchw_device_buf.FromDevice(in_nchw.mData.data()); +} + +} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..f6ee9d71a5 --- /dev/null +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,232 @@ +#pragma once +#include +#include "device.hpp" +#include "tensor.hpp" +#include "gridwise_operation_wrapper.hpp" +#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" + +namespace launcher { + +using namespace ck; + +template +void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, + Tensor& in_nchw, + WeiDesc wei_kcyx_desc, + const Tensor& wei_kcyx, + OutDesc out_nkhw_desc, + const Tensor& out_nkhw, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + std::size_t nrepeat) +{ + constexpr index_t N = out_nkhw_desc.GetLengths()[0]; + constexpr index_t K = out_nkhw_desc.GetLengths()[1]; + constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; + + constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; + constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; + + constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; + constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; + + constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; + constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); + DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); + out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); + +#if 1 + // BlockSize = 256, each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // BlockSize = 256, each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#endif + + constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH); + constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW); + + constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; + constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; + + constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); + constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + + constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + constexpr index_t HtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); + constexpr index_t WtildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + + constexpr index_t HtildaRight = math::min( + Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WtildaRight = math::min( + Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; + constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + + constexpr index_t GemmM = C; + constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; + + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer; + + timer.Start(); + + static_for<0, Ytilda, 1>{}([&](auto ytilda_) { + static_for<0, Xtilda, 1>{}([&](auto xtilda_) { + constexpr index_t ytilda = decltype(ytilda_){}; + constexpr index_t xtilda = decltype(xtilda_){}; + + constexpr auto gridwise_conv = + GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ytilda, + xtilda, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThreadSubC, + GemmNPerThreadSubC, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmKPerThreadLoop, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + + launch_and_time_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + gridwise_conv, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + }); + }); + + timer.End(); + + float time = timer.GetElapsedTime(); + + printf("Elapsed time : %f ms, %f TFlop/s\n", + time, + (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / time); + usleep(std::min(time * 1000, float(10000))); + } + + in_nchw_device_buf.FromDevice(in_nchw.mData.data()); +} + +} // namespace launcher diff --git a/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp index e19051a9bd..5840947a45 100644 --- a/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp @@ -82,13 +82,13 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc, WoPerThread, InBlockCopyDataPerRead, WeiBlockCopyDataPerRead>; - float time = launch_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer())); + float time = launch_and_time_kernel(run_gridwise_convolution_kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms\n", time); usleep(std::min(time * 1000, float(10000))); diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index b1068d2a5e..39a05db992 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp index f95235821a..34a10e2d46 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp @@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp index 03cbc204c7..3b192c9a86 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp @@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, WeiBlockCopyDataPerRead_K, OutThreadCopyDataPerWrite_W>{}; - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp index a26347d032..50da0a7df5 100644 --- a/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp +++ b/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp @@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc, WeiBlockCopyDataPerRead, OutThreadCopyDataPerWrite>{}; - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp index 7e0134069f..23cef570fc 100644 --- a/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp @@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, WeiBlockCopyClusterLengths_C_K, WeiBlockCopyDataPerAccess_K>{}; - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 7efa7ef91b..2bb353a825 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 0 +#if 1 // BlockSize = 256, EperBlock = 8, each thread hold 64 data constexpr index_t BlockSize = 256; @@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 +#elif 1 // BlockSize = 64, each thread hold 64 data constexpr index_t BlockSize = 64; @@ -258,13 +258,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_operation, dim3(GridSize), dim3(BlockSize), 0, + 0, gridwise_conv, const_cast( static_cast(in_nchw_device_buf.GetDeviceBuffer())), diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp index f741b4abf6..ab309670c6 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp @@ -81,6 +81,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // BlockSize = 256, EPerBlock = 16, each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t EPerBlock = 16; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #elif 0 @@ -247,10 +284,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, + 0, static_cast(in_nchw_device_buf.GetDeviceBuffer()), static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast(out_nkhw_device_buf.GetDeviceBuffer())); diff --git a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp index 6b08c99678..1a67f48477 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp @@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K>{}; - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp index 6e9d240d02..f905eaec5a 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp @@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K>{}; - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 4077846c0e..24f46cfa8d 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 1 +#if 0 // BlockSize = 256, GemmKPerBlock = 8 constexpr index_t BlockSize = 256; @@ -83,6 +83,37 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // BlockSize = 256, GemmKPerBlock = 16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 // BlockSize = 256, GemmKPerBlock = 8 @@ -116,7 +147,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 +#elif 1 // BlockSize = 256, GemmKPerBlock = 16 // 1x1 filter, 8x8 image constexpr index_t BlockSize = 256; @@ -225,10 +256,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, + 0, static_cast(in_nchw_device_buf.GetDeviceBuffer()), static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast(out_nkhw_device_buf.GetDeviceBuffer())); diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp index cb51bfc1de..646d59dbf4 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp @@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(run_gridwise_convolution_kernel, + float time = + launch_and_time_kernel(run_gridwise_convolution_kernel, dim3(GridSize), dim3(BlockSize), 0, diff --git a/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp b/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp index c6be195213..7158032e8e 100644 --- a/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp @@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel( + float time = launch_and_time_kernel( gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw; - using ConvDilations = Sequence<2, 2>; + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; @@ -41,7 +44,7 @@ int main(int argc, char* argv[]) constexpr index_t C = 256; constexpr index_t HI = 34; constexpr index_t WI = 34; - constexpr index_t K = 128; + constexpr index_t K = 256; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -51,27 +54,27 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 1x1 filter, 8x8 image - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; + // 3x3, 28x28 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 1024; + constexpr index_t Y = 3; + constexpr index_t X = 3; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; #elif 0 // 1x1 filter, 8x8 image - constexpr index_t N = 128; - constexpr index_t C = 2048; + constexpr index_t N = 256; + constexpr index_t C = 1024; constexpr index_t HI = 8; constexpr index_t WI = 8; - constexpr index_t K = 384; + constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -83,25 +86,10 @@ int main(int argc, char* argv[]) #elif 0 // 1x1 filter, 7x7 image constexpr index_t N = 128; - constexpr index_t C = 832; + constexpr index_t C = 1024; constexpr index_t HI = 7; constexpr index_t WI = 7; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - constexpr index_t N = 128; - constexpr index_t C = 1280; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; + constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -123,27 +111,12 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 28x28 image constexpr index_t N = 128; - constexpr index_t C = 256; + constexpr index_t C = 128; constexpr index_t HI = 28; constexpr index_t WI = 28; constexpr index_t K = 128; @@ -153,105 +126,30 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 // 1x1 filter, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 768; + constexpr index_t C = 1024; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 128; + constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 1; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output - constexpr index_t N = 128; - constexpr index_t C = 288; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 // 5x5 filter, 2x2 pad, 7x7 input constexpr index_t N = 128; - constexpr index_t C = 48; + constexpr index_t C = 1024; constexpr index_t HI = 7; constexpr index_t WI = 7; - constexpr index_t K = 128; + constexpr index_t K = 1024; constexpr index_t Y = 5; constexpr index_t X = 5; @@ -260,28 +158,13 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>; -#elif 0 - // 7x1 filter, 3x0 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; #elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 128; + constexpr index_t C = 1024; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 128; + constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -290,6 +173,36 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>; +#elif 0 + // 7x1 filter, 3x0 pad, 17x17 input + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 1024; + constexpr index_t Y = 7; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<3, 0>; + using RightPads = Sequence<3, 0>; +#elif 0 + // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #endif constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); @@ -337,8 +250,12 @@ int main(int argc, char* argv[]) device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw -#else +#elif 1 device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw +#elif 0 + device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw +#elif 1 + device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw #endif (in_nchw_desc, in_nchw_device, diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 3762392847..bf3f598288 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -30,32 +30,63 @@ int main(int argc, char* argv[]) using namespace ck; #if 0 - constexpr index_t N = 8; - constexpr index_t C = 32; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 32; - constexpr index_t Y = 5; - constexpr index_t X = 5; + // 1x1 + constexpr index_t N = 256; + constexpr index_t C = 1024; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 1024; + constexpr index_t Y = 1; + constexpr index_t X = 1; using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 + // 1x7 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 1024; + constexpr index_t Y = 1; + constexpr index_t X = 7; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 3>; + using RightPads = Sequence<0, 3>; +#elif 0 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; constexpr index_t HI = 34; constexpr index_t WI = 34; - constexpr index_t K = 128; + constexpr index_t K = 256; constexpr index_t Y = 3; constexpr index_t X = 3; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output + constexpr index_t N = 128; + constexpr index_t C = 128; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 @@ -282,21 +313,6 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>; #elif 0 - // 7x1 filter, 3x0 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; -#elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -311,6 +327,21 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>; +#elif 1 + // 7x1 filter, 3x0 pad, 17x17 input + constexpr index_t N = 128; + constexpr index_t C = 128; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 128; + constexpr index_t Y = 7; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<3, 0>; + using RightPads = Sequence<3, 0>; #endif auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence{}); diff --git a/script/cmake-cuda.sh b/script/cmake-cuda.sh index 4828c22fc8..759564b8ee 100755 --- a/script/cmake-cuda.sh +++ b/script/cmake-cuda.sh @@ -14,13 +14,11 @@ cmake -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D DEVICE_BACKEND=NVIDIA \ -D CUDA_COMMON_INCLUDE_DIR="/package/install/cuda/10.1/NVIDIA_CUDA-10.1_Samples/common/inc" \ --D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \ +-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ ${MY_PROJECT_SOURCE} #-D BOOST_ROOT="/package/install/boost_1.67.0" \ #-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \ #-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \ +#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ diff --git a/script/cmake-cuda_docker.sh b/script/cmake-cuda_docker.sh index 592608166f..d414bd873d 100755 --- a/script/cmake-cuda_docker.sh +++ b/script/cmake-cuda_docker.sh @@ -15,11 +15,9 @@ cmake -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D DEVICE_BACKEND=NVIDIA \ -D CUDA_COMMON_INCLUDE_DIR="/root/NVIDIA_CUDA-10.1_Samples/common/inc" \ --D CMAKE_CUDA_FLAGS="-ccbin clang++-6.0 -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_60,code=sm_60 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \ +-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ ${MY_PROJECT_SOURCE} #-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \ +#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ diff --git a/script/extract_asm-cuda.sh b/script/extract_asm-cuda.sh index 4041ddc0ef..879e0b1a3d 100755 --- a/script/extract_asm-cuda.sh +++ b/script/extract_asm-cuda.sh @@ -1,3 +1,3 @@ -cuobjdump -xelf sm_60 ./driver/driver && nvdisasm --print-code -g driver.sm_60.cubin > driver.sm_60.asm -cuobjdump -xelf sm_61 ./driver/driver && nvdisasm --print-code -g driver.sm_61.cubin > driver.sm_61.asm -cuobjdump -xelf sm_70 ./driver/driver && nvdisasm --print-code -g driver.sm_70.cubin > driver.sm_70.asm +DRIVER=$1 +ARCH=$2 +cuobjdump -xelf $ARCH ./driver/$DRIVER && nvdisasm --print-code -g $DRIVER.$ARCH.cubin > $DRIVER.$ARCH.asm