diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index 0d066dbdb7..61d1d9cd6c 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -77,8 +77,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, constexpr unsigned KPerThread = 16; constexpr unsigned CPerThread = 1; - constexpr unsigned GemmRowThreadPerCluster = 4; - constexpr unsigned GemmColumnThreadPerCluster = 8; + constexpr unsigned GemmThreadPerColumnPerCluster = 4; + constexpr unsigned GemmThreadPerRowPerCluster = 8; constexpr unsigned InBlockCopyThreadPerDim0 = 4; constexpr unsigned InBlockCopyThreadPerDim1 = 16; @@ -120,7 +120,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, #if 1 gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw -#else +#elif 0 gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline #endif <<>>(in_cnhw_desc, diff --git a/src/include/blockwise_2d_tensor_op.cuh b/src/include/blockwise_2d_tensor_op.cuh index dd440fe1a2..71b916f27f 100644 --- a/src/include/blockwise_2d_tensor_op.cuh +++ b/src/include/blockwise_2d_tensor_op.cuh @@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3 assert(false); } } + + if(has_tail_d0) + { + constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + + if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) + { + if(DataPerRead == 1) + { + p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = + p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + + nloop_d0 * dst_loop_stride)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + nloop_d0 * src_loop_stride)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + + nloop_d0 * dst_loop_stride)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + nloop_d0 * src_loop_stride)); + } + else + { + assert(false); + } + } + } } }; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index 59c0aa419a..65d508e9ea 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -20,8 +20,8 @@ template {}; // LDS diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh index 0222752f17..465971a7f4 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline.cuh @@ -20,8 +20,8 @@ template {}; // LDS diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 5acb9d13be..20a6dff81f 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -20,8 +20,8 @@ template __global__ void @@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, false, false, CPerThread, - GemmThreadPerClusterRow, - GemmThreadPerClusterColumn, + GemmThreadPerColumnPerCluster, + GemmThreadPerRowPerCluster, true>{}; // LDS diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh index 9d9ff8f3fb..c54919cdf4 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline.cuh @@ -20,8 +20,8 @@ template __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline( @@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline false, false, CPerThread, - GemmRowThreadPerCluster, - GemmColumnThreadPerCluster, + GemmThreadPerColumnPerCluster, + GemmThreadPerRowPerCluster, true>{}; // LDS