diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index f49fcbdbd6..920b65e1b6 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -190,8 +190,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 256; -#elif 1 - // 1x1, 14x14, Vega 10 +#elif 0 + // 1x1, 14x14, Vega 20 constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; @@ -219,6 +219,36 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t BlockSize = 128; +#elif 1 + // 1x1, 14x14, Vega 20, hack CPerBlock = 1 + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 1; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t BlockSize = 128; #endif diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index f7cb637d4e..366616d60c 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -420,9 +420,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 } template - __device__ void Run_asm(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, + __device__ void Run_asm(const FloatA* const __restrict__ p_a_block, + const FloatB* const __restrict__ p_b_block, + FloatC* const __restrict__ p_c_thread, Accumulator f_accum) const { constexpr auto True = integral_constant{}; @@ -462,11 +462,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 && + KPerThreadLoop == 1 && K == 1, + "asm is not for this mtx shape"); + + const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA; + #pragma unroll // loop over k for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { - //#pragma unroll +#if 0 +#pragma unroll // copy A-sub to form A for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { @@ -475,9 +482,65 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + mMyThreadOffsetA, a_thread_mtx, - p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.NCol(p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths()); } +#elif 1 + // this produce right result + using vectorA_t = typename vector_type::MemoryType; // this is float4* + + asm volatile( + "\n \ + ds_read_b128 %0, %1 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))) + : "v"(__to_local( + (void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA)))); + + asm volatile("\n \ + ds_read_b128 %0, %1 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast( + p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) + : "v"(__to_local(( + void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) + + mMyThreadOffsetA)))); +#elif 0 + // this produce wrong result + using vectorA_t = typename vector_type::MemoryType; // this is float4* + + asm volatile( + "\n \ + ds_read_b128 %0, %2 \n \ + ds_read_b128 %1, %3 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))), + "=v"(*(reinterpret_cast(p_a_thread + + a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) + : "v"(__to_local( + (void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))), + "v"(__to_local((void*)(p_a_block + + a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) + + mMyThreadOffsetA)))); +#elif 1 + // this produce wrong result + using vectorA_t = typename vector_type::MemoryType; // this is float4* + + asm volatile( + "\n \ + ds_read_b128 %0, %1 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))) + : "v"(__to_local((void*)(p_a_block_thread_offset)))); + + asm volatile("\n \ + ds_read_b128 %0, %1 offset:16 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast( + p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) + : "v"(__to_local((void*)(p_a_block_thread_offset)))); + +#endif //#pragma unroll // copy B-sub to form B diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index 6770b590a9..5e3b88f670 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -5,6 +5,8 @@ #include "Array.hip.hpp" #include "functional.hip.hpp" +extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; + __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 08aa8f90f5..2964fcedde 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -238,7 +238,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; #if 0 blockwise_gemm.Run -#elif 0 +#elif 1 blockwise_gemm.Run_asm #elif 1 blockwise_gemm.Run_RegisterDoubleBuffer diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index f15bc1807b..74efbca112 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); } } @@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); } } } diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index 8cf2404c63..ece0c54d10 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -10,7 +10,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; -#if 0 +#if 1 for(index_t i = 0; i < NRow; ++i) { for(index_t j = 0; j < NCol; ++j) @@ -21,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, p_dst[dst_index] = p_src[src_index]; } } -#elif 1 +#elif 0 static_assert(NCol == 4, "only for NCol == 4"); using vector_t = typename vector_type::MemoryType; @@ -31,15 +31,21 @@ __device__ void threadwise_matrix_copy(SrcMatrix, const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0); -#if 1 - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); +#if 0 + *(reinterpret_cast(&p_dst[dst_index]) = + *(reinterpret_cast(&p_src[src_index])); +#elif 0 + asm volatile("\n \ + ds_read2_b64 %0, %1 offset1:1 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast(&p_dst[dst_index]))) + : "v"(__to_local((void*)(&p_src[src_index])))); #elif 1 asm volatile("\n \ - ds_read_b128 %0, %1, offset:0 \n \ - " - : "=v"(*(reinterpret_cast(p_dst+dst_index))) - : "v"((uint32_t)(p_src + src_index))); + ds_read_b128 %0, %1 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(*(reinterpret_cast(&p_dst[dst_index]))) + : "v"(__to_local((void*)(&p_src[src_index])))); #endif } #endif