diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp index 2d1cd532de..22c5d11564 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp @@ -98,8 +98,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn const index_t wi_block_data_begin = wo_block_data_begin; // global tensor view - constexpr auto wei_c_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); // LDS tensor view // be careful of alignment @@ -212,44 +211,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn // set threadwise output tensor to 0 threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread); -#if 1 - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - for(index_t y = 0; y < Y; ++y) - { -#pragma unroll - for(index_t x = 0; x < X; ++x) - { - blockwise_in_copy.Run( - p_in_global_block_offset + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block); - - blockwise_wei_copy.Run( - p_wei_global_block_offset + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_wei_block); - - __syncthreads(); - - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - } - } - } -#else for(index_t y = 0; y < Y; ++y) { for(index_t x = 0; x < X; ++x) @@ -282,7 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn } } } -#endif // output: register to global mem, const auto c_thread_mtx_begin = diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp index 0df27009ad..8fa2aeb89f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp @@ -128,17 +128,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( Sequence{}); -// blockwise copy -// input: format is [C, Hi, Wi, N] -#if 0 - const auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; -#else + // blockwise copy + // input: format is [C, Hi, Wi, N] const auto blockwise_in_copy = Blockwise4dTensorCopy3{}; -#endif // blockwise wei copy // format is [CPerBlock, X * KPerBlock] diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp index 0f5305d196..713834400d 100644 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp @@ -478,9 +478,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn #elif 0 GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn -#elif 1 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn #elif 0 + GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn +#elif 1 GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer #endif