From e43d7bc63c2df138c376412fa5b4aaebc26ca131 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 1 Apr 2019 15:17:22 -0500 Subject: [PATCH] refactor --- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 16 +- driver/driver.hip.cpp | 2 +- src/include/ConstantTensorDescriptor.hip.hpp | 5 +- src/include/blockwise_gemm.hip.hpp | 848 ++++++++---------- src/include/common.hip.hpp | 54 +- .../gridwise_direct_convolution_1.hip.hpp | 10 +- ...irect_convolution_2_nchw_kcyx_nkhw.hip.hpp | 9 +- ...lution_2_vectorized_nchw_kcyx_nkhw.hip.hpp | 8 +- ..._gemm_convolution_1_chwn_cyxk_khwn.hip.hpp | 9 +- ...onvolution_1_chwn_cyxk_khwn_padded.hip.hpp | 8 +- ..._gemm_convolution_2_chwn_cyxk_khwn.hip.hpp | 385 ++++---- ...2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp | 422 ++++----- src/include/threadwise_gemm.hip.hpp | 25 +- 13 files changed, 873 insertions(+), 928 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 920b65e1b6..a773a078e4 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -270,7 +270,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel( + constexpr auto gridwise_conv = #if 1 gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn #else @@ -301,12 +301,14 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim1, InBlockCopyDataPerRead, - WeiBlockCopyDataPerRead>, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); + WeiBlockCopyDataPerRead>(); + + float time = launch_kernel(gridwise_conv.Run, + dim3(GridSize), + dim3(BlockSize), + static_cast(in_chwn_device_buf.GetDeviceBuffer()), + static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), + static_cast(out_khwn_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms\n", time); usleep(std::min(time * 1000, float(10000))); diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index a83e4082c7..0ea091e607 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 4e883f12e7..d36a752eb4 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -137,7 +137,10 @@ struct ConstantTensorDescriptor } }; - return static_const_reduce_n{}(GetElementSpace_f{}, add{}) + align.Get(); + index_t element_space_unaligned = + static_const_reduce_n{}(GetElementSpace_f{}, add{}) + 1; + + return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } template diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index dfaef91a83..3e9c57d15f 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "threadwise_gemm.hip.hpp" -extern "C" __attribute__((address_space(3))) void* __to_local(void* p) [[hc]]; +extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const + __device__ void Run_asm(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const { constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; @@ -368,10 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( Number{}, Number{}, Number{}); - float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; + float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; - FloatA *p_a_thread = p_thread; - FloatB *p_b_thread = p_thread + a_thread_mtx.GetElementSpace(); + FloatA* p_a_thread = p_thread; + FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace(); constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; @@ -387,9 +387,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; - const float4* a_loc = (const float4 *)(p_a_block + a_src_index); - const float4* b_loc = (const float4 *)(p_b_block + b_src_index); - float4* reg = (float4 *)(p_thread); + const float4* a_loc = (const float4*)(p_a_block + a_src_index); + const float4* b_loc = (const float4*)(p_b_block + b_src_index); + float4* reg = (float4*)(p_thread); reg[0] = a_loc[0]; reg[1] = a_loc[16]; @@ -398,41 +398,41 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 //asm volatile("\n \ //ds_read2_b64 %0, %1 offset1:1 \n \ - //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[0]) - //: "v"(__to_local((void *)(a_loc))) - //); + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]) + //: "v"(__to_local((void *)(a_loc))) + //); //asm volatile("\n \ //ds_read2_b64 %0, %1 offset1:1 \n \ - //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[1]) - //: "v"(__to_local((void *)(a_loc + 16))) - //); + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[1]) + //: "v"(__to_local((void *)(a_loc + 16))) + //); //asm volatile("\n \ //ds_read2_b64 %0, %1 offset1:1 \n \ - //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[2]) - //: "v"(__to_local((void *)(b_loc))) - //); + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[2]) + //: "v"(__to_local((void *)(b_loc))) + //); //asm volatile("\n \ //ds_read2_b64 %0, %1 offset1:1 \n \ - //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[3]) - //: "v"(__to_local((void *)(b_loc + 8))) - //); - + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[3]) + //: "v"(__to_local((void *)(b_loc + 8))) + //); + //asm volatile("\n \ //ds_read2_b64 %0, %4 offset1:1 \n \ //ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \ //ds_read2_b64 %2, %5 offset1:1 \n \ //ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \ - //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3]) - //: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) - //); + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3]) + //: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) + //); //asm volatile("\n \ //ds_read_b32 %0, %16 \n \ @@ -451,32 +451,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 //ds_read_b32 %13, %19 offset:1\n \ //ds_read_b32 %14, %19 offset:2\n \ //ds_read_b32 %15, %19 offset:3\n \ - //s_waitcnt lgkmcnt(0)" - //: - //"=v"(p_a_thread[0]), - //"=v"(p_a_thread[1]), - //"=v"(p_a_thread[2]), - //"=v"(p_a_thread[3]), - //"=v"(p_a_thread[4]), - //"=v"(p_a_thread[5]), - //"=v"(p_a_thread[6]), - //"=v"(p_a_thread[7]), - //"=v"(p_b_thread[0]), - //"=v"(p_b_thread[1]), - //"=v"(p_b_thread[2]), - //"=v"(p_b_thread[3]), - //"=v"(p_b_thread[4]), - //"=v"(p_b_thread[5]), - //"=v"(p_b_thread[6]), - //"=v"(p_b_thread[7]) - //: - //"v"(__to_local((void *)(&p_a_block[0]))), - //"v"(__to_local((void *)(&p_a_block[64]))), - //"v"(__to_local((void *)(&p_b_block[0]))), - //"v"(__to_local((void *)(&p_b_block[32]))) + //s_waitcnt lgkmcnt(0)" + //: + //"=v"(p_a_thread[0]), + //"=v"(p_a_thread[1]), + //"=v"(p_a_thread[2]), + //"=v"(p_a_thread[3]), + //"=v"(p_a_thread[4]), + //"=v"(p_a_thread[5]), + //"=v"(p_a_thread[6]), + //"=v"(p_a_thread[7]), + //"=v"(p_b_thread[0]), + //"=v"(p_b_thread[1]), + //"=v"(p_b_thread[2]), + //"=v"(p_b_thread[3]), + //"=v"(p_b_thread[4]), + //"=v"(p_b_thread[5]), + //"=v"(p_b_thread[6]), + //"=v"(p_b_thread[7]) + //: + //"v"(__to_local((void *)(&p_a_block[0]))), + //"v"(__to_local((void *)(&p_a_block[64]))), + //"v"(__to_local((void *)(&p_b_block[0]))), + //"v"(__to_local((void *)(&p_b_block[32]))) //); - // C = A * B asm volatile("\n \ v_mac_f32 %0, %64, %72 \n \ @@ -544,165 +543,161 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 v_mac_f32 %62, %71, %78 \n \ v_mac_f32 %63, %71, %79 \n \ " - : - "=v"(p_c_thread[0]), - "=v"(p_c_thread[1]), - "=v"(p_c_thread[2]), - "=v"(p_c_thread[3]), - "=v"(p_c_thread[4]), - "=v"(p_c_thread[5]), - "=v"(p_c_thread[6]), - "=v"(p_c_thread[7]), - "=v"(p_c_thread[8]), - "=v"(p_c_thread[9]), - "=v"(p_c_thread[10]), - "=v"(p_c_thread[11]), - "=v"(p_c_thread[12]), - "=v"(p_c_thread[13]), - "=v"(p_c_thread[14]), - "=v"(p_c_thread[15]), - "=v"(p_c_thread[16]), - "=v"(p_c_thread[17]), - "=v"(p_c_thread[18]), - "=v"(p_c_thread[19]), - "=v"(p_c_thread[20]), - "=v"(p_c_thread[21]), - "=v"(p_c_thread[22]), - "=v"(p_c_thread[23]), - "=v"(p_c_thread[24]), - "=v"(p_c_thread[25]), - "=v"(p_c_thread[26]), - "=v"(p_c_thread[27]), - "=v"(p_c_thread[28]), - "=v"(p_c_thread[29]), - "=v"(p_c_thread[30]), - "=v"(p_c_thread[31]), - "=v"(p_c_thread[32]), - "=v"(p_c_thread[33]), - "=v"(p_c_thread[34]), - "=v"(p_c_thread[35]), - "=v"(p_c_thread[36]), - "=v"(p_c_thread[37]), - "=v"(p_c_thread[38]), - "=v"(p_c_thread[39]), - "=v"(p_c_thread[40]), - "=v"(p_c_thread[41]), - "=v"(p_c_thread[42]), - "=v"(p_c_thread[43]), - "=v"(p_c_thread[44]), - "=v"(p_c_thread[45]), - "=v"(p_c_thread[46]), - "=v"(p_c_thread[47]), - "=v"(p_c_thread[48]), - "=v"(p_c_thread[49]), - "=v"(p_c_thread[50]), - "=v"(p_c_thread[51]), - "=v"(p_c_thread[52]), - "=v"(p_c_thread[53]), - "=v"(p_c_thread[54]), - "=v"(p_c_thread[55]), - "=v"(p_c_thread[56]), - "=v"(p_c_thread[57]), - "=v"(p_c_thread[58]), - "=v"(p_c_thread[59]), - "=v"(p_c_thread[60]), - "=v"(p_c_thread[61]), - "=v"(p_c_thread[62]), - "=v"(p_c_thread[63]) - : - "v"(p_a_thread[0]), - "v"(p_a_thread[1]), - "v"(p_a_thread[2]), - "v"(p_a_thread[3]), - "v"(p_a_thread[4]), - "v"(p_a_thread[5]), - "v"(p_a_thread[6]), - "v"(p_a_thread[7]), - "v"(p_b_thread[0]), - "v"(p_b_thread[1]), - "v"(p_b_thread[2]), - "v"(p_b_thread[3]), - "v"(p_b_thread[4]), - "v"(p_b_thread[5]), - "v"(p_b_thread[6]), - "v"(p_b_thread[7]), - "0"(p_c_thread[0]), - "1"(p_c_thread[1]), - "2"(p_c_thread[2]), - "3"(p_c_thread[3]), - "4"(p_c_thread[4]), - "5"(p_c_thread[5]), - "6"(p_c_thread[6]), - "7"(p_c_thread[7]), - "8"(p_c_thread[8]), - "9"(p_c_thread[9]), - "10"(p_c_thread[10]), - "11"(p_c_thread[11]), - "12"(p_c_thread[12]), - "13"(p_c_thread[13]), - "14"(p_c_thread[14]), - "15"(p_c_thread[15]), - "16"(p_c_thread[16]), - "17"(p_c_thread[17]), - "18"(p_c_thread[18]), - "19"(p_c_thread[19]), - "20"(p_c_thread[20]), - "21"(p_c_thread[21]), - "22"(p_c_thread[22]), - "23"(p_c_thread[23]), - "24"(p_c_thread[24]), - "25"(p_c_thread[25]), - "26"(p_c_thread[26]), - "27"(p_c_thread[27]), - "28"(p_c_thread[28]), - "29"(p_c_thread[29]), - "30"(p_c_thread[30]), - "31"(p_c_thread[31]), - "32"(p_c_thread[32]), - "33"(p_c_thread[33]), - "34"(p_c_thread[34]), - "35"(p_c_thread[35]), - "36"(p_c_thread[36]), - "37"(p_c_thread[37]), - "38"(p_c_thread[38]), - "39"(p_c_thread[39]), - "40"(p_c_thread[40]), - "41"(p_c_thread[41]), - "42"(p_c_thread[42]), - "43"(p_c_thread[43]), - "44"(p_c_thread[44]), - "45"(p_c_thread[45]), - "46"(p_c_thread[46]), - "47"(p_c_thread[47]), - "48"(p_c_thread[48]), - "49"(p_c_thread[49]), - "50"(p_c_thread[50]), - "51"(p_c_thread[51]), - "52"(p_c_thread[52]), - "53"(p_c_thread[53]), - "54"(p_c_thread[54]), - "55"(p_c_thread[55]), - "56"(p_c_thread[56]), - "57"(p_c_thread[57]), - "58"(p_c_thread[58]), - "59"(p_c_thread[59]), - "60"(p_c_thread[60]), - "61"(p_c_thread[61]), - "62"(p_c_thread[62]), - "63"(p_c_thread[63]) - ); + : "=v"(p_c_thread[0]), + "=v"(p_c_thread[1]), + "=v"(p_c_thread[2]), + "=v"(p_c_thread[3]), + "=v"(p_c_thread[4]), + "=v"(p_c_thread[5]), + "=v"(p_c_thread[6]), + "=v"(p_c_thread[7]), + "=v"(p_c_thread[8]), + "=v"(p_c_thread[9]), + "=v"(p_c_thread[10]), + "=v"(p_c_thread[11]), + "=v"(p_c_thread[12]), + "=v"(p_c_thread[13]), + "=v"(p_c_thread[14]), + "=v"(p_c_thread[15]), + "=v"(p_c_thread[16]), + "=v"(p_c_thread[17]), + "=v"(p_c_thread[18]), + "=v"(p_c_thread[19]), + "=v"(p_c_thread[20]), + "=v"(p_c_thread[21]), + "=v"(p_c_thread[22]), + "=v"(p_c_thread[23]), + "=v"(p_c_thread[24]), + "=v"(p_c_thread[25]), + "=v"(p_c_thread[26]), + "=v"(p_c_thread[27]), + "=v"(p_c_thread[28]), + "=v"(p_c_thread[29]), + "=v"(p_c_thread[30]), + "=v"(p_c_thread[31]), + "=v"(p_c_thread[32]), + "=v"(p_c_thread[33]), + "=v"(p_c_thread[34]), + "=v"(p_c_thread[35]), + "=v"(p_c_thread[36]), + "=v"(p_c_thread[37]), + "=v"(p_c_thread[38]), + "=v"(p_c_thread[39]), + "=v"(p_c_thread[40]), + "=v"(p_c_thread[41]), + "=v"(p_c_thread[42]), + "=v"(p_c_thread[43]), + "=v"(p_c_thread[44]), + "=v"(p_c_thread[45]), + "=v"(p_c_thread[46]), + "=v"(p_c_thread[47]), + "=v"(p_c_thread[48]), + "=v"(p_c_thread[49]), + "=v"(p_c_thread[50]), + "=v"(p_c_thread[51]), + "=v"(p_c_thread[52]), + "=v"(p_c_thread[53]), + "=v"(p_c_thread[54]), + "=v"(p_c_thread[55]), + "=v"(p_c_thread[56]), + "=v"(p_c_thread[57]), + "=v"(p_c_thread[58]), + "=v"(p_c_thread[59]), + "=v"(p_c_thread[60]), + "=v"(p_c_thread[61]), + "=v"(p_c_thread[62]), + "=v"(p_c_thread[63]) + : "v"(p_a_thread[0]), + "v"(p_a_thread[1]), + "v"(p_a_thread[2]), + "v"(p_a_thread[3]), + "v"(p_a_thread[4]), + "v"(p_a_thread[5]), + "v"(p_a_thread[6]), + "v"(p_a_thread[7]), + "v"(p_b_thread[0]), + "v"(p_b_thread[1]), + "v"(p_b_thread[2]), + "v"(p_b_thread[3]), + "v"(p_b_thread[4]), + "v"(p_b_thread[5]), + "v"(p_b_thread[6]), + "v"(p_b_thread[7]), + "0"(p_c_thread[0]), + "1"(p_c_thread[1]), + "2"(p_c_thread[2]), + "3"(p_c_thread[3]), + "4"(p_c_thread[4]), + "5"(p_c_thread[5]), + "6"(p_c_thread[6]), + "7"(p_c_thread[7]), + "8"(p_c_thread[8]), + "9"(p_c_thread[9]), + "10"(p_c_thread[10]), + "11"(p_c_thread[11]), + "12"(p_c_thread[12]), + "13"(p_c_thread[13]), + "14"(p_c_thread[14]), + "15"(p_c_thread[15]), + "16"(p_c_thread[16]), + "17"(p_c_thread[17]), + "18"(p_c_thread[18]), + "19"(p_c_thread[19]), + "20"(p_c_thread[20]), + "21"(p_c_thread[21]), + "22"(p_c_thread[22]), + "23"(p_c_thread[23]), + "24"(p_c_thread[24]), + "25"(p_c_thread[25]), + "26"(p_c_thread[26]), + "27"(p_c_thread[27]), + "28"(p_c_thread[28]), + "29"(p_c_thread[29]), + "30"(p_c_thread[30]), + "31"(p_c_thread[31]), + "32"(p_c_thread[32]), + "33"(p_c_thread[33]), + "34"(p_c_thread[34]), + "35"(p_c_thread[35]), + "36"(p_c_thread[36]), + "37"(p_c_thread[37]), + "38"(p_c_thread[38]), + "39"(p_c_thread[39]), + "40"(p_c_thread[40]), + "41"(p_c_thread[41]), + "42"(p_c_thread[42]), + "43"(p_c_thread[43]), + "44"(p_c_thread[44]), + "45"(p_c_thread[45]), + "46"(p_c_thread[46]), + "47"(p_c_thread[47]), + "48"(p_c_thread[48]), + "49"(p_c_thread[49]), + "50"(p_c_thread[50]), + "51"(p_c_thread[51]), + "52"(p_c_thread[52]), + "53"(p_c_thread[53]), + "54"(p_c_thread[54]), + "55"(p_c_thread[55]), + "56"(p_c_thread[56]), + "57"(p_c_thread[57]), + "58"(p_c_thread[58]), + "59"(p_c_thread[59]), + "60"(p_c_thread[60]), + "61"(p_c_thread[61]), + "62"(p_c_thread[62]), + "63"(p_c_thread[63])); #else - auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; - auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; - auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); + auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; + auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; + auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); - const float4* a_loc = (const float4 *)(p_a_block + a_src_index); - const float4* b_loc = (const float4 *)(p_b_block + b_src_index); - float4* reg = (float4 *)(p_a_thread + dst_index); + const float4* a_loc = (const float4*)(p_a_block + a_src_index); + const float4* b_loc = (const float4*)(p_b_block + b_src_index); + float4* reg = (float4*)(p_a_thread + dst_index); - - asm volatile("\n \ + asm volatile("\n \ ds_read2_b64 %0, %84 offset1:1 \n \ ds_read2_b64 %1, %84 offset0:32 offset1:33 \n \ ds_read2_b64 %2, %85 offset1:1 \n \ @@ -773,168 +768,165 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 v_mac_f32 %66, %75, %82 \n \ v_mac_f32 %67, %75, %83 \n \ " - : - "=v"(reg[0]), - "=v"(reg[1]), - "=v"(reg[2]), - "=v"(reg[3]), - "=v"(p_c_thread[0]), - "=v"(p_c_thread[1]), - "=v"(p_c_thread[2]), - "=v"(p_c_thread[3]), - "=v"(p_c_thread[4]), - "=v"(p_c_thread[5]), - "=v"(p_c_thread[6]), - "=v"(p_c_thread[7]), - "=v"(p_c_thread[8]), - "=v"(p_c_thread[9]), - "=v"(p_c_thread[10]), - "=v"(p_c_thread[11]), - "=v"(p_c_thread[12]), - "=v"(p_c_thread[13]), - "=v"(p_c_thread[14]), - "=v"(p_c_thread[15]), - "=v"(p_c_thread[16]), - "=v"(p_c_thread[17]), - "=v"(p_c_thread[18]), - "=v"(p_c_thread[19]), - "=v"(p_c_thread[20]), - "=v"(p_c_thread[21]), - "=v"(p_c_thread[22]), - "=v"(p_c_thread[23]), - "=v"(p_c_thread[24]), - "=v"(p_c_thread[25]), - "=v"(p_c_thread[26]), - "=v"(p_c_thread[27]), - "=v"(p_c_thread[28]), - "=v"(p_c_thread[29]), - "=v"(p_c_thread[30]), - "=v"(p_c_thread[31]), - "=v"(p_c_thread[32]), - "=v"(p_c_thread[33]), - "=v"(p_c_thread[34]), - "=v"(p_c_thread[35]), - "=v"(p_c_thread[36]), - "=v"(p_c_thread[37]), - "=v"(p_c_thread[38]), - "=v"(p_c_thread[39]), - "=v"(p_c_thread[40]), - "=v"(p_c_thread[41]), - "=v"(p_c_thread[42]), - "=v"(p_c_thread[43]), - "=v"(p_c_thread[44]), - "=v"(p_c_thread[45]), - "=v"(p_c_thread[46]), - "=v"(p_c_thread[47]), - "=v"(p_c_thread[48]), - "=v"(p_c_thread[49]), - "=v"(p_c_thread[50]), - "=v"(p_c_thread[51]), - "=v"(p_c_thread[52]), - "=v"(p_c_thread[53]), - "=v"(p_c_thread[54]), - "=v"(p_c_thread[55]), - "=v"(p_c_thread[56]), - "=v"(p_c_thread[57]), - "=v"(p_c_thread[58]), - "=v"(p_c_thread[59]), - "=v"(p_c_thread[60]), - "=v"(p_c_thread[61]), - "=v"(p_c_thread[62]), - "=v"(p_c_thread[63]) - : - "v"(p_a_thread[0]), - "v"(p_a_thread[1]), - "v"(p_a_thread[2]), - "v"(p_a_thread[3]), - "v"(p_a_thread[4]), - "v"(p_a_thread[5]), - "v"(p_a_thread[6]), - "v"(p_a_thread[7]), - "v"(p_b_thread[0]), - "v"(p_b_thread[1]), - "v"(p_b_thread[2]), - "v"(p_b_thread[3]), - "v"(p_b_thread[4]), - "v"(p_b_thread[5]), - "v"(p_b_thread[6]), - "v"(p_b_thread[7]), - "v"(__to_local((void *)(a_loc))), - "v"(__to_local((void *)(b_loc))), - "4"(p_c_thread[0]), - "5"(p_c_thread[1]), - "6"(p_c_thread[2]), - "7"(p_c_thread[3]), - "8"(p_c_thread[4]), - "9"(p_c_thread[5]), - "10"(p_c_thread[6]), - "11"(p_c_thread[7]), - "12"(p_c_thread[8]), - "13"(p_c_thread[9]), - "14"(p_c_thread[10]), - "15"(p_c_thread[11]), - "16"(p_c_thread[12]), - "17"(p_c_thread[13]), - "18"(p_c_thread[14]), - "19"(p_c_thread[15]), - "20"(p_c_thread[16]), - "21"(p_c_thread[17]), - "22"(p_c_thread[18]), - "23"(p_c_thread[19]), - "24"(p_c_thread[20]), - "25"(p_c_thread[21]), - "26"(p_c_thread[22]), - "27"(p_c_thread[23]), - "28"(p_c_thread[24]), - "29"(p_c_thread[25]), - "30"(p_c_thread[26]), - "31"(p_c_thread[27]), - "32"(p_c_thread[28]), - "33"(p_c_thread[29]), - "34"(p_c_thread[30]), - "35"(p_c_thread[31]), - "36"(p_c_thread[32]), - "37"(p_c_thread[33]), - "38"(p_c_thread[34]), - "39"(p_c_thread[35]), - "40"(p_c_thread[36]), - "41"(p_c_thread[37]), - "42"(p_c_thread[38]), - "43"(p_c_thread[39]), - "44"(p_c_thread[40]), - "45"(p_c_thread[41]), - "46"(p_c_thread[42]), - "47"(p_c_thread[43]), - "48"(p_c_thread[44]), - "49"(p_c_thread[45]), - "50"(p_c_thread[46]), - "51"(p_c_thread[47]), - "52"(p_c_thread[48]), - "53"(p_c_thread[49]), - "54"(p_c_thread[50]), - "55"(p_c_thread[51]), - "56"(p_c_thread[52]), - "57"(p_c_thread[53]), - "58"(p_c_thread[54]), - "59"(p_c_thread[55]), - "60"(p_c_thread[56]), - "61"(p_c_thread[57]), - "62"(p_c_thread[58]), - "63"(p_c_thread[59]), - "64"(p_c_thread[60]), - "65"(p_c_thread[61]), - "66"(p_c_thread[62]), - "67"(p_c_thread[63]) - ); + : "=v"(reg[0]), + "=v"(reg[1]), + "=v"(reg[2]), + "=v"(reg[3]), + "=v"(p_c_thread[0]), + "=v"(p_c_thread[1]), + "=v"(p_c_thread[2]), + "=v"(p_c_thread[3]), + "=v"(p_c_thread[4]), + "=v"(p_c_thread[5]), + "=v"(p_c_thread[6]), + "=v"(p_c_thread[7]), + "=v"(p_c_thread[8]), + "=v"(p_c_thread[9]), + "=v"(p_c_thread[10]), + "=v"(p_c_thread[11]), + "=v"(p_c_thread[12]), + "=v"(p_c_thread[13]), + "=v"(p_c_thread[14]), + "=v"(p_c_thread[15]), + "=v"(p_c_thread[16]), + "=v"(p_c_thread[17]), + "=v"(p_c_thread[18]), + "=v"(p_c_thread[19]), + "=v"(p_c_thread[20]), + "=v"(p_c_thread[21]), + "=v"(p_c_thread[22]), + "=v"(p_c_thread[23]), + "=v"(p_c_thread[24]), + "=v"(p_c_thread[25]), + "=v"(p_c_thread[26]), + "=v"(p_c_thread[27]), + "=v"(p_c_thread[28]), + "=v"(p_c_thread[29]), + "=v"(p_c_thread[30]), + "=v"(p_c_thread[31]), + "=v"(p_c_thread[32]), + "=v"(p_c_thread[33]), + "=v"(p_c_thread[34]), + "=v"(p_c_thread[35]), + "=v"(p_c_thread[36]), + "=v"(p_c_thread[37]), + "=v"(p_c_thread[38]), + "=v"(p_c_thread[39]), + "=v"(p_c_thread[40]), + "=v"(p_c_thread[41]), + "=v"(p_c_thread[42]), + "=v"(p_c_thread[43]), + "=v"(p_c_thread[44]), + "=v"(p_c_thread[45]), + "=v"(p_c_thread[46]), + "=v"(p_c_thread[47]), + "=v"(p_c_thread[48]), + "=v"(p_c_thread[49]), + "=v"(p_c_thread[50]), + "=v"(p_c_thread[51]), + "=v"(p_c_thread[52]), + "=v"(p_c_thread[53]), + "=v"(p_c_thread[54]), + "=v"(p_c_thread[55]), + "=v"(p_c_thread[56]), + "=v"(p_c_thread[57]), + "=v"(p_c_thread[58]), + "=v"(p_c_thread[59]), + "=v"(p_c_thread[60]), + "=v"(p_c_thread[61]), + "=v"(p_c_thread[62]), + "=v"(p_c_thread[63]) + : "v"(p_a_thread[0]), + "v"(p_a_thread[1]), + "v"(p_a_thread[2]), + "v"(p_a_thread[3]), + "v"(p_a_thread[4]), + "v"(p_a_thread[5]), + "v"(p_a_thread[6]), + "v"(p_a_thread[7]), + "v"(p_b_thread[0]), + "v"(p_b_thread[1]), + "v"(p_b_thread[2]), + "v"(p_b_thread[3]), + "v"(p_b_thread[4]), + "v"(p_b_thread[5]), + "v"(p_b_thread[6]), + "v"(p_b_thread[7]), + "v"(__to_local((void*)(a_loc))), + "v"(__to_local((void*)(b_loc))), + "4"(p_c_thread[0]), + "5"(p_c_thread[1]), + "6"(p_c_thread[2]), + "7"(p_c_thread[3]), + "8"(p_c_thread[4]), + "9"(p_c_thread[5]), + "10"(p_c_thread[6]), + "11"(p_c_thread[7]), + "12"(p_c_thread[8]), + "13"(p_c_thread[9]), + "14"(p_c_thread[10]), + "15"(p_c_thread[11]), + "16"(p_c_thread[12]), + "17"(p_c_thread[13]), + "18"(p_c_thread[14]), + "19"(p_c_thread[15]), + "20"(p_c_thread[16]), + "21"(p_c_thread[17]), + "22"(p_c_thread[18]), + "23"(p_c_thread[19]), + "24"(p_c_thread[20]), + "25"(p_c_thread[21]), + "26"(p_c_thread[22]), + "27"(p_c_thread[23]), + "28"(p_c_thread[24]), + "29"(p_c_thread[25]), + "30"(p_c_thread[26]), + "31"(p_c_thread[27]), + "32"(p_c_thread[28]), + "33"(p_c_thread[29]), + "34"(p_c_thread[30]), + "35"(p_c_thread[31]), + "36"(p_c_thread[32]), + "37"(p_c_thread[33]), + "38"(p_c_thread[34]), + "39"(p_c_thread[35]), + "40"(p_c_thread[36]), + "41"(p_c_thread[37]), + "42"(p_c_thread[38]), + "43"(p_c_thread[39]), + "44"(p_c_thread[40]), + "45"(p_c_thread[41]), + "46"(p_c_thread[42]), + "47"(p_c_thread[43]), + "48"(p_c_thread[44]), + "49"(p_c_thread[45]), + "50"(p_c_thread[46]), + "51"(p_c_thread[47]), + "52"(p_c_thread[48]), + "53"(p_c_thread[49]), + "54"(p_c_thread[50]), + "55"(p_c_thread[51]), + "56"(p_c_thread[52]), + "57"(p_c_thread[53]), + "58"(p_c_thread[54]), + "59"(p_c_thread[55]), + "60"(p_c_thread[56]), + "61"(p_c_thread[57]), + "62"(p_c_thread[58]), + "63"(p_c_thread[59]), + "64"(p_c_thread[60]), + "65"(p_c_thread[61]), + "66"(p_c_thread[62]), + "67"(p_c_thread[63])); #endif } } template - __device__ void Run_asm(const FloatA* const __restrict__ p_a_block, - const FloatB* const __restrict__ p_b_block, - FloatC* const __restrict__ p_c_thread, - Accumulator f_accum) const + __device__ void Run(const FloatA* const __restrict__ p_a_block, + const FloatB* const __restrict__ p_b_block, + FloatC* const __restrict__ p_c_thread, + Accumulator f_accum) const { constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; @@ -973,17 +965,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 && - KPerThreadLoop == 1 && K == 1, - "asm is not for this mtx shape"); - const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA; #pragma unroll // loop over k for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { -#if 0 #pragma unroll // copy A-sub to form A for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) @@ -993,67 +980,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + mMyThreadOffsetA, a_thread_mtx, - a_thread_sub_mtx.NCol(p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths()); } -#elif 1 - // this produce right result - using vectorA_t = typename vector_type::MemoryType; // this is float4* - asm volatile( - "\n \ - ds_read_b128 %0, %1 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))) - : "v"(__to_local( - (void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA)))); - - asm volatile("\n \ - ds_read_b128 %0, %1 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast( - p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) - : "v"(__to_local(( - void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) + - mMyThreadOffsetA)))); -#elif 0 - // this produce wrong result - using vectorA_t = typename vector_type::MemoryType; // this is float4* - - asm volatile( - "\n \ - ds_read_b128 %0, %2 \n \ - ds_read_b128 %1, %3 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))), - "=v"(*(reinterpret_cast(p_a_thread + - a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) - : "v"(__to_local( - (void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))), - "v"(__to_local((void*)(p_a_block + - a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) + - mMyThreadOffsetA)))); -#elif 1 - // this produce wrong result - using vectorA_t = typename vector_type::MemoryType; // this is float4* - - asm volatile( - "\n \ - ds_read_b128 %0, %1 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))) - : "v"(__to_local((void*)(p_a_block_thread_offset)))); - - asm volatile("\n \ - ds_read_b128 %0, %1 offset:16 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast( - p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC)))) - : "v"(__to_local((void*)(p_a_block_thread_offset)))); - -#endif - - //#pragma unroll +#pragma unroll // copy B-sub to form B for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { @@ -1066,8 +997,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 b_thread_sub_mtx.GetLengths()); } -// C = A * B -#if 1 + // C = A * B threadwise_gemm(a_thread_mtx, True, p_a_thread, @@ -1078,58 +1008,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 False, p_c_thread, f_accum); -#elif 0 - // inline asm - static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8, - "asm is only for 8x8"); - - for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed - { - const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); - - for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) - { - const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); - - asm volatile("\n \ - v_mac_f32 %0, %8, %9 \n \ - v_mac_f32 %1, %8, %10 \n \ - v_mac_f32 %2, %8, %11 \n \ - v_mac_f32 %3, %8, %12 \n \ - v_mac_f32 %4, %8, %13 \n \ - v_mac_f32 %5, %8, %14 \n \ - v_mac_f32 %6, %8, %15 \n \ - v_mac_f32 %7, %8, %16 \n \ - " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]), - "=v"(p_c_thread[cindex + 4]), - "=v"(p_c_thread[cindex + 5]), - "=v"(p_c_thread[cindex + 6]), - "=v"(p_c_thread[cindex + 7]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "v"(p_b_thread[bindex + 4]), - "v"(p_b_thread[bindex + 5]), - "v"(p_b_thread[bindex + 6]), - "v"(p_b_thread[bindex + 7]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3]), - "4"(p_c_thread[cindex + 4]), - "5"(p_c_thread[cindex + 5]), - "6"(p_c_thread[cindex + 6]), - "7"(p_c_thread[cindex + 7])); - } - } -#endif } } diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index 5e3b88f670..ca083ee640 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -5,8 +5,6 @@ #include "Array.hip.hpp" #include "functional.hip.hpp" -extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]]; - __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; } @@ -23,21 +21,45 @@ struct is_same static const bool value = true; }; -#if DEVICE_BACKEND_CUDA -template -__host__ __device__ constexpr T max(T a, T b) -{ - return a > b ? a : b; -} - -template -__host__ __device__ constexpr T min(T a, T b) -{ - return a < b ? a : b; -} -#endif - __host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b) { return (a + b - 1) / b; } + +namespace mod_conv { +template +__host__ __device__ constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +__host__ __device__ constexpr T max(T x, Ts... xs) +{ + static_assert(sizeof...(xs) > 0, "not enough argument"); + + auto y = max(xs...); + + static_assert(is_same::value, "not the same type"); + + return x > y ? x : y; +} + +template +__host__ __device__ constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +__host__ __device__ constexpr T min(T x, Ts... xs) +{ + static_assert(sizeof...(xs) > 0, "not enough argument"); + + auto y = min(xs...); + + static_assert(is_same::value, "not the same type"); + + return x < y ? x : y; +} +} diff --git a/src/include/gridwise_direct_convolution_1.hip.hpp b/src/include/gridwise_direct_convolution_1.hip.hpp index 1fb76988a7..7723fb78b4 100644 --- a/src/include/gridwise_direct_convolution_1.hip.hpp +++ b/src/include/gridwise_direct_convolution_1.hip.hpp @@ -59,12 +59,12 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ constexpr auto out_block_desc = make_ConstantTensorDescriptor(out_block_global_desc.GetLengths()); - constexpr index_t in_block_size = in_block_desc.GetElementSpace(); - constexpr index_t wei_block_size = wei_block_desc.GetElementSpace(); - constexpr index_t out_block_size = out_block_desc.GetElementSpace(); + constexpr index_t in_block_element_size = in_block_desc.GetElementSpace(); + constexpr index_t wei_block_element_size = wei_block_desc.GetElementSpace(); + constexpr index_t out_block_size = out_block_desc.GetElementSpace(); - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; + __shared__ Float p_in_block[in_block_element_size]; + __shared__ Float p_wei_block[wei_block_element_size]; __shared__ Float p_out_block[out_block_size]; const index_t block_id = blockIdx.x; diff --git a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp index 944a1624ee..b301fc1e52 100644 --- a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp @@ -63,17 +63,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i Sequence{}); // shared mem - constexpr index_t in_block_size = + constexpr index_t in_block_element_size = in_nchw_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_size = + constexpr index_t wei_block_element_size = wei_kcyx_block_desc.GetElementSpace(Number{}); constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead ? InBlockCopyDataPerRead : WeiBlockCopyDataPerRead; - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + __shared__ Float p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; + __shared__ Float + p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; // threadwise tensors constexpr index_t HiPerThread = HoPerThread + Y - 1; diff --git a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp index 71b8828de7..250253f2ff 100644 --- a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp @@ -73,10 +73,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( Sequence{}); // shared mem - constexpr index_t in_block_size = + constexpr index_t in_block_element_size = in_nchw_vec_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_size = + constexpr index_t wei_block_element_size = wei_kcyx_vec_block_desc.GetElementSpace(Number{}); constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead @@ -84,9 +84,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( : WeiBlockCopyDataPerRead; __shared__ in_vector_mem_t - p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + p_in_vec_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; __shared__ in_vector_mem_t - p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + p_wei_vec_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; // threadwise tensors constexpr index_t HiPerThread = HoPerThread + Y - 1; diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp index 38fd90ca37..a214393379 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp @@ -164,18 +164,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric HoPerThread>{}; // LDS: be careful of alignment - constexpr index_t in_block_size = + constexpr index_t in_block_element_size = in_chwn_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_size = + constexpr index_t wei_block_element_size = wei_cyxk_block_desc.GetElementSpace(Number{}); constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead ? InBlockCopyDataPerRead : WeiBlockCopyDataPerRead; - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + __shared__ Float p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; + __shared__ Float + p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; // register Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp index fb0c781bfd..f04a283fcf 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp @@ -204,11 +204,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( true>{}; // LDS - constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace(); + constexpr index_t in_block_element_size = in_chwn_block_desc.GetElementSpace(); + constexpr index_t wei_block_element_size = wei_cyxk_block_desc.GetElementSpace(); - __shared__ Float p_in_block[in_block_size]; - __shared__ Float p_wei_block[wei_block_size]; + __shared__ Float p_in_block[in_block_element_size]; + __shared__ Float p_wei_block[wei_block_element_size]; // register Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 74ddeaf2cc..da689bc6b9 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -34,63 +34,109 @@ template -__global__ void -gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) +class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + public: + __host__ __device__ static index_t GetSharedMemorySize() + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); - constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); - constexpr index_t N = in_chwn_global_desc.GetLength(I3); + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - constexpr index_t B = N * Hi * Wi; - constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); - // divide block work by 2d: [K, B] - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); - const index_t k_block_work_id = get_block_1d_id() / BBlockWork; - const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t b_block_data_begin = b_block_work_id * BPerBlock; + constexpr index_t max_align = + mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + // LDS: be careful of alignment + constexpr index_t in_block_element_space = + in_cb_block_desc.GetElementSpace(Number{}); - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + constexpr index_t wei_block_element_space = + wei_cyxk_block_desc.GetElementSpace(Number{}); - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + return (in_block_element_space + wei_block_element_space) * sizeof(Float); + } - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + __global__ static void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + + constexpr index_t C = in_chwn_global_desc.GetLength(I0); + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); + constexpr index_t N = in_chwn_global_desc.GetLength(I3); + + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); + + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); + + constexpr index_t B = N * Hi * Wi; + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); + + // divide block work by 2d: [K, B] + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; + + const index_t k_block_work_id = get_block_1d_id() / BBlockWork; + const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t b_block_data_begin = b_block_work_id * BPerBlock; + + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); + + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -121,20 +167,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric decltype(in_cb_block_desc), decltype(in_cb_block_desc.GetLengths())>{}; #elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; + const auto blockwise_in_copy = + Blockwise2dTensorCopy2{}; #elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; + const auto blockwise_in_copy = + Blockwise2dTensorCopy3{}; #endif // blockwise wei copy @@ -147,137 +195,138 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric decltype(wei_ek_block_desc), decltype(wei_ek_block_desc.GetLengths())>{}; #elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; + const auto blockwise_wei_copy = + Blockwise2dTensorCopy2{}; #elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; #endif - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; - // LDS: be careful of alignment - constexpr index_t in_block_size = - in_cb_block_desc.GetElementSpace(Number{}); + // LDS: be careful of alignment + constexpr index_t max_align = + mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - constexpr index_t wei_block_size = - wei_cyxk_block_desc.GetElementSpace(Number{}); + constexpr index_t in_block_element_space = + in_cb_block_desc.GetElementSpace(Number{}); - constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t wei_block_element_space = + wei_cyxk_block_desc.GetElementSpace(Number{}); - // LDS - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + __shared__ Float p_in_block[in_block_element_space]; + __shared__ Float p_wei_block[wei_block_element_space]; - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + const Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - const Float* p_wei_global_block_offset = - p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + const Float* p_wei_global_block_offset = + p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - __syncthreads()) - { - // load data - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); - - __syncthreads(); - - // compute on current data - // a series of GEMM - for(index_t y = 0; y < Y; ++y) + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) { - for(index_t x = 0; x < X; ++x) + // load data + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); + + __syncthreads(); + + // compute on current data + // a series of GEMM + for(index_t y = 0; y < Y; ++y) { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 - blockwise_gemm.Run + for(index_t x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 0 + blockwise_gemm.Run #elif 1 - blockwise_gemm.Run_asm -#elif 1 - blockwise_gemm.Run_RegisterDoubleBuffer + blockwise_gemm.Run_RegisterDoubleBuffer +#elif 0 + blockwise_gemm.Run_asm #endif - (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + y * Wi + x, + p_out_thread, + f_accum); + } + } + } + + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + + for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + { + for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + { + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; + index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + index_t h_data = b_data / (Wi * N); + index_t itmp = b_data - h_data * (Wi * N); + index_t w_data = itmp / N; + index_t n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } } } } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - - for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; - index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - index_t h_data = b_data / (Wi * N); - index_t itmp = b_data - h_data * (Wi * N); - index_t w_data = itmp / N; - index_t n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; - } - } - } -} +}; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index 74efbca112..488b0a0da7 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -34,67 +34,65 @@ template -__global__ void -#if 0 -__launch_bounds__(256,2) -#endif -gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( - const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) +class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; + public: + __global__ static void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; + constexpr auto in_chwn_global_desc = InGlobalDesc{}; + constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; + constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); - constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); - constexpr index_t N = in_chwn_global_desc.GetLength(I3); + constexpr index_t C = in_chwn_global_desc.GetLength(I0); + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); + constexpr index_t N = in_chwn_global_desc.GetLength(I3); - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr index_t B = N * Hi * Wi; - constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); + constexpr index_t B = N * Hi * Wi; + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - // divide block work by 2d: [K, B] - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; + // divide block work by 2d: [K, B] + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; - const index_t k_block_work_id = get_block_1d_id() / BBlockWork; - const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + const index_t k_block_work_id = get_block_1d_id() / BBlockWork; + const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t b_block_data_begin = b_block_work_id * BPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t b_block_data_begin = b_block_work_id * BPerBlock; - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); + // flattend (2d) tensor view of gridwise input + constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + // tensor view of blockwise input and weight + // be careful of alignment + constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); + constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, Number{}); - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); + // tensor view of threadwise output in register + constexpr auto out_kb_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); #if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) @@ -125,20 +123,22 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( decltype(in_cb_block_desc), decltype(in_cb_block_desc.GetLengths())>{}; #elif 0 - const auto blockwise_in_copy = Blockwise2dTensorCopy2{}; + const auto blockwise_in_copy = + Blockwise2dTensorCopy2{}; #elif 1 - const auto blockwise_in_copy = Blockwise2dTensorCopy3{}; + const auto blockwise_in_copy = + Blockwise2dTensorCopy3{}; #endif // blockwise wei copy @@ -151,36 +151,38 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( decltype(wei_ek_block_desc), decltype(wei_ek_block_desc.GetLengths())>{}; #elif 0 - const auto blockwise_wei_copy = Blockwise2dTensorCopy2{}; + const auto blockwise_wei_copy = + Blockwise2dTensorCopy2{}; #elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; #endif - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + // a series of blockwise GEMM + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx and b_mtx saved in LDS, c_mtx saved in register + // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] + // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] + // c_mtx[K,B] is out_block[K,B] + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); #if 0 const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}; #else - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; + const auto blockwise_gemm = + BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; #endif - // LDS: be careful of alignment - constexpr index_t in_block_size = - in_cb_block_desc.GetElementSpace(Number{}); + // LDS: be careful of alignment + constexpr index_t in_block_element_size = + in_cb_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_size = - wei_cyxk_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_element_size = + wei_cyxk_block_desc.GetElementSpace(Number{}); - constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; - // LDS double buffer - __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_0[max_align * ((wei_block_size + max_align - 1) / max_align)]; + // LDS double buffer + __shared__ Float + p_in_block_0[max_align * ((in_block_element_size + max_align - 1) / max_align)]; + __shared__ Float + p_wei_block_0[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; - __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; + __shared__ Float + p_in_block_1[max_align * ((in_block_element_size + max_align - 1) / max_align)]; + __shared__ Float + p_wei_block_1[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); + const Float* p_in_global_block_offset = + p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); - const Float* p_wei_global_block_offset = - p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); + const Float* p_wei_global_block_offset = + p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - // preload data into LDS - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); + // preload data into LDS + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0); + blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0); - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; + // register + Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - // set threadwise output tensor to 0 - threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); + // set threadwise output tensor to 0 + threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - bool even_loop = true; + bool even_loop = true; - for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - even_loop = !even_loop) - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; + c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + even_loop = !even_loop) + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; - Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; + Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0; + Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0; - __syncthreads(); + __syncthreads(); // load next data #if 0 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); #elif 1 - Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; - Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; - blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, - p_in_register_clipboard); + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); - blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, - p_wei_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); #endif - // compute on current data - // a series of GEMM - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) + // compute on current data + // a series of GEMM + for(index_t y = 0; y < Y; ++y) { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; + for(index_t x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; #if 1 - blockwise_gemm.Run + blockwise_gemm.Run #else - blockwise_gemm.Run_RegisterDoubleBuffer + blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); + } + } + +#if 1 + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, + p_wei_block_next); +#endif + } + + // last computation + { + Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; + Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; + + __syncthreads(); + + for(index_t y = 0; y < Y; ++y) + { + for(index_t x = 0; x < X; ++x) + { + auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; +#if 1 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); + } } } -#if 1 - blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); -#endif - } + // output: register to global mem, + const auto c_thread_mtx_begin = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - // last computation - { - Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; - Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; - - __syncthreads(); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 - blockwise_gemm.Run -#else - blockwise_gemm.Run_RegisterDoubleBuffer -#endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; #if 0 if(get_block_1d_id() == 0) @@ -348,26 +355,27 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( } #endif - for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; - index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - index_t h_data = b_data / (Wi * N); - index_t itmp = b_data - h_data * (Wi * N); - index_t w_data = itmp / N; - index_t n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) + for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) { - p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + const auto c_thread_mtx_distance = + blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); + + index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; + index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; + + index_t h_data = b_data / (Wi * N); + index_t itmp = b_data - h_data * (Wi * N); + index_t w_data = itmp / N; + index_t n_data = itmp - w_data * N; + + if(n_data < N && h_data < Ho && w_data < Wo) + { + p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = + p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; + } } } } -} +}; diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index e1c22ce39a..d1c7e830d0 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -10,11 +10,9 @@ __device__ void threadwise_matrix_copy(SrcMatrix, constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; -#if 1 - //NRow = 1 +#if 0 for(index_t i = 0; i < NRow; ++i) { - //NCol = 4 for(index_t j = 0; j < NCol; ++j) { const index_t src_index = src_mtx.Get1dIndex(i, j); @@ -23,7 +21,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, p_dst[dst_index] = p_src[src_index]; } } -#elif 0 +#elif 1 static_assert(NCol == 4, "only for NCol == 4"); using vector_t = typename vector_type::MemoryType; @@ -33,22 +31,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix, const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0); -#if 0 - *(reinterpret_cast(&p_dst[dst_index]) = + *(reinterpret_cast(&p_dst[dst_index])) = *(reinterpret_cast(&p_src[src_index])); -#elif 0 - asm volatile("\n \ - ds_read2_b64 %0, %1 offset1:1 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast(&p_dst[dst_index]))) - : "v"(__to_local((void*)(&p_src[src_index])))); -#elif 1 - asm volatile("\n \ - ds_read_b128 %0, %1 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(*(reinterpret_cast(&p_dst[dst_index]))) - : "v"(__to_local((void*)(&p_src[src_index])))); -#endif } #endif } @@ -84,13 +68,10 @@ __device__ void threadwise_gemm(MatrixA, constexpr index_t N = c_mtx.NCol(); constexpr index_t K = a_mtx.NRow(); // A is transposed - // K = 1 for(index_t k = 0; k < K; ++k) { - // M = 8 for(index_t i = 0; i < M; ++i) { - // N = 8 for(index_t j = 0; j < N; ++j) { const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed