diff --git a/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp index 2fbe301e7d..74a2b65571 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp @@ -114,10 +114,10 @@ struct GridwiseCol2Im_eb_nchw 1, BlockCopyDataPerAccess_B, BlockCopyDataPerAccess_B, - AddressSpace::vgpr, - AddressSpace::vgpr, - AddressSpace::global, - InMemoryDataOperation::atomic_add>( + AddressSpace::Vgpr, + AddressSpace::Vgpr, + AddressSpace::Global, + InMemoryDataOperation::AtomicAdd>( {e_block_data_on_global, b_block_data_on_global}, {e_block_data_on_global, b_block_data_on_global}); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index fa3c6f2ffb..1c20af2279 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -25,15 +25,15 @@ template {}, Merge>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - // weight tensor - constexpr auto wei_k_e_global_desc = - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3); - // input tensor constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, @@ -116,38 +111,42 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_e_b_global_desc = transform_tensor_descriptor( + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( in_n_c_y_ho_x_wo_global_desc, make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // GEMM - constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) - ? InMemoryDataOperation::none - : InMemoryDataOperation::atomic_add; + // \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need + // atomic, find out all of them + constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and + (ConvStrideW >= ConvDilationW * (X - 1) + 1); + + constexpr auto in_memory_op = + not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd; constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp index 6f244808ce..0fdb15a440 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -147,10 +147,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl 2, OutBlockCopySrcDataPerRead_B, OutBlockCopyDstDataPerWrite_N0, - AddressSpace::global, - AddressSpace::vgpr, - AddressSpace::lds, - InMemoryDataOperation::none>( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, b_block_data_on_global, 0}, {0, 0, 0}); // weight tensor @@ -187,10 +187,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl 2, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_C0, - AddressSpace::global, - AddressSpace::vgpr, - AddressSpace::lds, - InMemoryDataOperation::none>( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, e_block_data_on_global, 0}, {0, 0, 0}); // GEMM definition @@ -356,10 +356,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl #if 1 // debug // input: register to global memory, atomic add constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) - ? InMemoryDataOperation::none - : InMemoryDataOperation::atomic_add; + ? InMemoryDataOperation::Set + : InMemoryDataOperation::AtomicAdd; #else - constexpr auto in_memory_op = InMemoryDataOperation::atomic_add; + constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd; #endif constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster; @@ -432,8 +432,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl 4, 1, InThreadCopyDstDataPerWrite_B, - AddressSpace::vgpr, - AddressSpace::global, + AddressSpace::Vgpr, + AddressSpace::Global, in_memory_op>({0, 0, 0, 0, 0, 0}, {e_thread_data_on_global / E1, e_thread_data_on_global % E1, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 70a0738d8a..75381eb76f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -8,9 +8,9 @@ namespace ck { -// GemmM = C * Ytilda * Xtilda; -// GemmN = N * HtildaNonZero * WtildaNonZero; -// GemmK = K * Ydot * Xdot; +// GemmM = C * YTilda * XTilda; +// GemmN = N * HTildaSlice * WTildaSlice; +// GemmK = K * YDot * XDot; template {}, PassThrough{}, Embed, - Sequence>{}, + Sequence, + Sequence>{}, Embed, - Sequence>{}), + Sequence, + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple(Merge>{}, Merge>{}), + make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -134,33 +134,33 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, - Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{}, + Sequence, + Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{}, Embed, - Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}), + Sequence, + Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = transform_tensor_descriptor( out_n_k_ydot_htilda_xdot_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); constexpr auto out_gemmk_gemmn_global_desc = - transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, - make_tuple(Merge>{}, - Merge>{}), + transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -188,35 +188,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, + Sequence, Sequence, in_skip_all_out_of_bound_check>{}, Embed, + Sequence, Sequence, in_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = transform_tensor_descriptor( in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); constexpr auto in_gemmm_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, - make_tuple(Merge>{}, - Merge>{}), + transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -229,17 +229,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw decltype(wei_gemmk_gemmm_global_desc), decltype(out_gemmk_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc), - InMemoryDataOperation::none, + InMemoryDataOperation::Set, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp index 068b8c1931..a36e7edba0 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -8,10 +8,10 @@ namespace ck { -// Ytilda*Xtilda number of GEMMs -// GemmM = C; -// GemmN = N * HtildaNonZero * WtildaNonZero; -// GemmK = K * YdotNonZero * XdotNonZero; +// Number of GEMMs: YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice template {}, PassThrough{}, Embed, - Sequence, + Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}, Embed, - Sequence, + Sequence, + Sequence, wei_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -167,26 +167,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, - Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>, + Sequence, + Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, out_skip_all_out_of_bound_check>{}, Embed, - Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>, + Sequence, + Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, out_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = transform_tensor_descriptor( out_n_k_ydot_htilda_xdot_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( @@ -216,26 +216,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, + Sequence, Sequence, in_skip_all_out_of_bound_check>{}, Embed, + Sequence, Sequence, in_skip_all_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = transform_tensor_descriptor( in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( @@ -246,54 +246,49 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw __shared__ Float p_shared_block[shared_block_size]; -#if 1 // debug - static_for<0, Ytilda, 1>{}([&](auto ytilda_) { - static_for<0, Xtilda, 1>{}([&](auto xtilda_) { -#else - static_for<0, 1, 1>{}([&](auto ytilda_) { - static_for<0, 1, 1>{}([&](auto xtilda_) { -#endif - constexpr index_t ytilda = decltype(ytilda_){}; - constexpr index_t xtilda = decltype(xtilda_){}; + static_for<0, YTilda, 1>{}([&](auto iYTilda_) { + static_for<0, XTilda, 1>{}([&](auto iXTilda_) { + constexpr index_t iYTilda = decltype(iYTilda_){}; + constexpr index_t iXTilda = decltype(iXTilda_){}; - constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; - constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; + constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; // A matrix - constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc = + constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = transform_tensor_descriptor( wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Slice, + Slice, Sequence<0, 0>, - Sequence>{}, - Slice, - Sequence, - Sequence>{}), + Sequence>{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc, - make_tuple(Merge>{}, + wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, + make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // B matrix - constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc = + constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = transform_tensor_descriptor( - out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, + out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, - Sequence>{}), + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, @@ -306,23 +301,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw Sequence<2, 4>{})); constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc, - make_tuple(Merge>{}, - Merge>{}), + out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // C matrix - constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = + constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = transform_tensor_descriptor( - in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, + in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, @@ -335,9 +330,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw Sequence<2, 4>{})); constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_1_htildatrim_1_wtildatrim_global_desc, + in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, make_tuple(Merge>{}, - Merge>{}), + Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -349,17 +344,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw decltype(wei_gemmk_gemmm_global_desc), decltype(out_gemmk_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc), - InMemoryDataOperation::none, + InMemoryDataOperation::Set, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 11d8f6540f..b8d47b409f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -8,9 +8,10 @@ namespace ck { +// Number of GEMMs: YTilda * XTilda // GemmM = C -// GemmN = N * Htilda * Wtilda; -// GemmK = K * YdotNonZero * XdotNonZero +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice template {GemmM, GemmN, GemmK}; + } + + __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) + { + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + index_t iYTilda = gemm_id / XTilda; + index_t iXTilda = gemm_id % XTilda; + + return GetGemmSizeImpl(iYTilda, iXTilda); } template @@ -89,44 +167,39 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationW = ConvDilations{}[1]; -#if 0 // debug - // sanity-check for vectorized memory load - // TODO: this logic may not be correct for bwd-data - static_assert( - (Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) && - (X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); -#endif + //\todo static_assert for global vector load/store + // statc_assert(); - constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); - constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t Htilda = + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t Wtilda = + constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - constexpr index_t HtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); - constexpr index_t WtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + constexpr index_t iHTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t iWTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - constexpr index_t HtildaRight = math::min( - Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WtildaRight = math::min( - Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + constexpr index_t iHTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t iWTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; - constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; + constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; - constexpr bool wei_skip_all_out_of_bound_check = true; + // weight out-of-bound check can be skipped + constexpr bool wei_skip_out_of_bound_check = true; // weight tensor constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( @@ -134,20 +207,22 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, - Sequence, - wei_skip_all_out_of_bound_check>{}, + Sequence, + Sequence, + wei_skip_out_of_bound_check>{}, Embed, - Sequence, - wei_skip_all_out_of_bound_check>{}), + Sequence, + Sequence, + wei_skip_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); -#if 1 // debug - constexpr bool out_skip_all_out_of_bound_check = false; +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool out_skip_out_of_bound_check = false; #else - constexpr bool out_skip_all_out_of_bound_check = true; + //\todo sometimes output tensor out-of-bound check can be skipped, find out all such + // situations + constexpr bool out_skip_out_of_bound_check = true; #endif // output tensor @@ -156,35 +231,36 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, - Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>, - out_skip_all_out_of_bound_check>{}, + Sequence, + Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, + out_skip_out_of_bound_check>{}, Embed, - Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>, - out_skip_all_out_of_bound_check>{}), + Sequence, + Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, + out_skip_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc = + constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = transform_tensor_descriptor( out_n_k_ydot_htilda_xdot_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); -#if 1 // debug - constexpr bool in_skip_all_out_of_bound_check = false; +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool in_skip_out_of_bound_check = false; #else - constexpr bool in_skip_all_out_of_bound_check = true; + //\todo sometimes input out-of-bound check can be skipped, find out all such situations + constexpr bool in_skip_out_of_bound_check = true; #endif // input tensor @@ -193,7 +269,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple( PassThrough{}, PassThrough{}, - Pad, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}), + Pad, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); @@ -205,100 +281,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw make_tuple(PassThrough{}, PassThrough{}, Embed, + Sequence, Sequence, - in_skip_all_out_of_bound_check>{}, + in_skip_out_of_bound_check>{}, Embed, + Sequence, Sequence, - in_skip_all_out_of_bound_check>{}), + in_skip_out_of_bound_check>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = + constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = transform_tensor_descriptor( in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, make_tuple(PassThrough{}, PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); // GEMM - constexpr index_t ytilda = iYTilda; - constexpr index_t xtilda = iXTilda; - - constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot; - constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot; + constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; // A matrix - constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc = + constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = transform_tensor_descriptor( wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Slice, - Sequence<0, 0>, - Sequence>{}, - Slice, - Sequence, - Sequence>{}), + make_tuple( + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}, + Slice, + Sequence, + Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc, - make_tuple(Merge>{}, Merge>{}), + wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, + make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // B matrix - constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc = + constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = transform_tensor_descriptor( - out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence<0, 0>, - Sequence>{}), + out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc, - make_tuple(Merge>{}, - Merge>{}), + out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // C matrix - constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor( - in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_1_htildatrim_1_wtildatrim_global_desc, - make_tuple(Merge>{}, Merge>{}), + in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, + make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -310,19 +382,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw decltype(wei_gemmk_gemmm_global_desc), decltype(out_gemmk_gemmn_global_desc), decltype(in_gemmm_gemmn_global_desc), - InMemoryDataOperation::none, + InMemoryDataOperation::Set, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, + ThreadGemmAThreadCopySrcDataPerRead_GemmM, + ThreadGemmAThreadCopySrcDataPerRead_GemmN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadClusterLengths_GemmK_GemmM, Sequence<0, 1>, @@ -355,16 +427,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationW = ConvDilations{}[1]; - constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - constexpr index_t iYTilda = GemmId / Xtilda; - constexpr index_t iXTilda = GemmId % Xtilda; + constexpr index_t iYTilda = GemmId / XTilda; + constexpr index_t iXTilda = GemmId % XTilda; - static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda"); + static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda"); RunImpl(p_in_global, p_wei_global, p_out_global); } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index c8830e310d..a462c6b560 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -229,10 +229,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer 3, InBlockCopySrcDataPerRead_B, InBlockCopyDstDataPerWrite_N2, - AddressSpace::global, - AddressSpace::vgpr, - AddressSpace::lds, - InMemoryDataOperation::none>( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); // weight tensor @@ -269,10 +269,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer 1, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_K, - AddressSpace::global, - AddressSpace::vgpr, - AddressSpace::lds, - InMemoryDataOperation::none>( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, k_block_data_on_global}, {0, 0}); // GEMM definition @@ -344,6 +344,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); } + constexpr auto in_block_slice_copy_steps = Sequence{}; + constexpr auto wei_block_slice_copy_steps = Sequence{}; + // LDS double buffer: main body for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; e_block_data_begin += 2 * EPerBlock) @@ -366,8 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); + blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True); + blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True); __syncthreads(); @@ -393,8 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); + blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True); + blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True); __syncthreads(); @@ -482,14 +485,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer 3, 1, 1, - AddressSpace::vgpr, - AddressSpace::global, - InMemoryDataOperation::none>({0, 0, 0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - 0, - b_thread_data_on_global, - 0}) + AddressSpace::Vgpr, + AddressSpace::Global, + InMemoryDataOperation::Set>({0, 0, 0, 0, 0}, + {k_thread_data_on_global / K1, + k_thread_data_on_global % K1, + 0, + b_thread_data_on_global, + 0}) .Run(p_out_thread, p_out_global); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp index b5fde21c9f..133a4635f0 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp @@ -94,9 +94,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep constexpr auto True = integral_constant{}; constexpr auto generic_address_space = - integral_constant{}; + integral_constant{}; constexpr auto global_address_space = - integral_constant{}; + integral_constant{}; static_assert(ConvDirection == ConvolutionDirection::Forward || ConvDirection == ConvolutionDirection::BackwardWeight, @@ -141,13 +141,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep constexpr index_t E = C * Y * X; // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) && - (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); + static_assert( + (Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) && + (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), + "wrong! alignment requirement for vectorized global load of input tensor will " + "be violated"); // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, "wrong! cannot divide work evenly among block"); constexpr index_t KBlockWork = K / KPerBlock; @@ -357,37 +358,49 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep // LDS double buffer: tail { - // even iteration - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; + constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0); - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); + if(has_two_iteration_left) // if has 2 iteration left + { + // even iteration + Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; + Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - __syncthreads(); + blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); + blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadThreadBuffer( - p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); - blockwise_wei_copy.RunLoadThreadBuffer( - p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space); + __syncthreads(); - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + // LDS doubel buffer: load next data from device mem + blockwise_in_copy.RunLoadThreadBuffer( + p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); + blockwise_wei_copy.RunLoadThreadBuffer( + p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space); - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - // odd iteration - __syncthreads(); + // LDS double buffer: store next data to LDS + blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, + p_in_block_double + in_block_space); + blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, + p_wei_block_double + wei_block_space); - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); + // odd iteration + __syncthreads(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + } } // copy output: register to global memory diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index e3db0193bc..31b340e4c5 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -25,15 +25,15 @@ template , diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp index 8b3f8445d6..c6e36d5973 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp @@ -251,9 +251,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep // LDS double buffer: preload data into LDS { - blockwise_in_copy.template Run(p_in_global, + blockwise_in_copy.template Run(p_in_global, p_in_block_double); - blockwise_wei_copy.template Run(p_wei_global, + blockwise_wei_copy.template Run(p_wei_global, p_wei_block_double); } @@ -285,9 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep __syncthreads(); // LDS doubel buffer: load next data from device mem - blockwise_in_copy.template RunLoadThreadBuffer( + blockwise_in_copy.template RunLoadThreadBuffer( p_in_global, p_in_thread_buffer); - blockwise_wei_copy.template RunLoadThreadBuffer( + blockwise_wei_copy.template RunLoadThreadBuffer( p_wei_global, p_wei_thread_buffer); // LDS double buffer: GEMM on current data @@ -311,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep __syncthreads(); // LDS doubel buffer: load next data from device mem - blockwise_in_copy.template RunLoadThreadBuffer( + blockwise_in_copy.template RunLoadThreadBuffer( p_in_global, p_in_thread_buffer); - blockwise_wei_copy.template RunLoadThreadBuffer( + blockwise_wei_copy.template RunLoadThreadBuffer( p_wei_global, p_wei_thread_buffer); // LDS double buffer: GEMM on current data @@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) { threadwise_out_copy - .template Run(p_out_thread, + .template Run(p_out_thread, p_out_global); threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True); diff --git a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp index e2a5836edd..0ebd9dc4a1 100644 --- a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp @@ -60,7 +60,7 @@ __host__ __device__ constexpr auto template __host__ __device__ constexpr auto -make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated) + make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated) { using TDesc = ConstantTensorDescriptor_deprecated; static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index f796dac880..a2d6bb3fb1 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -267,7 +267,7 @@ struct TensorCoordinate private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(NativeTensorDescriptor) + MakeDummyTensorCoordinate(NativeTensorDescriptor) { return NativeTensorCoordinate>( make_zero_array()); @@ -275,7 +275,7 @@ struct TensorCoordinate template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(TransformedTensorDescriptor) + MakeDummyTensorCoordinate(TransformedTensorDescriptor) { return TransformedTensorCoordinate>( make_zero_array()); diff --git a/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp b/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp index da02abdd52..69659445a0 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp @@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated) + MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated) { return NormalTensorCoordinate_deprecated>(); } template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated) + MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated) { return MergedTensorCoordinate_deprecated< ConstantMergedTensorDescriptor_deprecated>(); diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp index 1597e4c577..b65edf5d44 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -64,10 +64,10 @@ template __host__ __device__ constexpr auto -reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, - Sequence, - Sequence, - Sequence) + reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, + Sequence, + Sequence, + Sequence) { return TransformedTensorDescriptor...>, @@ -78,7 +78,7 @@ reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, // reorder a NativeTensorDescriptor template __host__ __device__ constexpr auto -reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLower2Upper) + reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLower2Upper) { static_assert(is_valid_sequence_map{}, "wrong! MapLower2Upper is not a valid map"); @@ -96,7 +96,7 @@ reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLo // reorder a TransformedTensorDescriptor template __host__ __device__ constexpr auto -reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor, MapLower2Upper) + reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor, MapLower2Upper) { static_assert(is_valid_sequence_map{}, "wrong! MapLower2Upper is not a valid map"); @@ -152,9 +152,9 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript typename arithmetic_sequence_gen::type{}; constexpr auto right = typename arithmetic_sequence_gen::type{}; - // sanity-checknfoldable + // sanity-check if unfold-able static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)), - "wrong! not unfoldable"); + "wrong! not unfold-able"); // unfolded length, stride constexpr index_t unfold_length = diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp index 1c7bb92f6c..6106581896 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp @@ -23,8 +23,8 @@ template + index_t ThreadGemmADataPerRead_M, + index_t ThreadGemmBDataPerRead_N> struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 { struct MatrixIndex @@ -150,13 +150,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 decltype(a_thread_mtx), KPerThreadLoop, MPerThreadSubC, - DataPerReadA>{}; + ThreadGemmADataPerRead_M>{}; constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy{}; + ThreadGemmBDataPerRead_N>{}; constexpr auto threadwise_gemm = ThreadwiseGemmTransANormalBNormalC{}; + ThreadGemmADataPerRead_M>{}; constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy{}; + ThreadGemmBDataPerRead_N>{}; constexpr auto threadwise_gemm = ThreadwiseGemmTransANormalBNormalC + AddressSpace SrcAddressSpace = AddressSpace::Generic, + AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic, + AddressSpace DstAddressSpace = AddressSpace::Generic, + InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set> struct BlockwiseGenericTensorSliceCopy_v4 { static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); @@ -115,7 +115,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 template __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const { - static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr, + static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr, "wrong! This function use vgpr as its thread " "buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer " "to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. " @@ -157,7 +157,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 1, SrcAddressSpace, ThreadBufferAddressSpace, - InMemoryDataOperation::none>; + InMemoryDataOperation::Set>; using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2{}; + integral_constant{}; RunLoadThreadBuffer( p_block_src, p_thread_buffer, generic_address_space, generic_address_space); @@ -529,7 +529,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated BlockDstData* p_block_dst) const { constexpr auto generic_address_space = - integral_constant{}; + integral_constant{}; RunStoreThreadBuffer( p_thread_buffer, p_block_dst, generic_address_space, generic_address_space); @@ -548,7 +548,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated BlockSrcData p_thread_buffer[GetThreadBufferSize()]; constexpr auto generic_address_space = - integral_constant{}; + integral_constant{}; RunLoadThreadBuffer( p_block_src, p_thread_buffer, block_src_address_space, generic_address_space); @@ -562,7 +562,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const { constexpr auto generic_address_space = - integral_constant{}; + integral_constant{}; Run(p_block_src, p_block_dst, generic_address_space, generic_address_space); } diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index 56d779616f..e5c8e37495 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -22,15 +22,15 @@ template ( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, m_block_data_on_global}, {0, 0}); // B matrix in LDS memory, dst of blockwise copy @@ -165,10 +165,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 1, BBlockCopySrcDataPerRead, BBlockCopyDstDataPerWrite_N, - AddressSpace::global, - AddressSpace::vgpr, - AddressSpace::lds, - InMemoryDataOperation::none>( + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( {0, n_block_data_on_global}, {0, 0}); // GEMM definition @@ -181,35 +181,33 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc); // sanity check - static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 && - NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0, + static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && + NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, "wrong!"); - constexpr index_t GemmMRepeat = - MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster); + constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - constexpr index_t GemmNRepeat = - NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster); + constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); // c_thread_mtx definition: this is a mess // TODO:: more elegent way of defining c_thread_mtx constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); + Number{}, Number{}); const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< BlockSize, decltype(a_k_m_block_mtx_desc), decltype(b_k_n_block_mtx_desc), decltype(c_m0m1_n0n1_thread_mtx_desc), - MPerThreadSubC, - NPerThreadSubC, + MPerThread, + NPerThread, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster, - KPerThreadLoop, - ThreadGemmDataPerReadM, - ThreadGemmDataPerReadN>{}; + KPerThread, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N>{}; // LDS allocation for A and B: be careful of alignment constexpr index_t a_block_space = @@ -233,6 +231,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 b_blockwise_copy.Run(p_b_global, p_b_block_double); } + constexpr auto a_block_slice_copy_steps = Sequence{}; + constexpr auto b_block_slice_copy_steps = Sequence{}; + // LDS double buffer: main body for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; k_block_data_begin += 2 * KPerBlock) @@ -255,8 +256,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - a_blockwise_copy.MoveSrcSliceWindow(Sequence{}, True); - b_blockwise_copy.MoveSrcSliceWindow(Sequence{}, True); + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); __syncthreads(); @@ -282,8 +283,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - a_blockwise_copy.MoveSrcSliceWindow(Sequence{}, True); - b_blockwise_copy.MoveSrcSliceWindow(Sequence{}, True); + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); __syncthreads(); @@ -317,16 +318,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 // input: register to global memory { - constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; constexpr index_t M0 = M / M1; - constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; constexpr index_t N0 = N / N1; // define input tensor descriptor for threadwise copy // thread input tensor, src of threadwise copy constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); + Sequence{}); constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( c_m_n_global_desc, @@ -352,8 +353,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 CThreadCopySrcDstVectorReadWriteDim, 1, CThreadCopyDstDataPerWrite, - AddressSpace::vgpr, - AddressSpace::global, + AddressSpace::Vgpr, + AddressSpace::Global, CGlobalMemoryDataOperation>( {0, 0, 0, 0}, {m_thread_data_on_global / M1, diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 53331c10cd..1538623e41 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -21,9 +21,9 @@ template + AddressSpace SrcAddressSpace = AddressSpace::Generic, + AddressSpace DstAddressSpace = AddressSpace::Generic, + InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set> struct ThreadwiseGenericTensorSliceCopy_v4r2 { static constexpr index_t nDim = SliceLengths::Size(); @@ -115,8 +115,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 transfer_data( + AddressSpace::Vgpr, + InMemoryDataOperation::Set>( p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); } } @@ -146,7 +146,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { transfer_data( p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); @@ -265,12 +265,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 transfer_data(p_src, - src_nonlinear_coord.GetOffset() + - src_linear_offset, - p_src_long_vector, - buffer_offset); + AddressSpace::Vgpr, + InMemoryDataOperation::Set>(p_src, + src_nonlinear_coord.GetOffset() + + src_linear_offset, + p_src_long_vector, + buffer_offset); } } @@ -303,7 +303,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { transfer_data( p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); @@ -404,8 +404,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 transfer_data( + AddressSpace::Vgpr, + InMemoryDataOperation::Set>( p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); } } @@ -448,7 +448,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { transfer_data(p_dst_long_vector, buffer_offset, diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp index f28ef935b1..71460f33d2 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp @@ -333,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated // 2. src_normal_offset must be calculatd at compile time (guaranteed by // algorithm) // 3. src_merged_offset can be runtime value (no assumption imposed) - static_if{}([&](auto fwd) { + static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING vector_data = amd_intrinsic_buffer_load( fwd(p_src), src_merged_offset, src_normal_offset); @@ -442,7 +442,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated // 2. dst_normal_offset must be calculatd at compile time (guaranteed by // algorithm) // 3. dst_merged_offset can be runtime value (no assumption imposed) - static_if{}([&](auto fwd) { + static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING amd_intrinsic_buffer_store( vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset); @@ -464,7 +464,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated __device__ void Run(const SrcData* p_src, DstData* p_dst) const { constexpr auto generic_address_space = - integral_constant{}; + integral_constant{}; Run(p_src, p_dst, generic_address_space, generic_address_space); } diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 9d32cc81af..7ff99e0af6 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -54,21 +54,23 @@ #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 +#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0 +#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0 namespace ck { enum AddressSpace { - generic, - global, - lds, - vgpr + Generic, + Global, + Lds, + Vgpr }; enum InMemoryDataOperation { - none, - atomic_add + Set, + AtomicAdd }; #if CK_UNSIGNED_INDEX_TYPE diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 294be0536a..2ba30a183b 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -10,13 +10,14 @@ template -__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +__device__ void set_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { using vector_t = typename vector_type::MemoryType; #if CK_USE_AMD_BUFFER_ADDRESSING // TODO: use static_if::ElseIf, instead of nested static_if - static_if{}([&](auto) { + static_if{}([&](auto) { // buffer_load requires: // 1) p_src must be in global memory space, d_dst must be vgpr // 2) p_src to be a block-invariant pointer. @@ -24,7 +25,8 @@ __device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t *reinterpret_cast(&p_dst[dst_offset]) = amd_intrinsic_buffer_load(p_src, src_offset, 0); }).Else([&](auto) { - static_if{}([&](auto) { + static_if{}([&](auto) { // buffer_store requires: // 1) p_src must be in vgpr space, d_dst must be global memory // 2) p_dst to be a block-invariant pointer. @@ -50,19 +52,18 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in { using vector_t = typename vector_type::MemoryType; - static_if{}( - [&](auto) { + static_if{}([&](auto) { #if CK_USE_AMD_BUFFER_ATOMIC_ADD - amd_intrinsic_buffer_atomic_add( - *reinterpret_cast(&p_src[src_offset]), p_dst, dst_offset, 0); + amd_intrinsic_buffer_atomic_add( + *reinterpret_cast(&p_src[src_offset]), p_dst, dst_offset, 0); #else - atomicAdd(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); + atomicAdd(reinterpret_cast(&p_dst[dst_offset]), + *reinterpret_cast(&p_src[src_offset])); #endif - }) - .Else([&](auto fwd) { - static_assert(fwd(false), "atomic_add doesn't support this memory space"); - }); + }).Else([&](auto fwd) { + static_assert(fwd(false), "atomic_add doesn't support this memory space"); + }); } template __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { - static_assert(DstInMemOp == InMemoryDataOperation::none || - DstInMemOp == InMemoryDataOperation::atomic_add, + static_assert(DstInMemOp == InMemoryDataOperation::Set || + DstInMemOp == InMemoryDataOperation::AtomicAdd, "wrong! InMemoryDataOperation not supported!"); // TODO: use static_if::ElseIf - static_if{}([&](auto) { - copy_data( + static_if{}([&](auto) { + set_data( p_src, src_offset, p_dst, dst_offset); }); - static_if{}([&](auto) { + static_if{}([&](auto) { atomic_add_data( p_src, src_offset, p_dst, dst_offset); }); diff --git a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in index 3a984a0e07..4061aff125 100644 --- a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in @@ -23,14 +23,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in { using vector_t = typename vector_type::MemoryType; - static_if{}( - [&](auto) { - atomicAdd(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); - }) - .Else([&](auto fwd) { - static_assert(fwd(false), "atomic_add doesn't support this memory space"); - }); + static_if{}([&](auto) { + atomicAdd(reinterpret_cast(&p_dst[dst_offset]), + *reinterpret_cast(&p_src[src_offset])); + }).Else([&](auto fwd) { + static_assert(fwd(false), "atomic_add doesn't support this memory space"); + }); } template __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { - static_assert(DstInMemOp == InMemoryDataOperation::none || - DstInMemOp == InMemoryDataOperation::atomic_add, + static_assert(DstInMemOp == InMemoryDataOperation::Set || + DstInMemOp == InMemoryDataOperation::AtomicAdd, "wrong! InMemoryDataOperation not supported!"); // TODO: use static_if::ElseIf - static_if{}([&](auto) { + static_if{}([&](auto) { copy_data( p_src, src_offset, p_dst, dst_offset); }); - static_if{}([&](auto) { + static_if{}([&](auto) { atomic_add_data( p_src, src_offset, p_dst, dst_offset); }); diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 20f51552f6..4c9cd85d5e 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -107,27 +107,22 @@ __host__ __device__ constexpr T min(T x, Ts... xs) template __host__ __device__ constexpr T gcd(T x, T y) { - if(x == 0) + if(x == y || x == 0) { return y; } - - if(y == 0) + else if(y == 0) { return x; } - - if(x == y) - { - return x; - } - - if(x > y) + else if(x > y) { return gcd(x - y, y); } - - return gcd(x, y - x); + else + { + return gcd(x, y - x); + } } template @@ -150,10 +145,10 @@ __host__ __device__ constexpr T lcm(T x, T y) return (x * y) / gcd(x, y); } -template -__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs) +template +__host__ __device__ constexpr auto lcm(X x, Ys... ys) { - return lcm(x, lcm(y, zs...)); + return lcm(x, lcm(ys...)); } template diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 4545488aa2..00dcbdc832 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 1 +#if 0 // BlockSize = 256, each thread hold 64 data constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // BlockSize = 256, each thread hold 64 data + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #endif constexpr index_t GemmM = C * Y * X; @@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index e82e72c179..622062018d 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #endif - constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); - constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - constexpr index_t HtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); - constexpr index_t WtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + constexpr index_t HTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t WTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - constexpr index_t HtildaRight = math::min( - Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WtildaRight = math::min( - Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + constexpr index_t HTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; - constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; + constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - constexpr index_t GemmM = C * Ytilda * Xtilda; - constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; + constexpr index_t GemmM = C * YTilda * XTilda; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * math::integer_divide_ceil(GemmN, GemmNPerBlock); @@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp index 3b84c0ba9b..2fec94b08b 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif - constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); - constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - constexpr index_t HtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); - constexpr index_t WtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + constexpr index_t HTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t WTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - constexpr index_t HtildaRight = math::min( - Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WtildaRight = math::min( - Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + constexpr index_t HTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; - constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; + constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * math::integer_divide_ceil(GemmN, GemmNPerBlock); @@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index c34953d919..8ae1c72527 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; @@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif - constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW); + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h; - constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w; + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); - constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - constexpr index_t HtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); - constexpr index_t WtildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]); + constexpr index_t HTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t WTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - constexpr index_t HtildaRight = math::min( - Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WtildaRight = math::min( - Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + constexpr index_t HTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t WTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; - constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; + constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; + constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * math::integer_divide_ceil(GemmN, GemmNPerBlock); @@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i for(index_t i = 0; i < nrepeat; ++i) { - using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< + using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< GridSize, BlockSize, T, @@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmThreadGemmDataPerReadM, GemmThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, @@ -196,21 +196,29 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i KernelTimer timer; timer.Start(); - static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) { + static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) { constexpr index_t gemm_id = decltype(gemm_id_){}; - launch_kernel(run_gridwise_convolution_backward_data_v4r1, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); + constexpr index_t gemm_k = gemm_sizes.At(2); + constexpr bool is_gemm_not_empty = gemm_k > 0; + + // only compile and run if GEMM is no empty + static_if{}([&](auto fwd) { + launch_kernel( + run_gridwise_convolution_backward_data_v4r1, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + }); }); timer.End(); diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index 17a0cd7e98..a94dcb55bf 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -23,17 +23,16 @@ int main(int argc, char* argv[]) { using namespace launcher; -#if 0 - // 3x3 filter, 2x2 stride, 35x35 input - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 1024; - constexpr index_t Y = 3; - constexpr index_t X = 3; +#if 1 + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; - using ConvStrides = Sequence<2, 2>; + using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; using LeftPads = Sequence<0, 0>; @@ -158,7 +157,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>; -#elif 1 +#elif 0 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -188,7 +187,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; -#elif 0 +#elif 1 // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output constexpr index_t N = 128; constexpr index_t C = 1024;