From cd29b09a824311bb33fd3f66b4d97a291b5e90e0 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 19 May 2019 11:41:00 -0500 Subject: [PATCH] refactor --- driver/driver.hip.cpp | 4 +- ...3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp | 94 +++++++++---------- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 3e032a2333..ef0db9e5b7 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -608,11 +608,11 @@ int main(int argc, char* argv[]) device_convolution_direct_v2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 1 +#elif 0 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn -#elif 0 +#elif 1 device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw #elif 0 device_convolution_implicit_gemm_v2_chwn_cyxk_khwn diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp index 33673dbaa4..2bfe348d0d 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp @@ -100,8 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw const index_t wi_block_data_begin = wo_block_data_begin; // global tensor view - constexpr auto wei_c_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); // LDS tensor view // be careful of alignment @@ -359,13 +358,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - static_if{}([&](auto f_dummy) { // f_dummy do nothing but - // perfect forwarding. - // Using this trick to - // make this lambda a generic lambda, so it won't be compiled until - // instantiated + static_if{}([&](auto fwd) { + // fwd do nothing but perfect forwarding. + // Using this trick to make this lambda a generic lambda, so it won't be compiled until + // begin instantiated here static_assert( - (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), + (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), "wrong!"); // output is a 10d tensor @@ -373,38 +371,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw constexpr index_t N1 = NPerBlock / N2; constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); + (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); constexpr index_t W1 = WoPerBlock / W2; constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc) + .Fold(I3, Number{}, Number{}) + .Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number<1>{}, Number{}) + .Fold(I2, Number{}, Number<1>{}) + .Fold(I0, Number<1>{}, Number{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "a: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, + "a: out_n_k_h_w_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); + } #endif constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; @@ -421,8 +414,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw out_10d_thread_desc.GetLengths(), map_out_global2thread); // Number{}); - }).else_([&](auto f_dummy) { - static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + }).else_([&](auto fwd) { + static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && GemmNPerThreadSubC % NPerThread == 0, "wrong!"); @@ -431,29 +424,34 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + fwd(out_n_k_h_w_global_desc) + .Fold(I3, Number{}, Number{}, Number{}) + .Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = + fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number{}) + .Fold(I2, Number{}, Number<1>{}, Number{}) + .Fold(I0, Number<1>{}, Number{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "b: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - - } + print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, + "b: out_n_k_h_w_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); + } #endif constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};