diff --git a/driver/conv.cu b/driver/conv.cu index feb665d96c..7852e9bed7 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -357,7 +357,7 @@ int main() constexpr unsigned C = 1; constexpr unsigned HI = 34; constexpr unsigned WI = 34; - constexpr unsigned K = 4; + constexpr unsigned K = 1; constexpr unsigned S = 3; constexpr unsigned R = 3; #elif 1 diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 9585113fb9..6ee62008cc 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -67,29 +67,29 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); #if 0 - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 4; + constexpr unsigned BPerBlock = 256; + constexpr unsigned KPerBlock = 1; constexpr unsigned CPerBlock = 1; - constexpr unsigned BPerThread = 4; + constexpr unsigned BPerThread = 8; constexpr unsigned KPerThread = 1; constexpr unsigned CPerThread = 1; - constexpr unsigned ThreadPerClusterRow = 4; - constexpr unsigned ThreadPerClusterColumn = 16; + constexpr unsigned ThreadPerClusterRow = 1; + constexpr unsigned ThreadPerClusterColumn = 4; - constexpr unsigned BlockSize = 128; + constexpr unsigned BlockSize = 32; #elif 1 constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 2; - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; constexpr unsigned CPerThread = 1; constexpr unsigned ThreadPerClusterRow = 4; - constexpr unsigned ThreadPerClusterColumn = 16; + constexpr unsigned ThreadPerClusterColumn = 4; constexpr unsigned BlockSize = 128; #endif diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh index 760cc1ad4d..99ecc4b962 100644 --- a/src/include/gemm.cuh +++ b/src/include/gemm.cuh @@ -388,9 +388,9 @@ struct blockwise_gemm_block_a_block_b_thread_c const unsigned thread_work_cluster_id = thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster); - const unsigned m_cluster_work_block_id = cluster_work_block_id / NThreadPerCluster; + const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork; const unsigned n_cluster_work_block_id = - cluster_work_block_id - m_cluster_work_block_id * NThreadPerCluster; + cluster_work_block_id - m_cluster_work_block_id * NClusterWork; const unsigned m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster; @@ -401,12 +401,12 @@ struct blockwise_gemm_block_a_block_b_thread_c if(get_block_1d_id() == 0) { printf("%u %u, \t" - //"MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t" + "MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t" "m_cluster_work_block_id %u n_cluster_work_block_id %u \t" "m_thread_work_cluster_id %u n_thread_work_cluster_id %u \t" "\n", get_block_1d_id(), get_thread_local_1d_id(), - //MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster, + MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster, m_cluster_work_block_id, n_cluster_work_block_id, m_thread_work_cluster_id, n_thread_work_cluster_id); } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 5ab64ca96d..2a609f046a 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -239,10 +239,13 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); } #endif - if(k_data < K && n_data < N && h_data < Ho && w_data < Wo) + if(n_data < N && h_data < Ho && w_data < Wo) { +#if 1 p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; +#endif + #if 0 if(get_block_1d_id() == 0) {