diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 6ac6f4158b..fb673bd6a6 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -189,7 +189,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 256; -#elif 1 +#elif 0 // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; @@ -219,7 +219,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 128; -#elif 0 +#elif 1 // 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer constexpr index_t BPerBlock = 128; constexpr index_t KPerBlock = 128; diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index cecc5c7e1b..f9c8a3ee21 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -409,7 +409,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; @@ -583,7 +583,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; @@ -667,9 +667,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 1 - device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 0 + device_implicit_gemm_convolution_1_chwn_cyxk_khwn +#elif 1 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/script/compile-hip.sh b/script/compile-hip.sh index 4c91e8a293..2ebf032d73 100755 --- a/script/compile-hip.sh +++ b/script/compile-hip.sh @@ -1,6 +1,7 @@ #!/bin/bash export KMDUMPISA=1 export KMDUMPLLVM=1 +export KMOPTLLC=-mattr=+enable-ds128 make -j driver -/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.isa +/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.asm diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 3e8d10e193..3a4d34faf0 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 const FloatB* __restrict__ p_b_block, FloatC* __restrict__ p_c_thread) const { - static_assert(is_same::value && is_same::value && - is_same::value, - "Run_asm only deal with float\n"); - constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; @@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( Number{}, Number{}, Number{}); + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + // assertion for inline asm + static_assert(is_same::value && is_same::value && + is_same::value, + "Run_asm only deal with float\n"); + static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && MPerThread == 8 && NPerThread == 8, "Run_asm cannot deal with this GEMM shape yet\n"); using Float4 = vector_type::MemoryType; - float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; - - FloatA* p_a_thread = p_thread; - FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace(); - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_c = (Float4*)(p_c_thread); - void* a_loc = (void*)(p_a_block + mMyThreadOffsetA); - void* b_loc = (void*)(p_b_block + mMyThreadOffsetB); - int lds_a_block_off = sizeof(Float) * M; - int lds_b_block_off = sizeof(Float) * N; - int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); - int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); - ds_read_b128(reg_a[0], a_loc, 0); - ds_read_b128(reg_b[0], b_loc, 0); - ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1); - ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1); - lgkmcnt(2); + reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA]); + reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB]); + reg_b[1] = + *reinterpret_cast(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]); + reg_a[1] = + *reinterpret_cast(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - lgkmcnt(1); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); - lgkmcnt(0); #pragma unroll - for(int k_i = 1; k_i < K; k_i++) + for(index_t k = 1; k < K; ++k) { - ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off); + reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA + k * M]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off); + reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB + k * N]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); - ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off); - ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off); - lgkmcnt(2); + reg_b[1] = *reinterpret_cast( + &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]); + reg_a[1] = *reinterpret_cast( + &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - lgkmcnt(1); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); - lgkmcnt(0); } outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); diff --git a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 33edd968b3..82c3b11a4f 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); -#if 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_double); -#else - vmcnt(0); - blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard, - p_in_block_double); - blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard, - p_wei_block_double); -#endif } // register @@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); @@ -271,31 +262,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer { for(index_t x = 0; x < X; ++x) { -#if 1 +#if 0 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 +#elif 1 blockwise_gemm.Run_asm #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread); } } -#if 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#else - vmcnt(0); - blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard, - p_wei_block_next); -#endif } } @@ -320,32 +303,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer { for(index_t x = 0; x < X; ++x) { -#if 1 +#if 0 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 +#elif 1 blockwise_gemm.Run_asm #endif - (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + y * Wi + x, - p_out_thread); + (p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_double + y * Wi + x, + p_out_thread); } } -#if 1 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_double + wei_block_space); -#else - vmcnt(0); - blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard, - p_wei_block_double + wei_block_space); -#endif // odd __syncthreads(); @@ -354,17 +328,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer { for(index_t x = 0; x < X; ++x) { -#if 1 +#if 0 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 +#elif 1 blockwise_gemm.Run_asm #endif - (p_wei_block_double + wei_block_space + - wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_double + in_block_space + y * Wi + x, - p_out_thread); + (p_wei_block_double + wei_block_space + + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_double + in_block_space + y * Wi + x, + p_out_thread); } } }