diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 198bae87e4..ddfafdd020 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -271,7 +271,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 1 +#if 0 gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn #else gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer @@ -306,7 +306,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, float time = launch_kernel(gridwise_conv.Run, dim3(GridSize), dim3(BlockSize), - gridwise_conv.GetSharedMemoryUsage(), + gridwise_conv.GetDynamicSharedMemoryUsage(), static_cast(in_chwn_device_buf.GetDeviceBuffer()), static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), static_cast(out_khwn_device_buf.GetDeviceBuffer())); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 0c4777b465..35696912cf 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -34,9 +34,8 @@ template -class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn +struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn { - public: __host__ __device__ constexpr index_t GetInputBlockElementSpace() const { constexpr auto I0 = Number<0>{}; @@ -97,7 +96,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn return wei_cyxk_block_desc.GetElementSpace(Number{}); } - __host__ __device__ constexpr index_t GetSharedMemoryUsage() const + __host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const { return (GetInputBlockElementSpace() + GetWeightBlockElementSpace()) * sizeof(Float); @@ -300,22 +299,38 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) { - // load data - //blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - //blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); +// load data +#if 0 + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); +#elif 0 + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); +#elif 1 Float4 tmp_in, tmp_wei; - Float4* glb_in_p = (Float4 *)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); - Float4* loc_in_p = (Float4 *)(p_in_block + blockwise_in_copy.mDstMyThreadOffset); + Float4* glb_in_p = + (Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); + Float4* loc_in_p = (Float4*)(p_in_block + blockwise_in_copy.mDstMyThreadOffset); - Float4* glb_wei_p = (Float4 *)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); - Float4* loc_wei_p = (Float4 *)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset); + Float4* glb_wei_p = + (Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); + Float4* loc_wei_p = (Float4*)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset); global_load(tmp_in, glb_in_p); global_load(tmp_wei, glb_wei_p); vmcnt(0); ds_write_b128(tmp_in, loc_in_p); ds_write_b128(tmp_wei, loc_wei_p); +#endif __syncthreads(); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 488b0a0da7..57572a7c0b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -34,9 +34,10 @@ template -class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer +struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer { - public: + __host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const { return 0; } + __global__ static void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) @@ -239,9 +240,27 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer const Float* p_wei_global_block_offset = p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - // preload data into LDS +// preload data into LDS +#if 0 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); +#else + Float4 tmp_in, tmp_wei; + Float4* glb_in_p = + (Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); + Float4* glb_wei_p = + (Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); + + global_load(tmp_in, glb_in_p); + global_load(tmp_wei, glb_wei_p); + + Float4* loc_in_p = (Float4*)(p_in_block_0 + blockwise_in_copy.mDstMyThreadOffset); + Float4* loc_wei_p = (Float4*)(p_wei_block_0 + blockwise_wei_copy.mDstMyThreadOffset); + + vmcnt(0); + ds_write_b128(tmp_in, loc_in_p); + ds_write_b128(tmp_wei, loc_wei_p); +#endif p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); @@ -270,9 +289,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer // load next data #if 0 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); -#elif 1 Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; @@ -281,6 +297,15 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); +#elif 1 + Float4 tmp_in, tmp_wei; + Float4* glb_in_p = + (Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); + Float4* glb_wei_p = + (Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); + + global_load(tmp_in, glb_in_p); + global_load(tmp_wei, glb_wei_p); #endif // compute on current data @@ -290,22 +315,31 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 +#if 0 blockwise_gemm.Run -#else +#elif 0 blockwise_gemm.Run_RegisterDoubleBuffer +#elif 1 + blockwise_gemm.Run_asm #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); } } -#if 1 +#if 0 blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#elif 1 + Float4* loc_in_p = (Float4*)(p_in_block_next + blockwise_in_copy.mDstMyThreadOffset); + Float4* loc_wei_p = (Float4*)(p_wei_block_next + blockwise_wei_copy.mDstMyThreadOffset); + + vmcnt(0); + ds_write_b128(tmp_in, loc_in_p); + ds_write_b128(tmp_wei, loc_wei_p); #endif } @@ -321,15 +355,17 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 +#if 0 blockwise_gemm.Run -#else +#elif 1 + blockwise_gemm.Run_asm +#elif 0 blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); } } } diff --git a/src/include/inline_asm.hpp b/src/include/inline_asm.hpp index bfe6d4e70f..2dfee50688 100644 --- a/src/include/inline_asm.hpp +++ b/src/include/inline_asm.hpp @@ -4,28 +4,34 @@ typedef float Float4 __attribute__((ext_vector_type(4))); extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; -inline __device__ void vmcnt(int cnt) { - if(cnt == 0) { - asm volatile ("\n \ +inline __device__ void vmcnt(int cnt) +{ + if(cnt == 0) + { + asm volatile("\n \ s_waitcnt vmcnt(0) \n \ - "::); + " ::); } - else if(cnt == 1) { - asm volatile ("\n \ + else if(cnt == 1) + { + asm volatile("\n \ s_waitcnt vmcnt(1) \n \ - "::); + " ::); } - else if(cnt == 2) { - asm volatile ("\n \ + else if(cnt == 2) + { + asm volatile("\n \ s_waitcnt vmcnt(2) \n \ - "::); + " ::); } - else if(cnt == 4) { - asm volatile ("\n \ + else if(cnt == 4) + { + asm volatile("\n \ s_waitcnt vmcnt(2) \n \ - "::); + " ::); } - else { + else + { assert(0); } } @@ -397,13 +403,13 @@ inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0) } } -inline __device__ void global_load(Float4 &r, Float4* ptr) { - asm volatile("\n \ +inline __device__ void global_load(Float4& r, Float4* ptr) +{ + asm volatile("\n \ global_load_dwordx4 %0, %1, off \n \ " - :"=v"(r) - :"v"(ptr) - ); + : "=v"(r) + : "v"(ptr)); } inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0) @@ -411,8 +417,6 @@ inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0) asm volatile("\n \ ds_write_b128 %0, %1 \n \ " - : - : "v"(__to_local(lds)), "v"(r) - ); + : + : "v"(__to_local(lds)), "v"(r)); } -