diff --git a/driver/device_direct_convolution_1.hpp b/driver/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp similarity index 66% rename from driver/device_direct_convolution_1.hpp rename to driver/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp index 93b57c7511..a18a7be5c5 100644 --- a/driver/device_direct_convolution_1.hpp +++ b/driver/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp @@ -1,16 +1,17 @@ #pragma once #include #include "device.hpp" -#include "gridwise_direct_convolution_1.hip.hpp" +#include "gridwise_convolution_wrapper.hip.hpp" +#include "gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hip.hpp" template -void device_direct_convolution_1(InDesc, - const Tensor& in, - WeiDesc, - const Tensor& wei, - OutDesc, - Tensor& out, - index_t nrepeat) +void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc, + const Tensor& in, + WeiDesc, + const Tensor& wei, + OutDesc, + Tensor& out, + index_t nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -33,11 +34,11 @@ void device_direct_convolution_1(InDesc, constexpr auto out_desc = OutDesc{}; #if 1 - // 3x3, 34x34 + // 3x3, 34x34, 128 thread constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 16; - constexpr index_t CPerBlock = 2; - constexpr index_t HoPerBlock = 4; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; constexpr index_t WoPerBlock = 32; constexpr index_t NPerThread = 2; @@ -46,6 +47,9 @@ void device_direct_convolution_1(InDesc, constexpr index_t HoPerThread = 2; constexpr index_t WoPerThread = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t BlockSize = 128; #endif @@ -57,24 +61,28 @@ void device_direct_convolution_1(InDesc, for(index_t i = 0; i < nrepeat; ++i) { - float time = launch_kernel(gridwise_direct_convolution_1, + using gridwise_conv = GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw; + float time = launch_kernel(run_gridwise_convolution, dim3(GridSize), dim3(BlockSize), + 0, static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer())); diff --git a/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp b/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 93c0916eae..0000000000 --- a/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,111 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp" - -template -void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, - const Tensor& in, - WeiDesc, - const Tensor& wei, - OutDesc, - Tensor& out, - index_t nrepeat) -{ - std::size_t data_sz = sizeof(T); - DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace()); - DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace()); - - int num_thread = std::thread::hardware_concurrency(); - - in_device_buf.ToDevice(in.mData.data()); - wei_device_buf.ToDevice(wei.mData.data()); - out_device_buf.ToDevice(out.mData.data()); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - -#if 1 - // 3x3, 34x34, 128 thread - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t BlockSize = 128; -#elif 1 - // 3x3, 34x34, 128 thread, fp16 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t BlockSize = 128; -#endif - - constexpr index_t GridSize = - (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * - (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_device_buf.FromDevice(out.mData.data()); -} diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index ed7fa09d1d..c72e1eab3b 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -7,8 +7,7 @@ #include "tensor.hpp" #include "ConstantTensorDescriptor.hip.hpp" #include "conv_common.hip.hpp" -//#include "device_direct_convolution_1.hpp" -#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp" +#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" //#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" #include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp" @@ -604,15 +603,15 @@ int main(int argc, char* argv[]) #if 1 #if 0 device_direct_convolution_1 -#elif 0 - device_direct_convolution_2_nchw_kcyx_nkhw +#elif 1 + device_convolution_direct_v2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw #elif 0 device_convolution_implicit_gemm_v1_chwn_cyxk_khwn #elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_khwn -#elif 1 +#elif 0 device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw #elif 0 device_convolution_implicit_gemm_v2_chwn_cyxk_khwn diff --git a/src/include/blockwise_direct_convolution.hip.hpp b/src/include/blockwise_direct_convolution.hip.hpp deleted file mode 100644 index c79833f17d..0000000000 --- a/src/include/blockwise_direct_convolution.hip.hpp +++ /dev/null @@ -1,134 +0,0 @@ -#pragma once -#include "ConstantTensorDescriptor.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "threadwise_direct_convolution.hip.hpp" - -template -__device__ void blockwise_direct_convolution(InBlockDesc, - Float* const __restrict__ p_in_block, - WeiBlockDesc, - Float* const __restrict__ p_wei_block, - OutBlockDesc, - Float* __restrict__ p_out_block) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_block_desc = InBlockDesc{}; - constexpr auto wei_block_desc = WeiBlockDesc{}; - constexpr auto out_block_desc = OutBlockDesc{}; - - constexpr index_t Y = wei_block_desc.GetLength(I2); - constexpr index_t X = wei_block_desc.GetLength(I3); - - constexpr index_t InTileSizeH = HoPerThread + Y - 1; - constexpr index_t InTileSizeW = WoPerThread + X - 1; - - // divide thread work - constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; - constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; - constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; - constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; - -#if 0 - if(get_thread_local_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_block_desc); - print_ConstantTensorDescriptor(wei_block_desc); - print_ConstantTensorDescriptor(out_block_desc); - } -#endif - - constexpr auto in_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto out_thread_desc = - get_convolution_output_default_4d_tensor_descriptor(in_thread_desc, wei_thread_desc); - - constexpr auto in_thread_block_desc = - make_ConstantTensorDescriptor(in_thread_desc.GetLengths(), in_block_desc.GetStrides()); - - constexpr auto wei_thread_block_desc = - make_ConstantTensorDescriptor(wei_thread_desc.GetLengths(), wei_block_desc.GetStrides()); - - constexpr auto out_thread_block_desc = - make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides()); - - const index_t thread_id = get_thread_local_1d_id(); - - for(index_t thread_work_id = thread_id; - thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork; - thread_work_id += BlockSize) - { - index_t itmp = thread_work_id; - index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); - itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); - index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork); - itmp -= k_thread_work_id * (YThreadWork * XThreadWork); - index_t y_thread_work_id = itmp / XThreadWork; - index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork; - - index_t n_thread_data_begin = n_thread_work_id * NPerThread; - index_t k_thread_data_begin = k_thread_work_id * KPerThread; - index_t ho_thread_data_begin = y_thread_work_id * HoPerThread; - index_t wo_thread_data_begin = x_thread_work_id * WoPerThread; - - index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding - index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding - - Float p_out_thread[out_thread_desc.GetElementSpace()]; - - threadwise_4d_tensor_copy(out_block_desc, - p_out_block + - out_block_desc.Get1dIndex(n_thread_data_begin, - k_thread_data_begin, - ho_thread_data_begin, - wo_thread_data_begin), - out_thread_desc, - p_out_thread, - out_thread_desc.GetLengths()); - - for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); - c_thread_data_begin += CPerThread) - { - // threadwise convolution - threadwise_direct_convolution_2( - in_thread_block_desc, - p_in_block + - in_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data_begin, - hi_thread_data_begin, - wi_thread_data_begin), - wei_thread_block_desc, - p_wei_block + - wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0), - out_thread_desc, - p_out_thread); - } - - // copy output into LDS - threadwise_4d_tensor_copy(out_thread_desc, - p_out_thread, - out_block_desc, - p_out_block + - out_block_desc.Get1dIndex(n_thread_data_begin, - k_thread_data_begin, - ho_thread_data_begin, - wo_thread_data_begin), - out_thread_desc.GetLengths()); - } -} diff --git a/src/include/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hip.hpp new file mode 100644 index 0000000000..92bd32376f --- /dev/null +++ b/src/include/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hip.hpp @@ -0,0 +1,244 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "blockwise_2d_tensor_op.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" +#include "threadwise_direct_convolution.hip.hpp" + +template +struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_global_desc = InGlobalDesc{}; + constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{}; + constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_nchw_global_desc.GetLength(I0); + constexpr index_t K = wei_kcyx_global_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_global_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_global_desc.GetLength(I3); + + constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor( + Sequence{}); // 2d view of wei for blockwise copy + + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; + + constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); + + constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned( + Sequence{}, + Number{}); // 2d view of wei for blockwise copy + + constexpr auto wei_kcyx_block_desc = + make_ConstantTensorDescriptor(Sequence{}, + Sequence{}); + + // shared mem + constexpr index_t in_block_element_size = + in_nchw_block_desc.GetElementSpace(Number{}); + constexpr index_t wei_block_element_size = + wei_kcyx_block_desc.GetElementSpace(Number{}); + + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; + + __shared__ Float + p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; + __shared__ Float + p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; + + // threadwise tensors + constexpr index_t HiPerThread = HoPerThread + Y - 1; + constexpr index_t WiPerThread = WoPerThread + X - 1; + + constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor( + Sequence{}, + in_nchw_block_desc.GetStrides()); + + constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( + Sequence{}, wei_kcyx_block_desc.GetStrides()); + + constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor( + in_nchw_thread_block_desc, wei_kcyx_thread_block_desc); + + // register + Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; + + // divide block work + constexpr index_t NBlockWork = + (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = + (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = + (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = + (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + + const index_t block_id = blockIdx.x; + + index_t itmp = block_id; + const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + const index_t h_block_work_id = itmp / WBlockWork; + const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; + + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + + const index_t hi_block_data_begin = ho_block_data_begin; // minus padding + const index_t wi_block_data_begin = wo_block_data_begin; // minus padding + + // divide thread work + constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; + constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; + constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; + + const index_t thread_id = get_thread_local_1d_id(); + + itmp = thread_id; + const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); + itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); + const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork); + itmp -= k_thread_work_id * (HThreadWork * WThreadWork); + const index_t h_thread_work_id = itmp / WThreadWork; + const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork; + + const index_t n_thread_data_begin = n_thread_work_id * NPerThread; + const index_t k_thread_data_begin = k_thread_work_id * KPerThread; + const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread; + const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread; + + const index_t hi_thread_data_begin = ho_thread_data_begin; + const index_t wi_thread_data_begin = wo_thread_data_begin; + + constexpr auto blockwise_in_copy = + Blockwise4dTensorCopy1{}; + +#if 0 + constexpr auto blockwise_wei_copy = + Blockwise4dTensorCopy1{}; +#elif 1 + const auto blockwise_wei_copy = + Blockwise2dTensorCopy3{}; +#endif + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); + + for(index_t c_block_data_begin = 0; c_block_data_begin < C; + c_block_data_begin += CPerBlock, __syncthreads()) + { + // copy input tensor to LDS + blockwise_in_copy.Run(p_in_global + + in_nchw_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), + p_in_block); + + // copy weight tensor to LDS + blockwise_wei_copy.Run( + p_wei_global + + wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), + p_wei_block); + + __syncthreads(); + + for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) + { +// threadwise convolution +#if 1 + threadwise_direct_convolution_2( + in_nchw_thread_block_desc, + p_in_block + + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), + wei_kcyx_thread_block_desc, + p_wei_block + + wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), + out_nkhw_thread_desc, + p_out_thread); +#elif 0 + threadwise_direct_convolution_3( + in_nchw_thread_block_desc, + p_in_block + + in_nchw_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), + wei_kcyx_thread_block_desc, + p_wei_block + + wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), + out_nkhw_thread_desc, + p_out_thread); +#endif + } + } + + // copy output tensor from register to global mem + threadwise_nd_tensor_copy( + out_nkhw_thread_desc, + p_out_thread, + out_nkhw_global_desc, + p_out_global + + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_nkhw_thread_desc.GetLengths(), + Number<1>{}); + } +}; diff --git a/src/include/gridwise_direct_convolution_1.hip.hpp b/src/include/gridwise_direct_convolution_1.hip.hpp deleted file mode 100644 index 7723fb78b4..0000000000 --- a/src/include/gridwise_direct_convolution_1.hip.hpp +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_direct_convolution.hip.hpp" - -template -__global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_global_desc = InGlobalDesc{}; - constexpr auto wei_global_desc = WeiGlobalDesc{}; - constexpr auto out_global_desc = OutGlobalDesc{}; - - constexpr index_t Y = wei_global_desc.GetLength(I2); - constexpr index_t X = wei_global_desc.GetLength(I3); - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - constexpr auto in_block_global_desc = make_ConstantTensorDescriptor( - Sequence{}, in_global_desc.GetStrides()); - - constexpr auto wei_block_global_desc = make_ConstantTensorDescriptor( - Sequence{}, wei_global_desc.GetStrides()); - - constexpr auto out_block_global_desc = make_ConstantTensorDescriptor( - Sequence{}, out_global_desc.GetStrides()); - - constexpr auto in_block_desc = make_ConstantTensorDescriptor(in_block_global_desc.GetLengths()); - constexpr auto wei_block_desc = - make_ConstantTensorDescriptor(wei_block_global_desc.GetLengths()); - constexpr auto out_block_desc = - make_ConstantTensorDescriptor(out_block_global_desc.GetLengths()); - - constexpr index_t in_block_element_size = in_block_desc.GetElementSpace(); - constexpr index_t wei_block_element_size = wei_block_desc.GetElementSpace(); - constexpr index_t out_block_size = out_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_element_size]; - __shared__ Float p_wei_block[wei_block_element_size]; - __shared__ Float p_out_block[out_block_size]; - - const index_t block_id = blockIdx.x; - - index_t itmp = block_id; - index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); - itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); - itmp -= k_block_work_id * (HBlockWork * WBlockWork); - index_t h_block_work_id = itmp / WBlockWork; - index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - - index_t n_block_work_begin = n_block_work_id * NPerBlock; - index_t k_block_work_begin = k_block_work_id * KPerBlock; - index_t ho_block_work_begin = h_block_work_id * HoPerBlock; - index_t wo_block_work_begin = w_block_work_id * WoPerBlock; - - index_t hi_block_work_begin = ho_block_work_begin; // minus padding - index_t wi_block_work_begin = wo_block_work_begin; // minus padding - - constexpr auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; - - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; - - constexpr auto blockwise_out_copy = - Blockwise4dTensorCopy1{}; - - // set output tensor in LDS to 0 - blockwise_4d_tensor_set_zero(out_block_desc, p_out_block); - - for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1); - c_block_work_begin += CPerBlock) - { - // copy input tensor to LDS - blockwise_in_copy.Run(p_in_global + - in_global_desc.Get1dIndex(n_block_work_begin, - c_block_work_begin, - hi_block_work_begin, - wi_block_work_begin), - p_in_block); - - // copy weight tensor to LDS - blockwise_wei_copy.Run( - p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), - p_wei_block); - - __syncthreads(); - - // blockwise convolution - blockwise_direct_convolution( - in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block); - - __syncthreads(); - } - - // copy output tensor from LDS to device mem - blockwise_out_copy.Run( - p_out_block, - p_out_global + - out_global_desc.Get1dIndex( - n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin)); -} diff --git a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp deleted file mode 100644 index cbebe28f17..0000000000 --- a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp +++ /dev/null @@ -1,237 +0,0 @@ -#pragma once -#include "common.hip.hpp" -#include "ConstantTensorDescriptor.hip.hpp" -#include "blockwise_2d_tensor_op.hip.hpp" -#include "blockwise_4d_tensor_op.hip.hpp" -#include "blockwise_direct_convolution.hip.hpp" -#include "threadwise_4d_tensor_op.hip.hpp" -#include "threadwise_direct_convolution.hip.hpp" - -template -__global__ void -gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_global_desc = InGlobalDesc{}; - constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{}; - constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_nchw_global_desc.GetLength(I0); - constexpr index_t K = wei_kcyx_global_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_global_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_global_desc.GetLength(I3); - - constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor( - Sequence{}); // 2d view of wei for blockwise copy - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); // 2d view of wei for blockwise copy - - constexpr auto wei_kcyx_block_desc = - make_ConstantTensorDescriptor(Sequence{}, - Sequence{}); - - // shared mem - constexpr index_t in_block_element_size = - in_nchw_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_element_size = - wei_kcyx_block_desc.GetElementSpace(Number{}); - - constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - __shared__ Float p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; - __shared__ Float - p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; - - // threadwise tensors - constexpr index_t HiPerThread = HoPerThread + Y - 1; - constexpr index_t WiPerThread = WoPerThread + X - 1; - - constexpr auto in_nchw_thread_block_desc = - make_ConstantTensorDescriptor(Sequence{}, - in_nchw_block_desc.GetStrides()); - - constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( - Sequence{}, wei_kcyx_block_desc.GetStrides()); - - constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor( - in_nchw_thread_block_desc, wei_kcyx_thread_block_desc); - - // register - Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; - - // divide block work - constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = - (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = - (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - const index_t block_id = blockIdx.x; - - index_t itmp = block_id; - const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); - itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); - itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const index_t h_block_work_id = itmp / WBlockWork; - const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - - const index_t n_block_data_begin = n_block_work_id * NPerBlock; - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; // minus padding - const index_t wi_block_data_begin = wo_block_data_begin; // minus padding - - // divide thread work - constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; - constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; - - const index_t thread_id = get_thread_local_1d_id(); - - itmp = thread_id; - const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); - itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); - const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork); - itmp -= k_thread_work_id * (HThreadWork * WThreadWork); - const index_t h_thread_work_id = itmp / WThreadWork; - const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork; - - const index_t n_thread_data_begin = n_thread_work_id * NPerThread; - const index_t k_thread_data_begin = k_thread_work_id * KPerThread; - const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread; - const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread; - - const index_t hi_thread_data_begin = ho_thread_data_begin; - const index_t wi_thread_data_begin = wo_thread_data_begin; - - constexpr auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; - -#if 0 - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; -#elif 1 - const auto blockwise_wei_copy = Blockwise2dTensorCopy3{}; -#endif - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock, __syncthreads()) - { - // copy input tensor to LDS - blockwise_in_copy.Run(p_in_global + - in_nchw_global_desc.Get1dIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - p_in_block); - - // copy weight tensor to LDS - blockwise_wei_copy.Run( - p_wei_global + - wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), - p_wei_block); - - __syncthreads(); - - for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) - { -// threadwise convolution -#if 1 - threadwise_direct_convolution_2( - in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), - wei_kcyx_thread_block_desc, - p_wei_block + - wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), - out_nkhw_thread_desc, - p_out_thread); -#elif 0 - threadwise_direct_convolution_3( - in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.Get1dIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), - wei_kcyx_thread_block_desc, - p_wei_block + - wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), - out_nkhw_thread_desc, - p_out_thread); -#endif - } - } - - // copy output tensor from register to global mem - threadwise_4d_tensor_copy( - out_nkhw_thread_desc, - p_out_thread, - out_nkhw_global_desc, - p_out_global + - out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_nkhw_thread_desc.GetLengths()); -} diff --git a/src/include/threadwise_direct_convolution.hip.hpp b/src/include/threadwise_direct_convolution.hip.hpp index 1c033573b9..70f60e67cf 100644 --- a/src/include/threadwise_direct_convolution.hip.hpp +++ b/src/include/threadwise_direct_convolution.hip.hpp @@ -1,5 +1,6 @@ #pragma once #include "ConstantTensorDescriptor.hip.hpp" +#include "threadwise_nd_tensor_op.hip.hpp" // optimized for scenario if p_in, p_wei, p_out are in register template @@ -84,10 +85,12 @@ __device__ void threadwise_direct_convolution_2(InDesc, TInWei p_wei_reg[wei_reg_desc.GetElementSpace()]; // copy input tensor into register - threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths()); + threadwise_nd_tensor_copy( + in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{}); // copy input tensor into register - threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths()); + threadwise_nd_tensor_copy( + wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{}); // do convolution threadwise_direct_convolution_1(