diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 4ab5cf8efb..fecd7c5ca1 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -18,6 +18,8 @@ template {}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; constexpr auto True = integral_constant{}; @@ -75,10 +74,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2); - constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3); + constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); + constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); @@ -87,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); constexpr index_t N0 = N / (N1 * N2); @@ -95,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t E = C * Y * X; + // sanity-check for vectorized memory load + static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1, + "wrong! global vector load of input tensor is wrong"); + + static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), + "wrong! aligment requirement for vectorized global load of input tensor will " + "be violated"); + // divide block work by [K, B] static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, "wrong! cannot divide work evenly among block"); @@ -113,15 +124,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw // input tensor // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] - constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) - .Slice(I3, Number{}) - .Fold(I0, Number{}, Number{}) - .Extract(Sequence<0, 1, 2, 4, 5>{}); + constexpr auto in_n0_n1_n2_h_w_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Fold(I0, Number{}, Number{}) + .Extract(Sequence<0, 1, 2, 4, 5>{}); // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) - .Slice(I3, Number{}) - .Extract(Sequence<1, 2, 3>{}); + constexpr auto in_c_y_x_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Extract(Sequence<1, 2, 3>{}); // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( @@ -131,17 +144,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw Sequence<3, 6, 7>{}, Sequence<5>{}); -#if 0 - if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc, - "in_n0_n1_n2_h_w_global_desc: "); - print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: "); - print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc, - "in_e_n1_b_n2_global_merged_desc: "); - } -#endif - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy // be careful of LDS alignment constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( @@ -206,13 +208,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + constexpr auto a_e_k_block_mtx_desc = + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor(wei_e_k_block_desc); constexpr auto b_e_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); + make_ConstantMatrixDescriptor_from_ConstantTensorDescriptor( + in_e_n1_b_n2_block_desc.Unfold(I1, I3)); // sanity check static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == @@ -242,15 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw GemmDataPerReadA, GemmDataPerReadB>{}; - // choose GEMM implementation here - const auto run_blockwise_gemm = [&](auto... Xs) { -#if 1 - return blockwise_gemm.Run(Xs...); -#else - return blockwise_gemm.Run_amd_asm(Xs...); -#endif - }; - // LDS allocation for input and weight: be careful of alignment constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, WeiBlockCopyDstDataPerWrite_K, @@ -281,7 +273,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw __syncthreads(); - run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread); + blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread); __syncthreads(); @@ -293,7 +285,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw { constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t K0 = K / (K1 * K2); // define tensor descriptor for threadwise copy // output memory layout descriptor in register diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index eb132cd331..67395b978d 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t B = (N * Ho * Wo) / (N1 * N2); -#if 0 +#if 1 // each thread hold 64 data constexpr index_t BlockSize = 256; @@ -94,7 +94,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 +#elif 0 // each thread hold 32 data constexpr index_t BlockSize = 256;