diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index bab0b5d7fe..bf7cdc8c5a 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 256; -#elif 0 +#elif 1 // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 1dd52cbc8b..c83e44b31b 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -332,12 +332,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 n_repeat * NPerLevel1Cluster + n_in_sub_c}; } - template + template __device__ void Run_asm(const FloatA* __restrict__ p_a_block, const FloatB* __restrict__ p_b_block, FloatC* __restrict__ p_c_thread, - Accumulator f_accum, - Number) const + Accumulator f_accum) const { constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; @@ -378,45 +377,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + //auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; + //auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; + Float4* reg_a = (Float4*)(p_a_thread); + Float4* reg_b = (Float4*)(p_b_thread); + Float4* reg_c = (Float4*)(p_c_thread); + void* a_loc = (void *)(p_a_block + mMyThreadOffsetA); + void* b_loc = (void *)(p_b_block + mMyThreadOffsetB); #pragma unroll // loop over k for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { - - - auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; - auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; - Float4* reg_a = (Float4*)(p_a_thread); - Float4* reg_b = (Float4*)(p_b_thread); - void* a_loc = (void *)(p_a_block + a_src_index); - void* b_loc = (void *)(p_b_block + b_src_index); - - //asm volatile("\n \ - //ds_read_b128 %0, %2 \n \ - //ds_read_b128 %1, %2 offset:256\n \ - //" - //: "=v"(reg_a[0]), "=v"(reg_a[1]) - //: "v"(__to_local(a_loc)) - //); +#if 0 ds_read_b128(reg_a[0], a_loc, 0); ds_read_b128(reg_a[1], a_loc, 256); - ds_read_b128(reg_b[0], b_loc, 0); ds_read_b128(reg_b[1], b_loc, 128); lgkmcnt(0); - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread, - f_accum); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); +#else + ds_read_b128(reg_a[0], a_loc, k_begin * 512); + ds_read_b128(reg_b[0], b_loc, k_begin * 256); + ds_read_b128(reg_b[1], b_loc, 128 + k_begin * 256); + ds_read_b128(reg_a[1], a_loc, 256 + k_begin * 512); + lgkmcnt(2); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + lgkmcnt(1); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + lgkmcnt(0); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); +#endif } } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index e0c5f80b3a..657c233e5e 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -323,7 +323,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), p_in_block + y * Wi + x, p_out_thread, - f_accum, Number()); + f_accum); } } } diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index 9b42e017a6..f1af308440 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -12,7 +12,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; -#if 0 +#if 1 for(index_t i = 0; i < NRow; ++i) { for(index_t j = 0; j < NCol; ++j) @@ -72,6 +72,7 @@ __device__ void threadwise_gemm(MatrixA, for(index_t k = 0; k < K; ++k) { +#if 1 for(index_t i = 0; i < M; i+=4) { const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed @@ -88,6 +89,13 @@ __device__ void threadwise_gemm(MatrixA, outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]); } } +#else + const Float4 *a_vec = (const Float4 *)p_a_thread; + const Float4 *b_vec = (const Float4 *)p_b_thread; + Float4 *c_vec = (Float4 *)p_c_thread; + + outerProduct8x8(a_vec, b_vec, c_vec); +#endif } } else