From 05d7a0875c8e4cd12aed8e63100591ed07328d6a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 3 Apr 2019 19:04:17 -0500 Subject: [PATCH] enable 128x128 block gemm --- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 6 +- src/include/blockwise_gemm.hip.hpp | 14 +- ..._gemm_convolution_2_chwn_cyxk_khwn.hip.hpp | 5 +- src/include/inline_asm.hpp | 200 ++++++++---------- src/include/threadwise_gemm.hip.hpp | 12 +- 5 files changed, 112 insertions(+), 125 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 3690622c86..198bae87e4 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -190,7 +190,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 256; -#elif 1 +#elif 0 // 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; @@ -221,8 +221,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t BlockSize = 128; #elif 1 - // 1x1, 14x14, Vega 20, hack CPerBlock = 1 - constexpr index_t BPerBlock = 64; + // 1x1, 14x14, Vega 20, try + constexpr index_t BPerBlock = 128; constexpr index_t KPerBlock = 128; constexpr index_t CPerBlock = 8; diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index dcbb4e7e92..bdf79540c3 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -377,13 +377,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - //auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; - //auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; + // auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; + // auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_c = (Float4*)(p_c_thread); - void* a_loc = (void *)(p_a_block + mMyThreadOffsetA); - void* b_loc = (void *)(p_b_block + mMyThreadOffsetB); + void* a_loc = (void*)(p_a_block + mMyThreadOffsetA); + void* b_loc = (void*)(p_b_block + mMyThreadOffsetB); // loop over k int k_chunk = 2; #pragma unroll @@ -403,9 +403,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); #else - int k = k_begin; - int lds_a_block_off = sizeof(Float) * M; - int lds_b_block_off = sizeof(Float) * N; + int k = k_begin; + int lds_a_block_off = sizeof(Float) * M; + int lds_b_block_off = sizeof(Float) * N; int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 573b5923b7..a2194c4266 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -272,7 +272,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn // LDS: be careful of alignment constexpr index_t max_align = - mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); + mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); constexpr index_t in_block_element_space = in_cb_block_desc.GetElementSpace(Number{}); @@ -297,7 +297,8 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) { // load data blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); diff --git a/src/include/inline_asm.hpp b/src/include/inline_asm.hpp index 7cec3dc5ec..09e2284ec2 100644 --- a/src/include/inline_asm.hpp +++ b/src/include/inline_asm.hpp @@ -4,56 +4,68 @@ typedef float Float4 __attribute__((ext_vector_type(4))); extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; -inline __device__ void lgkmcnt(int cnt){ +inline __device__ void lgkmcnt(int cnt) +{ #if 1 - if(cnt == 0) { + if(cnt == 0) + { asm volatile("\n \ s_waitcnt lgkmcnt(0) \n \ - "::); + " ::); } - else if(cnt == 1) { + else if(cnt == 1) + { asm volatile("\n \ s_waitcnt lgkmcnt(1) \n \ - "::); + " ::); } - else if(cnt == 2) { + else if(cnt == 2) + { asm volatile("\n \ s_waitcnt lgkmcnt(2) \n \ - "::); + " ::); } - else if(cnt == 3) { + else if(cnt == 3) + { asm volatile("\n \ s_waitcnt lgkmcnt(3) \n \ - "::); + " ::); } - else if(cnt == 4) { + else if(cnt == 4) + { asm volatile("\n \ s_waitcnt lgkmcnt(4) \n \ - "::); + " ::); } - else { + else + { assert(0); } #endif } -inline __device__ void outerProduct1x4(const float *a, const float *b, float *c) { +inline __device__ void outerProduct1x4(const float* a, const float* b, float* c) +{ asm volatile("\n \ v_mac_f32 %0, %4, %5 \n \ v_mac_f32 %1, %4, %6 \n \ v_mac_f32 %2, %4, %7 \n \ v_mac_f32 %3, %4, %8 \n \ " - : - "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) - : - "v"(a[0]), - "v"(b[0]), "v"(b[1]), "v"(b[2]), "v"(b[3]), - "0"(c[0]), "1"(c[1]), "2"(c[2]), "3"(c[3]) - ); + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "v"(a[0]), + "v"(b[0]), + "v"(b[1]), + "v"(b[2]), + "v"(b[3]), + "0"(c[0]), + "1"(c[1]), + "2"(c[2]), + "3"(c[3])); } -inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 &c) { +inline __device__ void outerProduct1x4(const float& a, const Float4& b, Float4& c) +{ #if 0 asm volatile( "\n \ @@ -67,12 +79,13 @@ inline __device__ void outerProduct1x4(const float &a, const Float4 &b, Float4 & "v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w) ); #else - outerProduct1x4(&a, (float *)&b, (float *)&c); + outerProduct1x4(&a, (float*)&b, (float*)&c); #endif } - -inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4 &c0, Float4 &c1, Float4 &c2, Float4 &c3) { +inline __device__ void +outerProduct4x4(const Float4& a, const Float4& b, Float4& c0, Float4& c1, Float4& c2, Float4& c3) +{ #if 0 asm volatile( "\n \ @@ -126,7 +139,7 @@ inline __device__ void outerProduct4x4(const Float4 &a, const Float4 &b, Float4 #endif } -inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4 *c) +inline __device__ void outerProduct8x8(const Float4* a, const Float4* b, Float4* c) { outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]); outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]); @@ -134,250 +147,223 @@ inline __device__ void outerProduct8x8(const Float4 *a, const Float4 *b, Float4 outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]); } -inline __device__ void ds_read_b128(Float4 &r, void *lds, int offset = 0) +inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0) { if(offset == 0) { asm volatile("\n \ ds_read_b128 %0, %1 offset:0 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 128) { asm volatile("\n \ ds_read_b128 %0, %1 offset:128 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 256) { asm volatile("\n \ ds_read_b128 %0, %1 offset:256 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 384) { asm volatile("\n \ ds_read_b128 %0, %1 offset:384 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 512) { asm volatile("\n \ ds_read_b128 %0, %1 offset:512 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 640) { asm volatile("\n \ ds_read_b128 %0, %1 offset:640 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 768) { asm volatile("\n \ ds_read_b128 %0, %1 offset:768 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 896) { asm volatile("\n \ ds_read_b128 %0, %1 offset:896 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1024) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1024 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1152) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1152 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1280) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1280 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1408) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1408 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1536) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1536 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1664) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1664 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1792) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1792 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 1920) { asm volatile("\n \ ds_read_b128 %0, %1 offset:1920 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 2048) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2048 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 2176) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2176 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 2304) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2304 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 2560) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2560 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 2816) { asm volatile("\n \ ds_read_b128 %0, %1 offset:2816 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 3072) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3072 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 3328) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3328 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 3584) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3584 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 3840) { asm volatile("\n \ ds_read_b128 %0, %1 offset:3840 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 4096) { asm volatile("\n \ ds_read_b128 %0, %1 offset:4096 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else if(offset == 4352) { asm volatile("\n \ ds_read_b128 %0, %1 offset:4352 \n \ " - : "=v"(r) - : "v"(__to_local(lds)) - ); + : "=v"(r) + : "v"(__to_local(lds))); } else { diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index c5a7de8049..410358c349 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -31,10 +31,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix, const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0); - Float4 *reg_p = (Float4 *)&p_dst[dst_index]; - Float4 *loc_p = (Float4 *)&p_src[src_index]; + Float4* reg_p = (Float4*)&p_dst[dst_index]; + Float4* loc_p = (Float4*)&p_src[src_index]; - ds_read_b128(reg_p[0], (void *)&loc_p[0]); + ds_read_b128(reg_p[0], (void*)&loc_p[0]); } #endif } @@ -86,9 +86,9 @@ __device__ void threadwise_gemm(MatrixA, } } #else - const Float4 *a_vec = (const Float4 *)p_a_thread; - const Float4 *b_vec = (const Float4 *)p_b_thread; - Float4 *c_vec = (Float4 *)p_c_thread; + const Float4* a_vec = (const Float4*)p_a_thread; + const Float4* b_vec = (const Float4*)p_b_thread; + Float4* c_vec = (Float4*)p_c_thread; outerProduct8x8(a_vec, b_vec, c_vec); #endif