From bdbc0eaad175cd4054105cfb3fc812a8526e3b49 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Apr 2019 17:58:44 -0500 Subject: [PATCH] cleaning up dead code --- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 2 +- driver/driver.hip.cpp | 2 +- {build => script}/cmake-cuda.sh | 0 {build => script}/cmake-hip.sh | 0 src/include/blockwise_batched_gemm.hip.hpp | 455 ------------------ src/include/blockwise_gemm.hip.hpp | 217 +-------- 6 files changed, 6 insertions(+), 670 deletions(-) rename {build => script}/cmake-cuda.sh (100%) rename {build => script}/cmake-hip.sh (100%) diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 497aa3e9c1..a3489bc8cc 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -221,7 +221,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 128; #elif 0 - // 1x1, 14x14, Vega 20, hack CPerBlock = 1 + // 1x1, 14x14, Vega 20, hack CPerBlock = 1 for debugging constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 1; diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 0ea091e607..a83e4082c7 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; diff --git a/build/cmake-cuda.sh b/script/cmake-cuda.sh similarity index 100% rename from build/cmake-cuda.sh rename to script/cmake-cuda.sh diff --git a/build/cmake-hip.sh b/script/cmake-hip.sh similarity index 100% rename from build/cmake-hip.sh rename to script/cmake-hip.sh diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index bf2777f140..deba68a261 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -1,231 +1,6 @@ #pragma once #include "threadwise_gemm.hip.hpp" -template -struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC -{ - index_t mMyThreadOffsetA = 0; - index_t mMyThreadOffsetB = 0; - - struct MatrixIndex - { - index_t batch; - index_t row; - index_t col; - }; - - __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + - ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) - : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row)); - - mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + - ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) - : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0)); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); - print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); - print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); - - printf("%u %u, %u %u %u, %u %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - c_thread_mtx_index.batch, - c_thread_mtx_index.row, - c_thread_mtx_index.col, - mMyThreadOffsetA, - mMyThreadOffsetB); - } -#endif - } - - __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const - { - - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), - "wrong! k dimension not consistent!"); - - constexpr index_t MPerBlock = a_block_mtx.NCol(); - constexpr index_t NPerBlock = b_block_mtx.NCol(); - - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - // divide thread work - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); - static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); - static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); - - constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; - constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; - constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - - static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, - "wrong! wrong BlockSize"); - - if(DistributeThreadAlongColumnFirst) - { - // num of operations can be reduced - const index_t b_work_id = thread_id / (MThreadWork * NThreadWork); - index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); - const index_t m_work_id = itmp / NThreadWork; - const index_t n_work_id = itmp - m_work_id * NThreadWork; - - return MatrixIndex{ - b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; - } - else - { - // not implemented - assert(false); - } - } - else - { - // not implemented - assert(false); - } - } - - // this should be optimized away if input is known - __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c) - { - return MatrixIndex{batch_in_c, m_in_c, n_in_c}; - } - - template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // a is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - // loop over k - for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - // read first batch of a, b - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - k_begin * a_block_mtx.RowStride(), - a_thread_mtx, - p_a_thread, - a_thread_mtx.GetLengths()); - - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - k_begin * b_block_mtx.RowStride(), - b_thread_mtx, - p_b_thread, - b_thread_mtx.GetLengths()); - - // loop over batch - for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + ib * ThreadMatrixStrideC, - f_accum); - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - (ib + 1) * BlockMatrixStrideA + - +k_begin * a_block_mtx.RowStride(), - a_thread_mtx, - p_a_thread, - a_thread_mtx.GetLengths()); - } - - if(BlockMatrixStrideB != 0) - { - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - (ib + 1) * BlockMatrixStrideB + - k_begin * b_block_mtx.RowStride(), - b_thread_mtx, - p_b_thread, - b_thread_mtx.GetLengths()); - } - } - - // do last batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, - f_accum); - } - } - } -}; - template - __device__ void Run_v3(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - // loop over k - //#pragma unroll - for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - // read first batch of A, B - // copy A-sub to form A - //#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i) - { -#if 1 - for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex(k_begin + i, - m_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetA]; - } -#else - static_assert(a_thread_sub_mtx.NCol() == 4, "asm only read 4xfp32"); - -#endif - } - } - - // copy B-sub to form B - //#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i) - { - for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex(k_begin + i, - n_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetB]; - } - } - } - - // loop over batch - //#pragma unroll - for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) - { -#if 0 - for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(index_t j = 0; j < c_thread_mtx.NCol(); ++j) - { - const index_t aindex = - a_thread_mtx.Get1dIndex(k, i); // A is transposed - const index_t bindex = b_thread_mtx.Get1dIndex(k, j); - const index_t cindex = - c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } -#elif 1 - static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, - "asm is only for 16x4"); - - const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); - for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) - { - const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); - - asm volatile("\n \ - v_mac_f32 %0, %4, %5 \n \ - v_mac_f32 %1, %4, %6 \n \ - v_mac_f32 %2, %4, %7 \n \ - v_mac_f32 %3, %4, %8 \n \ - " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3])); - } -#endif - } - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { - //#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i) - { - for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, - m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex( - k_begin + i, m_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; - } - } - } - } - - if(BlockMatrixStrideB != 0) - { - //#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i) - { - for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, - n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex( - k_begin + i, n_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; - } - } - } - } - } - - // do last batch of gemm - for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) - { -#if 0 - for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(index_t j = 0; j < c_thread_mtx.NCol(); ++j) - { - const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const index_t bindex = b_thread_mtx.Get1dIndex(k, j); - const index_t cindex = c_thread_mtx.Get1dIndex(i, j) + - (BatchPerThread - 1) * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } -#elif 1 - static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, - "asm is only for 16x4"); - - const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); - for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) - { - const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const index_t cindex = - c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; - - asm volatile("\n \ - v_mac_f32 %0, %4, %5 \n \ - v_mac_f32 %1, %4, %6 \n \ - v_mac_f32 %2, %4, %7 \n \ - v_mac_f32 %3, %4, %8 \n \ - " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3])); - } -#endif - } - } - } - template __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, FloatC* __restrict__ p_c_block) const diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 3e9c57d15f..fee5b704f3 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -3,215 +3,6 @@ extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; -template -struct BlockwiseGemmBlockABlockBThreadC -{ - index_t mMyThreadOffsetA = 0; - index_t mMyThreadOffsetB = 0; - - struct MatrixIndex - { - index_t row; - index_t col; - }; - - __device__ BlockwiseGemmBlockABlockBThreadC() - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) - : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); - - mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) - : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); - print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); - print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); - - printf("%u %u, %u %u %u, %u %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - c_thread_mtx_index.batch, - c_thread_mtx_index.row, - c_thread_mtx_index.col, - mMyThreadOffsetA, - mMyThreadOffsetB); - } -#endif - } - - __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const - { - - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), - "wrong! k dimension not consistent!"); - - constexpr index_t MPerBlock = a_block_mtx.NCol(); - constexpr index_t NPerBlock = b_block_mtx.NCol(); - - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - // divide thread work - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0, - "MPerBlock % (MPerThread * MThreadPerCluster) != 0"); - - static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0, - "NPerBlock % (NPerThread * NThreadPerCluster) != 0"); - - constexpr index_t MClusterWork = - (MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster); - - constexpr index_t NClusterWork = - (NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster); - - static_assert(BlockSize == - (MClusterWork * MThreadPerCluster) * - (NClusterWork * NThreadPerCluster), - "wrong! wrong BlockSize"); - - if(DistributeThreadAlongColumnFirst) - { - const index_t cluster_work_block_id = - thread_id / (MThreadPerCluster * NThreadPerCluster); - - const index_t thread_work_cluster_id = - thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster); - - const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork; - const index_t n_cluster_work_block_id = - cluster_work_block_id - m_cluster_work_block_id * NClusterWork; - - const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster; - const index_t n_thread_work_cluster_id = - thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster; - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, \t" - "MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t" - "m_cluster_work_block_id %u n_cluster_work_block_id %u \t" - "m_thread_work_cluster_id %u n_thread_work_cluster_id %u \t" - "\n", - get_block_1d_id(), get_thread_local_1d_id(), - MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster, - m_cluster_work_block_id, n_cluster_work_block_id, - m_thread_work_cluster_id, n_thread_work_cluster_id); - } -#endif - - return MatrixIndex{m_cluster_work_block_id * (MThreadPerCluster * MPerThread) + - m_thread_work_cluster_id * MPerThread, - n_cluster_work_block_id * (NThreadPerCluster * NPerThread) + - n_thread_work_cluster_id * NPerThread}; - } - else - { - // not implemented - assert(false); - } - } - else - { - // not implemented - assert(false); - } - } - - // this should be optimized away if input is known - __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c, - index_t n_in_c) - { - return MatrixIndex{m_in_c, n_in_c}; - } - - template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // a is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - // loop over k - for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - k_begin * a_block_mtx.RowStride(), - a_thread_mtx, - p_a_thread, - a_thread_mtx.GetLengths()); - - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - k_begin * b_block_mtx.RowStride(), - b_thread_mtx, - p_b_thread, - b_thread_mtx.GetLengths()); - - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread, - f_accum); - } - } - } -}; - // if following number are power of 2, index calculation shall be greatly reduced: // MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster template - __device__ void Run_v2(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const + __device__ void Run_PipelineReadAndCompute(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const { constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{};