From c138e2126d0fe5b6b87cdb660dbd44f0192bf0f8 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 26 Apr 2019 18:03:55 -0500 Subject: [PATCH] nchw*cyxk*nkhw on AMD --- ...lution_implicit_gemm_v1_nchw_cyxk_khwn.hpp | 4 +- ...lution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp | 175 +++++++----- driver/driver.hip.cpp | 2 +- src/include/amd_inline_asm.hip.hpp | 260 +++++++++--------- ...3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp | 17 +- ..._implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp | 43 +-- 6 files changed, 278 insertions(+), 223 deletions(-) diff --git a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp index 22658d35ef..26fa9c8ca8 100644 --- a/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp +++ b/driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp @@ -217,9 +217,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, constexpr auto gridwise_conv = #if 0 GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn -#elif 1 - GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn #elif 0 + GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn +#elif 1 GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn #endif ; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 1; + + using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 2; +#elif 0 + // for 3x3, 34x34, v1r3, Vega 20 + constexpr index_t BlockSize = 256; + + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 16; + + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 2; + + using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 4; +#elif 1 + // for 3x3, 34x34, v1r3, Vega 20, try + constexpr index_t BlockSize = 256; + + constexpr index_t NPerBlock = 4; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 8; + + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; + using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>; + using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; + constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW + constexpr index_t InBlockReorderDataPerWrite_N = 1; + + using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used + constexpr index_t WeiBlockCopyDataPerRead_K = 4; + + constexpr index_t OutThreadCopyDataPerWrite_W = 1; +#elif 0 // for 3x3, 28x28, v1r2, Pascal constexpr index_t BlockSize = 128; @@ -90,76 +195,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>; constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_W = 2; -#elif 0 - // for 3x3, 28x28, v1r3, Pascal, bad - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet - - using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 2; -#elif 1 - // for 3x3, 34x34, v1r3, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 16; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet - - using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - constexpr index_t OutThreadCopyDataPerWrite_W = 2; #endif diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 09d0e61433..fd6a3bbf8d 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -608,7 +608,7 @@ int main(int argc, char* argv[]) device_direct_convolution_2_vectorized_nchw_kcyx_nkhw #elif 0 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn -#elif 1 +#elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn #elif 1 device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw diff --git a/src/include/amd_inline_asm.hip.hpp b/src/include/amd_inline_asm.hip.hpp index 44b480f542..1e453d3cf5 100644 --- a/src/include/amd_inline_asm.hip.hpp +++ b/src/include/amd_inline_asm.hip.hpp @@ -203,520 +203,520 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in asm volatile("\n \ ds_read_b128 %0, %1 offset:0\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 64) { asm volatile("\n \ ds_read_b128 %0, %1 offset:64\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 128) { asm volatile("\n \ ds_read_b128 %0, %1 offset:128\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 192) { asm volatile("\n \ ds_read_b128 %0, %1 offset:192\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 256) { asm volatile("\n \ ds_read_b128 %0, %1 offset:256\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 320) { asm volatile("\n \ ds_read_b128 %0, %1 offset:320\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 384) { asm volatile("\n \ ds_read_b128 %0, %1 offset:384\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 448) { asm volatile("\n \ ds_read_b128 %0, %1 offset:448\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 512) { asm volatile("\n \ ds_read_b128 %0, %1 offset:512\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 576) { asm volatile("\n \ ds_read_b128 %0, %1 offset:576\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 640) { asm volatile("\n \ ds_read_b128 %0, %1 offset:640\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 704) { asm volatile("\n \ ds_read_b128 %0, %1 offset:704\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 768) { asm volatile("\n \ ds_read_b128 %0, %1 offset:768\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 832) { asm volatile("\n \ ds_read_b128 %0, %1 offset:832\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 896) { asm volatile("\n \ ds_read_b128 %0, %1 offset:896\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 960) { asm volatile("\n \ ds_read_b128 %0, %1 offset:960\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1024) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1024\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1088) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1088\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1152) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1152\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1216) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1216\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1280) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1280\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1344) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1344\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1408) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1408\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1472) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1472\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1536) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1536\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1600) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1600\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1664) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1664\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1728) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1728\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1792) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1792\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1856) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1856\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1920) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1920\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 1984) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1984\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2048) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2048\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2112) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2112\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2176) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2176\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2240) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2240\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2304) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2304\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2368) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2368\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2432) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2432\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2496) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2496\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2560) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2560\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2624) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2624\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2688) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2688\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2752) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2752\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2816) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2816\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2880) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2880\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 2944) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2944\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3008) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3008\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3072) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3072\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3136) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3136\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3200) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3200\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3264) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3264\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3328) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3328\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3392) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3392\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3456) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3456\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3520) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3520\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3584) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3584\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3648) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3648\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3712) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3712\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3776) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3776\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3840) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3840\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3904) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3904\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 3968) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3968\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 4032) { asm volatile("\n \ ds_read_b128 %0, %1 offset:4032\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } if(offset == 4096) { asm volatile("\n \ ds_read_b128 %0, %1 offset:4096\n \ " - : "=v"(r) - : "v"(__to_local(lds))); + : "=v"(r) + : "v"(__to_local(lds))); } #endif } diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp index c8bd2efc35..ac96fff9fc 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp @@ -196,6 +196,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn GemmDataPerReadA, GemmDataPerReadB>{}; + // choose GEMM implementation here + const auto run_blockwise_batch_gemm = [&](auto... Xs) { +#if 0 + return blockwise_batch_gemm.Run(Xs...); +#elif 0 + return blockwise_batch_gemm.Run_asm(Xs...); +#else + return blockwise_batch_gemm.Run_asm_v2(Xs...); +#endif + }; + // LDS: be careful of alignment constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number{}); @@ -293,7 +304,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn p_wei_register_clipboard); // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); + run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread); // LDS double buffer: store next data to LDS blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard, @@ -322,7 +333,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn p_wei_register_clipboard); // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); + run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread); // LDS double buffer: store next data to LDS blockwise_in_copy_reorder.RunStoreRegisterClipboard( @@ -334,7 +345,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn __syncthreads(); // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, + run_blockwise_batch_gemm(p_wei_block_double + wei_block_space, p_in_block_double + in_block_space, p_out_thread); } diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp index 445512922a..dcafa0f4c8 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp @@ -78,22 +78,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); - // constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); + constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + constexpr auto block_work_desc = make_ConstantTensorDescriptor( + Sequence{}); - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); + + const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; + const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; const index_t hi_block_data_begin = ho_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin; @@ -193,6 +191,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw GemmDataPerReadA, GemmDataPerReadB>{}; + // choose GEMM implementation here + const auto run_blockwise_batch_gemm = [&](auto... Xs) { +#if 1 + return blockwise_batch_gemm.Run(Xs...); +#elif 0 + return blockwise_batch_gemm.Run_asm(Xs...); +#else + return blockwise_batch_gemm.Run_asm_v2(Xs...); +#endif + }; + // LDS: be careful of alignment constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number{}); @@ -222,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); -#if 1 +#if 0 const Float* p_in_global_block_offset = p_in_global + in_n_c_h_w_global_desc.Get1dIndex( @@ -267,7 +276,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw __syncthreads(); - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread); __syncthreads(); } @@ -314,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw __syncthreads(); - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); + run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread); __syncthreads(); }