diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index 335655d5a3..6f6df1b0eb 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data()); -#if 0 +#if 1 // for 3x3, 34x34 constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 64; @@ -105,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; constexpr index_t OutThreadCopyDataPerWrite = 2; @@ -145,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 128; #elif 0 - // 3x3 58x58, NKC = 64, 64, 256 + // 3x3 58x58 constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 64; constexpr index_t CPerBlock = 4; @@ -166,21 +168,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 128; #elif 0 - // 3x3 58x58, NKC = 16,256,128 - constexpr index_t NPerBlock = 8; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 2; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 16; - constexpr index_t CPerThread = 1; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; - - constexpr index_t BlockSize = 128; -#elif 1 // for 7x7, 38x38 constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; @@ -210,9 +197,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t BlockSize = 128; +#elif 0 + // for 3x3, 56x56, v1, Pacal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t BlockSize = 128; #elif 1 - // for 3x3, 56x56 + // for 3x3, 56x56, v1r2, Pascal + // for 3x3, 34x34, v1r2, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -231,6 +251,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 2; constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 1; + constexpr index_t GemmDataPerReadB = 1; constexpr index_t InBlockCopy_ThreadPerDimC = 2; constexpr index_t InBlockCopy_ThreadPerDimH = 4; @@ -317,7 +339,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 0 +#if 1 GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn #elif 1 GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn @@ -346,6 +368,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, GemmMLevel1Cluster, GemmNLevel1Cluster, GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB, Sequence + index_t BatchPerThread, + index_t DataPerReadA, + index_t DataPerReadB> struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 { index_t mMyThreadOffsetA = 0; @@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 mMyThreadOffsetA, a_thread_mtx, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths()); + a_thread_sub_mtx.GetLengths(), + Number{}); } // copy B-sub to form B @@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 mMyThreadOffsetB, b_thread_mtx, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths()); + b_thread_sub_mtx.GetLengths(), + Number{}); } // loop over batch @@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, a_thread_mtx, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths()); + a_thread_sub_mtx.GetLengths(), + Number{}); } } @@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, b_thread_mtx, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths()); + b_thread_sub_mtx.GetLengths(), + Number{}); } } } diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 1ebc780bf8..3e8d10e193 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -14,7 +14,9 @@ template + index_t KPerThreadLoop, + index_t DataPerReadA, + index_t DataPerReadB> struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 { struct MatrixIndex @@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 mMyThreadOffsetA, a_thread_mtx, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths()); + a_thread_sub_mtx.GetLengths(), + Number{}); } #pragma unroll @@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 mMyThreadOffsetB, b_thread_mtx, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths()); + b_thread_sub_mtx.GetLengths(), + Number{}); } // C = A * B @@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, a_thread_sub_mtx, p_a_thread_0 + m_repeat * MPerThreadSubC, - a_thread_sub_mtx.GetLengths()); + a_thread_sub_mtx.GetLengths(), + Number{}); } #pragma unroll @@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, b_thread_sub_mtx, p_b_thread_0 + n_repeat * NPerThreadSubC, - b_thread_sub_mtx.GetLengths()); + b_thread_sub_mtx.GetLengths(), + Number{}); } bool even_loop = true; @@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 m_repeat * MPerLevel1Cluster, a_thread_sub_mtx, p_a_thread_next + m_repeat * MPerThreadSubC, - a_thread_sub_mtx.GetLengths()); + a_thread_sub_mtx.GetLengths(), + Number{}); } #pragma unroll @@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 n_repeat * NPerLevel1Cluster, b_thread_sub_mtx, p_b_thread_next + n_repeat * NPerThreadSubC, - b_thread_sub_mtx.GetLengths()); + b_thread_sub_mtx.GetLengths(), + Number{}); } // C = A * B diff --git a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp index 7d1a383c51..7f2bd71a49 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp @@ -30,6 +30,8 @@ template {}; + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; // LDS: be careful of alignment constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); diff --git a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 322a7007c6..a13671fa08 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -30,6 +30,8 @@ template {}; + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; // LDS: be careful of alignment constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp index a7d01ec4ca..b32a19de87 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp @@ -30,6 +30,8 @@ template {}; + HoPerThread, + GemmDataPerReadA, + GemmDataPerReadB>{}; // LDS: be careful of alignment constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number{}); @@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn // register Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; -#if 1 +#if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp index 32e0175b9e..cbe913b65e 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp @@ -26,6 +26,8 @@ template {}; + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; // LDS: be careful of alignment constexpr index_t max_align = diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 250010c195..33edd968b3 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -27,6 +27,8 @@ template {}; + GemmKPerThreadLoop, + GemmDataPerReadA, + GemmDataPerReadB>{}; // LDS: be careful of alignment constexpr index_t max_align = diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index 590b4ba1cb..fea45f30a9 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -1,23 +1,29 @@ #pragma once -template +template __device__ void threadwise_matrix_copy(SrcMatrix, const Float* __restrict__ p_src, DstMatrix, Float* __restrict__ p_dst, - Sequence) + Sequence, + Number) { + static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0"); + + using vector_t = typename vector_type::MemoryType; + constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; for(index_t i = 0; i < NRow; ++i) { - for(index_t j = 0; j < NCol; ++j) + for(index_t j = 0; j < NCol; j += DataPerRead) { const index_t src_index = src_mtx.Get1dIndex(i, j); const index_t dst_index = dst_mtx.Get1dIndex(i, j); - p_dst[dst_index] = p_src[src_index]; + *reinterpret_cast(&p_dst[dst_index]) = + *reinterpret_cast(&p_src[src_index]); } } }