diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index a2194c4266..0c4777b465 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -301,8 +301,21 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn __syncthreads()) { // load data - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + //blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + //blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + Float4 tmp_in, tmp_wei; + Float4* glb_in_p = (Float4 *)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset); + Float4* loc_in_p = (Float4 *)(p_in_block + blockwise_in_copy.mDstMyThreadOffset); + + Float4* glb_wei_p = (Float4 *)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset); + Float4* loc_wei_p = (Float4 *)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset); + + global_load(tmp_in, glb_in_p); + global_load(tmp_wei, glb_wei_p); + vmcnt(0); + ds_write_b128(tmp_in, loc_in_p); + ds_write_b128(tmp_wei, loc_wei_p); __syncthreads(); diff --git a/src/include/inline_asm.hpp b/src/include/inline_asm.hpp index 09e2284ec2..bfe6d4e70f 100644 --- a/src/include/inline_asm.hpp +++ b/src/include/inline_asm.hpp @@ -4,6 +4,32 @@ typedef float Float4 __attribute__((ext_vector_type(4))); extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; +inline __device__ void vmcnt(int cnt) { + if(cnt == 0) { + asm volatile ("\n \ + s_waitcnt vmcnt(0) \n \ + "::); + } + else if(cnt == 1) { + asm volatile ("\n \ + s_waitcnt vmcnt(1) \n \ + "::); + } + else if(cnt == 2) { + asm volatile ("\n \ + s_waitcnt vmcnt(2) \n \ + "::); + } + else if(cnt == 4) { + asm volatile ("\n \ + s_waitcnt vmcnt(2) \n \ + "::); + } + else { + assert(0); + } +} + inline __device__ void lgkmcnt(int cnt) { #if 1 @@ -370,3 +396,23 @@ inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0) assert(0); } } + +inline __device__ void global_load(Float4 &r, Float4* ptr) { + asm volatile("\n \ + global_load_dwordx4 %0, %1, off \n \ + " + :"=v"(r) + :"v"(ptr) + ); +} + +inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0) +{ + asm volatile("\n \ + ds_write_b128 %0, %1 \n \ + " + : + : "v"(__to_local(lds)), "v"(r) + ); +} +