diff --git a/driver/device_direct_convolution_1.hpp b/driver/device_direct_convolution_1.hpp index 57abc62432..93b57c7511 100644 --- a/driver/device_direct_convolution_1.hpp +++ b/driver/device_direct_convolution_1.hpp @@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc, const Tensor& wei, OutDesc, Tensor& out, - unsigned nrepeat) + index_t nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc, #if 1 // 3x3, 34x34 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 16; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 16; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel(gridwise_direct_convolution_1& wei, OutDesc, Tensor& out, - unsigned nrepeat) + index_t nrepeat) { std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); @@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, #if 1 // 3x3, 34x34, 128 thread - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 1 // 3x3, 34x34, 128 thread, fp16 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw& wei_kcyx, OutDesc, Tensor& out_nkhw, - unsigned nrepeat) + index_t nrepeat) { // this suppose in / wei data type is int8x4 - constexpr unsigned NVector = 4; - using accum_t = int32_t; - using vector_t = vector_type; - using vector_mem_t = typename vector_t::MemoryType; + constexpr index_t NVector = 4; + using accum_t = int32_t; + using vector_t = vector_type; + using vector_mem_t = typename vector_t::MemoryType; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto out_nkhw_desc = OutDesc{}; - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - constexpr unsigned K = wei_kcyx_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); // vectorized input auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, #if 0 // 3x3, 34x34, 128 thread, fp32, vector = 1 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 3x3, 34x34, 128 thread, fp32, vector = 2 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 2; + constexpr index_t KPerThread = 4; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 3x3, 34x34, 128 thread, int8, vector = 4 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 2; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 1; - constexpr unsigned KPerThread = 8; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 4; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 8; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 4; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 1 // 1x1, 32x32, 128 thread, int8, vector = 4 - constexpr unsigned NPerBlock = 1; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 16; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 32; + constexpr index_t NPerBlock = 1; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 16; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 32; - constexpr unsigned NPerThread = 1; - constexpr unsigned KPerThread = 8; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 4; - constexpr unsigned WoPerThread = 2; + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 8; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 4; + constexpr index_t WoPerThread = 2; - constexpr unsigned InBlockCopyDataPerRead = 2; - constexpr unsigned WeiBlockCopyDataPerRead = 2; + constexpr index_t InBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel( gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw& wei_kcyx, OutDesc, Tensor& out_nkhw, - unsigned nrepeat) + index_t nrepeat) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto out_nkhw_desc = OutDesc{}; - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - constexpr unsigned K = wei_kcyx_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); // reorder weight auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, #if 0 // for 3x3, 34x34 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 8; - constexpr unsigned KPerThread = 8; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 8; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned InBlockCopy_ThreadPerDimC = 4; - constexpr unsigned InBlockCopy_ThreadPerDimH = 4; - constexpr unsigned InBlockCopy_ThreadPerDimW = 2; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 5x5, 36x36 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 8; - constexpr unsigned KPerThread = 8; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 8; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned InBlockCopy_ThreadPerDimC = 2; - constexpr unsigned InBlockCopy_ThreadPerDimH = 2; - constexpr unsigned InBlockCopy_ThreadPerDimW = 4; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopy_ThreadPerDimC = 2; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 2; + constexpr index_t WeiBlockCopyDataPerRead = 2; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 3x3 58x58, NKC = 64, 64, 256 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 3x3 58x58, NKC = 16,256,128 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 8; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 7x7, 38x38 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 8; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 1; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 3x3, 56x56 - constexpr unsigned NPerBlock = 32; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 1x1, 28x28 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned InBlockCopy_ThreadPerDimC = 8; - constexpr unsigned InBlockCopy_ThreadPerDimH = 2; - constexpr unsigned InBlockCopy_ThreadPerDimW = 2; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 1 // for 1x1, 14x14 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned InBlockCopy_ThreadPerDimC = 8; - constexpr unsigned InBlockCopy_ThreadPerDimH = 2; - constexpr unsigned InBlockCopy_ThreadPerDimW = 2; - constexpr unsigned InBlockCopy_ThreadPerDimN = 4; - constexpr unsigned InBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopy_ThreadPerDimC = 8; + constexpr index_t InBlockCopy_ThreadPerDimH = 2; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 4; + constexpr index_t InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned OutThreadCopyDataPerWrite = 2; + constexpr index_t OutThreadCopyDataPerWrite = 2; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel( gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn& out_nkhw, LowerPads, UpperPads, - unsigned nrepeat) + index_t nrepeat) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto out_nkhw_desc = OutDesc{}; - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); - constexpr unsigned N = out_nkhw_desc.GetLength(I0); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + constexpr index_t N = out_nkhw_desc.GetLength(I0); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - constexpr unsigned K = wei_kcyx_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); // reorder weight auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc, out_khwn_device_buf.ToDevice(out_khwn.mData.data()); #if 0 - constexpr unsigned NPerBlock = 1; - constexpr unsigned KPerBlock = 1; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 1; + constexpr index_t KPerBlock = 1; + constexpr index_t CPerBlock = 1; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 1; - constexpr unsigned KPerThread = 1; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 1; + constexpr index_t KPerThread = 1; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 1; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 1; + constexpr index_t WeiBlockCopyThreadPerDim0 = 1; + constexpr index_t WeiBlockCopyThreadPerDim1 = 1; - constexpr unsigned BlockSize = 8; + constexpr index_t BlockSize = 8; #elif 1 // for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 3x3 58x58, NKC = 16,256,128 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 8; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 5x5, 36x36 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 7x7, 38x38 - constexpr unsigned NPerBlock = 8; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 8; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 3x3, 56x56 - constexpr unsigned NPerBlock = 32; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 1 // 3x3 56x56, NKC = 16,256,128, with padding // 3x3 28x28, NKC = 16,512,256, with padding // 3x3 20x84, NKC = 16,256,256, with padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 2; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 64; + constexpr index_t WeiBlockCopyThreadPerDim0 = 2; + constexpr index_t WeiBlockCopyThreadPerDim1 = 64; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 5x5 filter, 20x84 image, 1x1 padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 1; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 1; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 5x5 filter, 28x28 image, 2x2 padding - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 2; - constexpr unsigned HoPerBlock = 4; - constexpr unsigned WoPerBlock = 4; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 32; + constexpr index_t CPerBlock = 2; + constexpr index_t HoPerBlock = 4; + constexpr index_t WoPerBlock = 4; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 1; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 1; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // for 1x1, 28x28 - constexpr unsigned NPerBlock = 16; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 2; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; - constexpr unsigned NPerThread = 4; - constexpr unsigned KPerThread = 16; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 1; + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 16; + constexpr index_t CPerThread = 2; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 1; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 32; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel( gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded& wei_kcyx, OutDesc, Tensor& out_nkhw, - unsigned nrepeat) + index_t nrepeat) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr auto wei_kcyx_desc = WeiDesc{}; constexpr auto out_nkhw_desc = OutDesc{}; - constexpr unsigned N = in_nchw_desc.GetLength(I0); - constexpr unsigned Hi = in_nchw_desc.GetLength(I2); - constexpr unsigned Wi = in_nchw_desc.GetLength(I3); + constexpr index_t N = in_nchw_desc.GetLength(I0); + constexpr index_t Hi = in_nchw_desc.GetLength(I2); + constexpr index_t Wi = in_nchw_desc.GetLength(I3); - constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); - constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + constexpr index_t Ho = out_nkhw_desc.GetLength(I2); + constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - constexpr unsigned K = wei_kcyx_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_desc.GetLength(I3); + constexpr index_t K = wei_kcyx_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_desc.GetLength(I3); - constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); // convert in_nchw to in_cnhw auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, #if 0 // 3x3, 34x34 // need to use register double buffer for GEMM - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; + constexpr index_t BPerBlock = 128; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 8; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 1x1, 28x28, 64 threads - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 8; - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 2; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 64; -#elif 1 + constexpr index_t BlockSize = 64; +#elif 0 // 1x1, 28x28, 128 threads, no lds-double-buffer // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 2; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 128; + constexpr index_t BlockSize = 128; #elif 0 // 1x1, 28x28, 256 thread - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 128; - constexpr unsigned CPerBlock = 8; + constexpr index_t BPerBlock = 128; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; - constexpr unsigned BPerThread = 8; - constexpr unsigned KPerThread = 8; + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; - constexpr unsigned GemmMPerThreadSubC = 4; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 4; - constexpr unsigned GemmMLevel1Cluster = 4; - constexpr unsigned GemmNLevel1Cluster = 4; - constexpr unsigned GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; - constexpr unsigned GemmThreadPerColumnPerCluster = 8; - constexpr unsigned GemmThreadPerRowPerCluster = 8; + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr unsigned BlockSize = 256; + constexpr index_t BlockSize = 256; +#elif 1 + // 1x1, 14x14, Vega 10 + constexpr index_t BPerBlock = 64; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 8; + + constexpr index_t BPerThread = 8; + constexpr index_t KPerThread = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t GemmThreadPerColumnPerCluster = 8; + constexpr index_t GemmThreadPerRowPerCluster = 8; + + constexpr index_t InBlockCopyThreadPerDim0 = 4; + constexpr index_t InBlockCopyThreadPerDim1 = 16; + + constexpr index_t WeiBlockCopyThreadPerDim0 = 4; + constexpr index_t WeiBlockCopyThreadPerDim1 = 16; + + constexpr index_t InBlockCopyDataPerRead = 4; + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t BlockSize = 128; #endif - constexpr unsigned GridSize = + constexpr index_t GridSize = ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); @@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - for(unsigned i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < nrepeat; ++i) { float time = launch_kernel( #if 1 diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 1ae7ecb78b..a83e4082c7 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard template double operator()(Ts... Xs) const { - std::array dims = {{Xs...}}; + std::array dims = {{Xs...}}; return std::accumulate(dims.begin(), dims.end(), true, - [](bool init, unsigned long x) -> int { return init != (x % 2); }) + [](bool init, index_t x) -> int { return init != (x % 2); }) ? 1 : -1; } @@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc) constexpr auto I3 = Number<3>{}; constexpr auto desc = TConstTensorDesc{}; - std::initializer_list lengths = { + std::initializer_list lengths = { desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)}; - std::initializer_list strides = { + std::initializer_list strides = { desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)}; return TensorDescriptor(lengths, strides); @@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor& in_nchw, LowerPads, UpperPads) { - unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); - unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); + index_t h_pad_low = LowerPads{}.Get(Number<0>{}); + index_t w_pad_low = LowerPads{}.Get(Number<1>{}); - unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); - unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); + index_t h_pad_up = UpperPads{}.Get(Number<0>{}); + index_t w_pad_up = UpperPads{}.Get(Number<1>{}); auto f = [&](auto n, auto k, auto ho, auto wo) { double v = 0; @@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; - unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); - unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); + index_t h_pad_low = LowerPads{}.Get(Number<0>{}); + index_t w_pad_low = LowerPads{}.Get(Number<1>{}); - unsigned h_pad_up = UpperPads{}.Get(Number<0>{}); - unsigned w_pad_up = UpperPads{}.Get(Number<1>{}); + index_t h_pad_up = UpperPads{}.Get(Number<0>{}); + index_t w_pad_up = UpperPads{}.Get(Number<1>{}); std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t WiPerTile = WoPerTile + X - 1; @@ -399,211 +399,211 @@ void check_error(const Tensor& ref, const Tensor& result) int main(int argc, char* argv[]) { #if 0 - constexpr unsigned N = 1; - constexpr unsigned C = 1; - constexpr unsigned HI = 28; - constexpr unsigned WI = 28; - constexpr unsigned K = 1; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 1; + constexpr index_t C = 1; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 1; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3, 34x34 - constexpr unsigned N = 64; - constexpr unsigned C = 256; - constexpr unsigned HI = 34; - constexpr unsigned WI = 34; - constexpr unsigned K = 64; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 34; + constexpr index_t WI = 34; + constexpr index_t K = 64; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3, 56x56 - constexpr unsigned N = 64; - constexpr unsigned C = 64; - constexpr unsigned HI = 56; - constexpr unsigned WI = 56; - constexpr unsigned K = 64; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 64; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 64; + constexpr index_t Y = 3; + constexpr index_t X = 3; #elif 0 // 3x3, 58x58 - constexpr unsigned N = 64; - constexpr unsigned C = 64; - constexpr unsigned HI = 58; - constexpr unsigned WI = 58; - constexpr unsigned K = 64; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 64; + constexpr index_t C = 64; + constexpr index_t HI = 58; + constexpr index_t WI = 58; + constexpr index_t K = 64; + constexpr index_t Y = 3; + constexpr index_t X = 3; #elif 0 // 5x5, 36x36 - constexpr unsigned N = 64; - constexpr unsigned C = 256; - constexpr unsigned HI = 36; - constexpr unsigned WI = 36; - constexpr unsigned K = 64; - constexpr unsigned Y = 5; - constexpr unsigned X = 5; + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 36; + constexpr index_t WI = 36; + constexpr index_t K = 64; + constexpr index_t Y = 5; + constexpr index_t X = 5; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 7x7, 38x38 - constexpr unsigned N = 64; - constexpr unsigned C = 256; - constexpr unsigned HI = 38; - constexpr unsigned WI = 38; - constexpr unsigned K = 64; - constexpr unsigned Y = 7; - constexpr unsigned X = 7; + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 38; + constexpr index_t WI = 38; + constexpr index_t K = 64; + constexpr index_t Y = 7; + constexpr index_t X = 7; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3, 58x58 - constexpr unsigned N = 16; - constexpr unsigned C = 128; - constexpr unsigned HI = 58; - constexpr unsigned WI = 58; - constexpr unsigned K = 256; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 128; + constexpr index_t HI = 58; + constexpr index_t WI = 58; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; #elif 0 // 3x3 filter, 58x58 image, 0x0 padding - constexpr unsigned N = 16; - constexpr unsigned C = 128; - constexpr unsigned HI = 58; - constexpr unsigned WI = 58; - constexpr unsigned K = 256; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 128; + constexpr index_t HI = 58; + constexpr index_t WI = 58; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3 filter, 56x56 image, 1x1 padding - constexpr unsigned N = 16; - constexpr unsigned C = 128; - constexpr unsigned HI = 56; - constexpr unsigned WI = 56; - constexpr unsigned K = 256; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 128; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 1; - constexpr unsigned WPad = 1; + constexpr index_t HPad = 1; + constexpr index_t WPad = 1; #elif 0 // 3x3 filter, 28x28 image, 1x1 padding - constexpr unsigned N = 16; - constexpr unsigned C = 256; - constexpr unsigned HI = 28; - constexpr unsigned WI = 28; - constexpr unsigned K = 512; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 256; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 512; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 1; - constexpr unsigned WPad = 1; + constexpr index_t HPad = 1; + constexpr index_t WPad = 1; #elif 0 // 1x1 filter, 28x28 image - constexpr unsigned N = 16; - constexpr unsigned C = 256; - constexpr unsigned HI = 28; - constexpr unsigned WI = 28; - constexpr unsigned K = 512; - constexpr unsigned Y = 1; - constexpr unsigned X = 1; + constexpr index_t N = 16; + constexpr index_t C = 256; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 512; + constexpr index_t Y = 1; + constexpr index_t X = 1; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3 filter, 20x84 image, 1x1 padding - constexpr unsigned N = 16; - constexpr unsigned C = 256; - constexpr unsigned HI = 20; - constexpr unsigned WI = 84; - constexpr unsigned K = 256; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 256; + constexpr index_t HI = 20; + constexpr index_t WI = 84; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 1; - constexpr unsigned WPad = 1; + constexpr index_t HPad = 1; + constexpr index_t WPad = 1; #elif 0 // 3x3 filter, 112x112 image, 1x1 padding - constexpr unsigned N = 16; - constexpr unsigned C = 64; - constexpr unsigned HI = 112; - constexpr unsigned WI = 112; - constexpr unsigned K = 128; - constexpr unsigned Y = 3; - constexpr unsigned X = 3; + constexpr index_t N = 16; + constexpr index_t C = 64; + constexpr index_t HI = 112; + constexpr index_t WI = 112; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; - constexpr unsigned HPad = 1; - constexpr unsigned WPad = 1; + constexpr index_t HPad = 1; + constexpr index_t WPad = 1; #elif 0 // 5x5 filter, 20x86 image, 1x1 padding - constexpr unsigned N = 16; - constexpr unsigned C = 256; - constexpr unsigned HI = 20; - constexpr unsigned WI = 86; - constexpr unsigned K = 512; - constexpr unsigned Y = 5; - constexpr unsigned X = 5; + constexpr index_t N = 16; + constexpr index_t C = 256; + constexpr index_t HI = 20; + constexpr index_t WI = 86; + constexpr index_t K = 512; + constexpr index_t Y = 5; + constexpr index_t X = 5; - constexpr unsigned HPad = 1; - constexpr unsigned WPad = 1; + constexpr index_t HPad = 1; + constexpr index_t WPad = 1; #elif 0 // 5x5 filter, 28x28 image, 2x2 padding - constexpr unsigned N = 16; - constexpr unsigned C = 192; - constexpr unsigned HI = 28; - constexpr unsigned WI = 28; - constexpr unsigned K = 32; - constexpr unsigned Y = 5; - constexpr unsigned X = 5; + constexpr index_t N = 16; + constexpr index_t C = 192; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 32; + constexpr index_t Y = 5; + constexpr index_t X = 5; - constexpr unsigned HPad = 2; - constexpr unsigned WPad = 2; + constexpr index_t HPad = 2; + constexpr index_t WPad = 2; #elif 0 // 1x1 filter, 32x32 image - constexpr unsigned N = 64; - constexpr unsigned C = 256; - constexpr unsigned HI = 32; - constexpr unsigned WI = 32; - constexpr unsigned K = 512; - constexpr unsigned Y = 1; - constexpr unsigned X = 1; + constexpr index_t N = 64; + constexpr index_t C = 256; + constexpr index_t HI = 32; + constexpr index_t WI = 32; + constexpr index_t K = 512; + constexpr index_t Y = 1; + constexpr index_t X = 1; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 - // 1x1 filter, 14x14 image - constexpr unsigned N = 128; - constexpr unsigned C = 2048; - constexpr unsigned HI = 14; - constexpr unsigned WI = 14; - constexpr unsigned K = 512; - constexpr unsigned Y = 1; - constexpr unsigned X = 1; + // 1x1 filter, 14x14 image, C = 2048 + constexpr index_t N = 128; + constexpr index_t C = 2048; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 512; + constexpr index_t Y = 1; + constexpr index_t X = 1; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 1 // 1x1 filter, 14x14 image, C = 512 - constexpr unsigned N = 128; - constexpr unsigned C = 512; - constexpr unsigned HI = 14; - constexpr unsigned WI = 14; - constexpr unsigned K = 512; - constexpr unsigned Y = 1; - constexpr unsigned X = 1; + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 512; + constexpr index_t Y = 1; + constexpr index_t X = 1; - constexpr unsigned HPad = 0; - constexpr unsigned WPad = 0; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #endif auto lower_pads = Sequence{}; @@ -634,7 +634,7 @@ int main(int argc, char* argv[]) } bool do_verification = atoi(argv[1]); - unsigned nrepeat = atoi(argv[2]); + index_t nrepeat = atoi(argv[2]); if(do_verification) { diff --git a/src/include/Array.hip.hpp b/src/include/Array.hip.hpp index 89654cbc2b..f3a3d13681 100644 --- a/src/include/Array.hip.hpp +++ b/src/include/Array.hip.hpp @@ -1,18 +1,18 @@ #pragma once -template +template struct Array { using Type = Array; - static constexpr unsigned nSize = NSize; + static constexpr index_t nSize = NSize; - unsigned mData[nSize]; + index_t mData[nSize]; template __host__ __device__ Array(Xs... xs) : mData{static_cast(xs)...} { } - __host__ __device__ TData operator[](unsigned i) const { return mData[i]; } + __host__ __device__ TData operator[](index_t i) const { return mData[i]; } }; diff --git a/src/include/ConstantMatrixDescriptor.hip.hpp b/src/include/ConstantMatrixDescriptor.hip.hpp index d014e93574..9cacf27553 100644 --- a/src/include/ConstantMatrixDescriptor.hip.hpp +++ b/src/include/ConstantMatrixDescriptor.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "common.hip.hpp" -template +template struct ConstantMatrixDescriptor { __host__ __device__ constexpr ConstantMatrixDescriptor() @@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); } - __host__ __device__ constexpr unsigned NRow() const { return NRow_; } + __host__ __device__ constexpr index_t NRow() const { return NRow_; } - __host__ __device__ constexpr unsigned NCol() const { return NCol_; } + __host__ __device__ constexpr index_t NCol() const { return NCol_; } - __host__ __device__ constexpr unsigned RowStride() const { return RowStride_; } + __host__ __device__ constexpr index_t RowStride() const { return RowStride_; } __host__ __device__ constexpr auto GetLengths() const { return Sequence{}; } - __host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; } + __host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; } - __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; } + __host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; } - __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const + __host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const { +#if DEVICE_BACKEND_HIP + return __mul24(irow, RowStride_) + icol; +#else return irow * RowStride_ + icol; +#endif } - template + template __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number, Number) const { @@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor } }; -template +template __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number) { return ConstantMatrixDescriptor{}; } -template +template __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number, Number) { diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 2e5d237e81..4e883f12e7 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -2,35 +2,35 @@ #include "common.hip.hpp" // this is ugly, only for 2d -template +template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 4d -template +template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 6d -template +template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 8d -template +template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { @@ -45,48 +45,48 @@ __host__ __device__ constexpr auto } // this is ugly, only for 2d -template +template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { - constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align); + constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align); return Sequence{}; } // this is ugly, only for 4d -template +template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { - constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align); + constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align); return Sequence{}; } template struct ConstantTensorDescriptor { - using Type = ConstantTensorDescriptor; - static constexpr unsigned nDim = Lengths::nDim; + using Type = ConstantTensorDescriptor; + static constexpr index_t nDim = Lengths::nDim; __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); } - __host__ __device__ constexpr unsigned GetDimension() const { return nDim; } + __host__ __device__ constexpr index_t GetDimension() const { return nDim; } __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } - template - __host__ __device__ constexpr unsigned GetLength(Number) const + template + __host__ __device__ constexpr index_t GetLength(Number) const { return Lengths{}.Get(Number{}); } - template - __host__ __device__ constexpr unsigned GetStride(Number) const + template + __host__ __device__ constexpr index_t GetStride(Number) const { return Strides{}.Get(Number{}); } @@ -95,18 +95,18 @@ struct ConstantTensorDescriptor struct GetElementSize_f { template - __host__ __device__ constexpr unsigned operator()(IDim idim) const + __host__ __device__ constexpr index_t operator()(IDim idim) const { return Type{}.GetLength(idim); } }; - __host__ __device__ constexpr unsigned GetElementSize() const + __host__ __device__ constexpr index_t GetElementSize() const { // c++14 doesn't support constexpr lambdas, has to use this trick instead struct multiply { - __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const + __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const { return a * b; } @@ -119,19 +119,19 @@ struct ConstantTensorDescriptor struct GetElementSpace_f { template - __host__ __device__ constexpr unsigned operator()(IDim idim) const + __host__ __device__ constexpr index_t operator()(IDim idim) const { return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); } }; template > - __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const + __host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const { // c++14 doesn't support constexpr lambdas, has to use this trick instead struct add { - __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const + __host__ __device__ constexpr index_t operator()(index_t a, index_t b) const { return a + b; } @@ -141,17 +141,21 @@ struct ConstantTensorDescriptor } template - __host__ __device__ unsigned Get1dIndex(Is... is) const + __host__ __device__ index_t Get1dIndex(Is... is) const { static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); - const auto multi_id = Array(is...); + const auto multi_id = Array(is...); - unsigned id = 0; + index_t id = 0; static_loop_n{}([&](auto IDim) { - constexpr unsigned idim = IDim.Get(); + constexpr index_t idim = IDim.Get(); +#if DEVICE_BACKEND_HIP + id += __mul24(multi_id[idim], GetStride(IDim)); +#else id += multi_id[idim] * GetStride(IDim); +#endif }); return id; @@ -163,7 +167,7 @@ struct ConstantTensorDescriptor return ConstantTensorDescriptor{}; } - template + template __host__ __device__ constexpr auto Vectorize(Number, Number) const { assert(false); // not implemented @@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride return ConstantTensorDescriptor{}; } -template +template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number) { using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number{})); @@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths template __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) { - constexpr auto desc = TDesc{}; - constexpr unsigned ndim = desc.GetDimension(); + constexpr auto desc = TDesc{}; + constexpr index_t ndim = desc.GetDimension(); static_assert(ndim >= 2 && ndim <= 8, "wrong!"); diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index c8ca7a0f24..55caf14591 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -2,38 +2,38 @@ #include "constant_integral.hip.hpp" #include "functional.hip.hpp" -template +template struct Sequence { using Type = Sequence; - static constexpr unsigned nDim = sizeof...(Is); + static constexpr index_t nDim = sizeof...(Is); - const unsigned mData[nDim] = {Is...}; + const index_t mData[nDim] = {Is...}; - template - __host__ __device__ constexpr unsigned Get(Number) const + template + __host__ __device__ constexpr index_t Get(Number) const { return mData[I]; } // this is ugly, only for nDIm = 4 - template + template __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence) const { static_assert(nDim == 4, "nDim != 4"); constexpr auto old_sequence = Type{}; - constexpr unsigned NR0 = old_sequence.mData[I0]; - constexpr unsigned NR1 = old_sequence.mData[I1]; - constexpr unsigned NR2 = old_sequence.mData[I2]; - constexpr unsigned NR3 = old_sequence.mData[I3]; + constexpr index_t NR0 = old_sequence.mData[I0]; + constexpr index_t NR1 = old_sequence.mData[I1]; + constexpr index_t NR2 = old_sequence.mData[I2]; + constexpr index_t NR3 = old_sequence.mData[I3]; return Sequence{}; } - template + template __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence) const { // don't know how to implement this @@ -41,7 +41,7 @@ struct Sequence assert(false); } - template + template __host__ __device__ constexpr auto PushBack(Number) const { return Sequence{}; @@ -56,14 +56,14 @@ struct Sequence } }; -template +template __host__ __device__ constexpr auto sequence_pop_back(Sequence) { static_assert(sizeof...(Is) >= 1, "empty Sequence!"); return Sequence{}; } -template +template __host__ __device__ constexpr auto sequence_sequence_op(Sequence, Sequence, F f) { static_assert(Sequence::nDim == Sequence::nDim, "Dim not the same"); @@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence, Sequenc return Sequence{}; } -template +template __host__ __device__ constexpr auto sequence_sequence_add(Sequence, Sequence) { struct add { - __host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const + __host__ __device__ constexpr index_t operator()(index_t x, index_t y) const { return x + y; } @@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence, Sequen return sequence_sequence_op(Sequence{}, Sequence{}, add{}); } -template +template __host__ __device__ constexpr auto Sequence::PopBack() const { return sequence_pop_back(Type{}); diff --git a/src/include/blockwise_2d_tensor_op.hip.hpp b/src/include/blockwise_2d_tensor_op.hip.hpp index ce3a7a37b9..5a29f94712 100644 --- a/src/include/blockwise_2d_tensor_op.hip.hpp +++ b/src/include/blockwise_2d_tensor_op.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "ConstantTensorDescriptor.hip.hpp" -template +template __device__ void blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f) { @@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst } #endif - constexpr unsigned NLoop = desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = desc.GetElementSize() / BlockSize; - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; - const unsigned did0 = is / desc.GetStride(I0); + const index_t did0 = is / desc.GetStride(I0); is -= did0 * desc.GetStride(I0); - const unsigned did1 = is / desc.GetStride(I1); + const index_t did1 = is / desc.GetStride(I1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); f(p_dst[dindex]); } @@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < desc.GetElementSize()) { - const unsigned did0 = is / desc.GetStride(I0); + const index_t did0 = is / desc.GetStride(I0); is -= did0 * desc.GetStride(I0); - const unsigned did1 = is / desc.GetStride(I1); + const index_t did1 = is / desc.GetStride(I1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); f(p_dst[dindex]); } @@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst // Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3] // TODO: in order to optimize mem access for different mem type, // need to write specialized version -template {}; constexpr auto I1 = Number<1>{}; - constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); - constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); + constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; - unsigned did[2]; + index_t did[2]; did[0] = is / ref_desc.GetStride(I0); @@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds did[1] = is / ref_desc.GetStride(I1); - const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]); + const index_t aindex = src_desc.Get1dIndex(did[0], did[1]); - const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); + const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); f(p_src[aindex], p_dst[bindex]); } @@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < ref_desc.GetElementSize()) { - unsigned did[2]; + index_t did[2]; did[0] = is / ref_desc.GetStride(I0); @@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds did[1] = is / ref_desc.GetStride(I1); - const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]); + const index_t aindex = src_desc.Get1dIndex(did[0], did[1]); - const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); + const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); f(p_src[aindex], p_dst[bindex]); } } } -template +template __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) { auto f_set_zero = [](Float& v) { v = Float(0); }; @@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) blockwise_2d_tensor_pointwise_operation_unary(DstDesc{}, p_dst, f_set_zero); } -template +template struct Blockwise2dTensorCopy1 { __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const @@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1 // need to be aligned to float4 and float2 // stride1 need to be 1 for both source and destination -template + index_t ThreadPerDim0, + index_t ThreadPerDim1> struct Blockwise2dTensorCopy2 { - unsigned mThreadId0; - unsigned mThreadId1; + index_t mThreadId0; + index_t mThreadId1; __device__ Blockwise2dTensorCopy2() { @@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2 constexpr bool align_v2 = src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0; - constexpr unsigned L0 = SrcOpLengths{}.Get(I0); - constexpr unsigned L1 = SrcOpLengths{}.Get(I1); + constexpr index_t L0 = SrcOpLengths{}.Get(I0); + constexpr index_t L1 = SrcOpLengths{}.Get(I1); - constexpr unsigned Dim0Loop = L0 / ThreadPerDim0; - constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop); + constexpr index_t Dim0Loop = L0 / ThreadPerDim0; + constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop); - constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0; + constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0; - constexpr unsigned Dim1V2Loop = + constexpr index_t Dim1V2Loop = align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0; - constexpr unsigned Dim1V1Loop = + constexpr index_t Dim1V1Loop = (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) / ThreadPerDim1; constexpr bool d1_has_tail = (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop)); - for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop) + for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop) { - unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0; + index_t did0 = d0loop * ThreadPerDim0 + mThreadId0; // v4 - for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) + for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) { - unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; + index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); *(reinterpret_cast(p_dst + dindex)) = *(reinterpret_cast(p_src + sindex)); } // v2 - for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) + for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) { - unsigned did1 = + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); *(reinterpret_cast(p_dst + dindex)) = *(reinterpret_cast(p_src + sindex)); } // v1 - for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) + for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) { - unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + - d1v1loop * ThreadPerDim1 + mThreadId1; + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + d1v1loop * ThreadPerDim1 + mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); p_dst[dindex] = p_src[sindex]; } @@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2 // dim-1 tail if(d1_has_tail) { - unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + - Dim1V1Loop * ThreadPerDim1 + mThreadId1; + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + Dim1V1Loop * ThreadPerDim1 + mThreadId1; if(did1 < L1) { - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); p_dst[dindex] = p_src[sindex]; } @@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2 // dim-0 tail if(d0_has_tail) { - unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0; + index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0; if(did0 < L0) { // v4 - for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) + for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) { - unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; + index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); *(reinterpret_cast(p_dst + dindex)) = *(reinterpret_cast(p_src + sindex)); } // v2 - for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) + for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) { - unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + - 2 * mThreadId1; + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + + 2 * mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); *(reinterpret_cast(p_dst + dindex)) = *(reinterpret_cast(p_src + sindex)); } // v1 - for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) + for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) { - unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + - Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 + - mThreadId1; + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + d1v1loop * ThreadPerDim1 + mThreadId1; - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); p_dst[dindex] = p_src[sindex]; } @@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2 // tail if(d1_has_tail) { - unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + - Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 + - mThreadId1; + index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + Dim1V1Loop * ThreadPerDim1 + mThreadId1; if(did1 < L1) { - const unsigned sindex = src_desc.Get1dIndex(did0, did1); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + const index_t sindex = src_desc.Get1dIndex(did0, did1); + const index_t dindex = dst_desc.Get1dIndex(did0, did1); p_dst[dindex] = p_src[sindex]; } @@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2 // starting point need to be aligned to float4 or float2 or float // stride1 need to be 1 for both source and destination -template + index_t DataPerRead> struct Blockwise2dTensorCopy3 { using vector_t = typename vector_type::MemoryType; - unsigned mSrcMyThreadOffset; - unsigned mDstMyThreadOffset; + index_t mSrcMyThreadOffset; + index_t mDstMyThreadOffset; __device__ Blockwise2dTensorCopy3() { @@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3 DstDesc{}.GetStride(I0) % DataPerRead == 0, "src and dst stride should be multiple of DataPerRead to keep alignment"); - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); - constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; - constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr index_t thread_per_d0 = BlockSize / thread_per_d1; // we allow out-of-bound read from src in D1 dimension, // but we need to make sure dst stride is big enough, @@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3 static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n"); - constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1; if(BlockSize > num_active_thread) { @@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3 } } - const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1; - const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1; + const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1; + const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1; mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead); mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead); @@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3 constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); - constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; - constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr index_t thread_per_d0 = BlockSize / thread_per_d1; - constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1; if(BlockSize > num_active_thread) { @@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3 } } - constexpr unsigned nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d0 = L0 / thread_per_d0; - constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; - constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - auto f_copy = [&](unsigned iloop) { + auto f_copy = [&](index_t iloop) { *(reinterpret_cast(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast(p_src + mSrcMyThreadOffset + iloop * src_loop_stride)); }; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + for(index_t iloop = 0; iloop < nloop_d0; ++iloop) { f_copy(iloop); } @@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3 if(has_tail_d0) { - constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0; if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { @@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3 } } - __device__ constexpr unsigned GetRegisterClipboardSize() const + __device__ constexpr index_t GetRegisterClipboardSize() const { static_assert(is_same::value, "wrong! only support float!\n"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); - constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; - constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr index_t thread_per_d0 = BlockSize / thread_per_d1; return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0; } @@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3 constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); - constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; - constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr index_t thread_per_d0 = BlockSize / thread_per_d1; - constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1; if(BlockSize > num_active_thread) { @@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3 } } - constexpr unsigned nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d0 = L0 / thread_per_d0; - constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; - constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - auto f_copy = [&](unsigned iloop) { + auto f_copy = [&](index_t iloop) { *(reinterpret_cast(p_clipboard + iloop * 4)) = *(reinterpret_cast(p_src + mSrcMyThreadOffset + iloop * src_loop_stride)); }; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + for(index_t iloop = 0; iloop < nloop_d0; ++iloop) { f_copy(iloop); } @@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3 if(has_tail_d0) { - constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0; if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { @@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3 constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); - constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; - constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr index_t thread_per_d0 = BlockSize / thread_per_d1; - constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1; if(BlockSize > num_active_thread) { @@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3 } } - constexpr unsigned nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d0 = L0 / thread_per_d0; - constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; - constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; + constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; - auto f_copy = [&](unsigned iloop) { + auto f_copy = [&](index_t iloop) { *(reinterpret_cast(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast(p_clipboard + iloop * 4)); }; - for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + for(index_t iloop = 0; iloop < nloop_d0; ++iloop) { f_copy(iloop); } @@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3 if(has_tail_d0) { - constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0; if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) { diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index 0660c34ebb..685bc67eea 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "ConstantTensorDescriptor.hip.hpp" -template +template __device__ void blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f) { @@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst } #endif - constexpr unsigned NLoop = desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = desc.GetElementSize() / BlockSize; - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; - const unsigned did0 = is / desc.GetStride(I0); + const index_t did0 = is / desc.GetStride(I0); is -= did0 * desc.GetStride(I0); - const unsigned did1 = is / desc.GetStride(I1); + const index_t did1 = is / desc.GetStride(I1); is -= did1 * desc.GetStride(I1); - const unsigned did2 = is / desc.GetStride(I2); + const index_t did2 = is / desc.GetStride(I2); is -= did2 * desc.GetStride(I2); - const unsigned did3 = is / desc.GetStride(I3); + const index_t did3 = is / desc.GetStride(I3); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); + const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); f(p_dst[dindex]); } @@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < desc.GetElementSize()) { - const unsigned did0 = is / desc.GetStride(I0); + const index_t did0 = is / desc.GetStride(I0); is -= did0 * desc.GetStride(I0); - const unsigned did1 = is / desc.GetStride(I1); + const index_t did1 = is / desc.GetStride(I1); is -= did1 * desc.GetStride(I1); - const unsigned did2 = is / desc.GetStride(I2); + const index_t did2 = is / desc.GetStride(I2); is -= did2 * desc.GetStride(I2); - const unsigned did3 = is / desc.GetStride(I3); + const index_t did3 = is / desc.GetStride(I3); - const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); + const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3); f(p_dst[dindex]); } @@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst // Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3] // TODO: in order to optimize mem access for different mem type, // need to write specialized version -template {}; constexpr auto I3 = Number<3>{}; - constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); - constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); - constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); - constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); + constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); + constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2); + constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; - unsigned did[4]; + index_t did[4]; did[0] = is / ref_desc.GetStride(I0); @@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds did[3] = is / ref_desc.GetStride(I3); - const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); - const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); + const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[src_index], p_dst[dst_index]); } @@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < ref_desc.GetElementSize()) { - unsigned did[4]; + index_t did[4]; did[0] = is / ref_desc.GetStride(I0); @@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds did[3] = is / ref_desc.GetStride(I3); - const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); - const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); + const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[src_index], p_dst[dst_index]); } } } -template +template __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) { auto f_set_zero = [](Float& v) { v = Float(0); }; @@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst) blockwise_4d_tensor_pointwise_operation_unary(DstDesc{}, p_dst, f_set_zero); } -template + index_t DataPerRead> struct Blockwise4dTensorCopy1 { using vector_t = typename vector_type::MemoryType; @@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1 // we allow out-of-bound read from src in D3 dimension, // but we need to make sure dst stride2 is big enough, // so that the out-of-bound write won't contaminate next line in dst - constexpr unsigned L3 = CopyLengths{}.Get(I3); - constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead); + constexpr index_t L3 = CopyLengths{}.Get(I3); + constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead); static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), "wrong! out-of-bound write will contaminate next line!\n"); @@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1 constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); - constexpr unsigned L2 = CopyLengths{}.Get(I2); - constexpr unsigned L3 = CopyLengths{}.Get(I3); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + constexpr index_t L3 = CopyLengths{}.Get(I3); - constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead); + constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead); constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; - auto f_copy = [&](unsigned is) { - unsigned did[4]; + auto f_copy = [&](index_t is) { + index_t did[4]; did[0] = is / ref_desc.GetStride(I0); @@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1 did[3] = is / ref_desc.GetStride(I3); - const unsigned src_index = + const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead); - const unsigned dst_index = + const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead); *(reinterpret_cast(p_dst + dst_index)) = *(reinterpret_cast(p_src + src_index)); }; - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; f_copy(is); } @@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1 if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < ref_desc.GetElementSize()) { @@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1 } }; -template {}; constexpr auto I1 = Number<1>{}; @@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0); constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1); - constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; + constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; const Float* p_src_tmp = p_src + @@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded } #endif - for(unsigned iloop = 0; iloop < NLoop; ++iloop) + for(index_t iloop = 0; iloop < NLoop; ++iloop) { - unsigned is = threadIdx.x + iloop * BlockSize; + index_t is = threadIdx.x + iloop * BlockSize; - unsigned did[4]; + index_t did[4]; did[0] = is / ref_desc.GetStride(I0); @@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded did[3] = is / ref_desc.GetStride(I3); - const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); p_dst[bindex] = (did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) || @@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded if(has_tail) { - unsigned is = threadIdx.x + NLoop * BlockSize; + index_t is = threadIdx.x + NLoop * BlockSize; if(is < ref_desc.GetElementSize()) { - unsigned did[4]; + index_t did[4]; did[0] = is / ref_desc.GetStride(I0); @@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded did[3] = is / ref_desc.GetStride(I3); - const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); + const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]); p_dst[bindex] = (did[1] < h_block_pad_low || @@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded // starting point need to be aligned to float4 or float2 or float // stride3 need to be 1 for both source and destination -template + index_t DataPerRead> struct Blockwise4dTensorCopy3 { using vector_t = typename vector_type::MemoryType; - unsigned mSrcMyThreadOffset; - unsigned mDstMyThreadOffset; + index_t mSrcMyThreadOffset; + index_t mDstMyThreadOffset; __device__ Blockwise4dTensorCopy3() { @@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3 DstDesc{}.GetStride(I2) % DataPerRead == 0, "wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"); - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); - constexpr unsigned L2 = CopyLengths{}.Get(I2); - constexpr unsigned L3 = CopyLengths{}.Get(I3); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + constexpr index_t L3 = CopyLengths{}.Get(I3); - constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); - constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); - constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); - constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3); // we allow out-of-bound read from src in D3 dimension, // but we need to make sure dst stride is big enough, // so that the out-of-bound write won't contaminate next line in dst - constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2), "wrong! out-of-bound write will contaminate next line!\n"); @@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3 static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3, "wrrong! BlockSize is not big enough for ThreadPerDims!"); - constexpr unsigned num_active_thread = + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; if(BlockSize > num_active_thread) @@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3 } } - const unsigned thread_id_d0 = + const index_t thread_id_d0 = get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3); - unsigned itmp = get_thread_local_1d_id() - - thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3); - const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3); + index_t itmp = get_thread_local_1d_id() - + thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3); + const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3); itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3); - const unsigned thread_id_d2 = itmp / thread_per_d3; - const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3; + const index_t thread_id_d2 = itmp / thread_per_d3; + const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3; mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); @@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3 constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr unsigned L0 = CopyLengths{}.Get(I0); - constexpr unsigned L1 = CopyLengths{}.Get(I1); - constexpr unsigned L2 = CopyLengths{}.Get(I2); - constexpr unsigned L3 = CopyLengths{}.Get(I3); + constexpr index_t L0 = CopyLengths{}.Get(I0); + constexpr index_t L1 = CopyLengths{}.Get(I1); + constexpr index_t L2 = CopyLengths{}.Get(I2); + constexpr index_t L3 = CopyLengths{}.Get(I3); - constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0); - constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1); - constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2); - constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3); + constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0); + constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1); + constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2); + constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3); - constexpr unsigned num_active_thread = + constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; if(BlockSize > num_active_thread) @@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3 } } - constexpr unsigned nloop_d0 = L0 / thread_per_d0; - constexpr unsigned nloop_d1 = L1 / thread_per_d1; - constexpr unsigned nloop_d2 = L2 / thread_per_d2; - constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr index_t nloop_d0 = L0 / thread_per_d0; + constexpr index_t nloop_d1 = L1 / thread_per_d1; + constexpr index_t nloop_d2 = L2 / thread_per_d2; + constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); #pragma unroll - for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) + for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) { #pragma unroll - for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) + for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1) { #pragma unroll - for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) + for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) { #pragma unroll - for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) + for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) { - const unsigned src_offset = + const index_t src_offset = SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, iloop_d1 * thread_per_d1, iloop_d2 * thread_per_d2, iloop_d3 * thread_per_d3 * DataPerRead); - const unsigned dst_offset = + const index_t dst_offset = DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, iloop_d1 * thread_per_d1, iloop_d2 * thread_per_d2, diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 1218f173b3..bf2777f140 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -1,30 +1,30 @@ #pragma once #include "threadwise_gemm.hip.hpp" -template struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC { - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; + index_t mMyThreadOffsetA = 0; + index_t mMyThreadOffsetB = 0; struct MatrixIndex { - unsigned batch; - unsigned row; - unsigned col; + index_t batch; + index_t row; + index_t col; }; __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() @@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC #endif } - __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const { if(TransA && (!TransB) && (!TransC)) @@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), "wrong! k dimension not consistent!"); - constexpr unsigned MPerBlock = a_block_mtx.NCol(); - constexpr unsigned NPerBlock = b_block_mtx.NCol(); + constexpr index_t MPerBlock = a_block_mtx.NCol(); + constexpr index_t NPerBlock = b_block_mtx.NCol(); constexpr auto c_thread_mtx = ThreadMatrixC{}; // divide thread work - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); - constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; - constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; + constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; + constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, "wrong! wrong BlockSize"); @@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC if(DistributeThreadAlongColumnFirst) { // num of operations can be reduced - const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork); - unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); - const unsigned m_work_id = itmp / NThreadWork; - const unsigned n_work_id = itmp - m_work_id * NThreadWork; + const index_t b_work_id = thread_id / (MThreadWork * NThreadWork); + index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); + const index_t m_work_id = itmp / NThreadWork; + const index_t n_work_id = itmp - m_work_id * NThreadWork; return MatrixIndex{ b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; @@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC // this should be optimized away if input is known __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c) { return MatrixIndex{batch_in_c, m_in_c, n_in_c}; } @@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // a is transposed, b is not constexpr auto a_thread_mtx = @@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; // loop over k - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { // read first batch of a, b threadwise_matrix_copy(a_block_mtx, @@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC b_thread_mtx.GetLengths()); // loop over batch - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) { // do current batch of gemm threadwise_gemm(a_thread_mtx, @@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC } }; -template + index_t BlockMatrixStrideA, + index_t BlockMatrixStrideB, + index_t ThreadMatrixStrideC, + index_t BatchSize, + index_t MPerThreadSubC, + index_t NPerThreadSubC, + index_t MLevel0Cluster, + index_t NLevel0Cluster, + index_t MLevel1Cluster, + index_t NLevel1Cluster, + index_t KPerThreadLoop, + index_t BatchPerThread> struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 { - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; + index_t mMyThreadOffsetA = 0; + index_t mMyThreadOffsetB = 0; struct MatrixIndex { - unsigned batch; - unsigned row; - unsigned col; + index_t batch; + index_t row; + index_t col; }; __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() @@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 static_assert(BatchSize % BatchPerThread == 0, "wrong! BatchSize is not dividable by BatchPerThread"); - constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + constexpr index_t BatchThreadWork = BatchSize / BatchPerThread; - constexpr unsigned ThreadPerLevel1Cluster = + constexpr index_t ThreadPerLevel1Cluster = MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, @@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), "wrong! K dimension not consistent\n"); - constexpr unsigned M = a_block_mtx.NCol(); // A is transposed - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); // A is transposed + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), "wrong! Cannot evenly divide thread work among repeat \n"); - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; static_assert((M % MRepeat == 0) && (N % NRepeat == 0), "wrong! Cannot evenly divide work among repeat\n"); - constexpr unsigned MPerLevel1Cluster = M / MRepeat; - constexpr unsigned NPerLevel1Cluster = N / NRepeat; + constexpr index_t MPerLevel1Cluster = M / MRepeat; + constexpr index_t NPerLevel1Cluster = N / NRepeat; static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && (NPerLevel1Cluster % NLevel1Cluster == 0), "wrong! Cannot evenly divide work among Level1Cluster\n"); - constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; - constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; + constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; + constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && (NPerLevel0Cluster % NLevel0Cluster == 0), @@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #endif } - __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const { - constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + constexpr index_t BatchThreadWork = BatchSize / BatchPerThread; - constexpr unsigned ThreadPerLevel1Cluster = + constexpr index_t ThreadPerLevel1Cluster = MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; - constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; - unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster; - unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; + index_t batch_work_id = thread_id / ThreadPerLevel1Cluster; + index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; - unsigned level1_id = cluster_id / ThreadPerLevel0Cluster; - unsigned level1_m_id = level1_id / NLevel1Cluster; - unsigned level1_n_id = level1_id % NLevel1Cluster; + index_t level1_id = cluster_id / ThreadPerLevel0Cluster; + index_t level1_m_id = level1_id / NLevel1Cluster; + index_t level1_n_id = level1_id % NLevel1Cluster; - unsigned level0_id = cluster_id % ThreadPerLevel0Cluster; - unsigned level0_m_id = level0_id / NLevel0Cluster; - unsigned level0_n_id = level0_id % NLevel0Cluster; + index_t level0_id = cluster_id % ThreadPerLevel0Cluster; + index_t level0_m_id = level0_id / NLevel0Cluster; + index_t level0_n_id = level0_id % NLevel0Cluster; - constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; - constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; + constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; return MatrixIndex{batch_work_id * BatchPerThread, level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, @@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 // this should be optimized away if input is known __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c) { constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - unsigned m_repeat = m_in_c / MPerThreadSubC; - unsigned n_repeat = n_in_c / NPerThreadSubC; + index_t m_repeat = m_in_c / MPerThreadSubC; + index_t n_repeat = n_in_c / NPerThreadSubC; - unsigned m_in_sub_c = m_in_c % MPerThreadSubC; - unsigned n_in_sub_c = n_in_c % NPerThreadSubC; + index_t m_in_sub_c = m_in_c % MPerThreadSubC; + index_t n_in_sub_c = n_in_c % NPerThreadSubC; return MatrixIndex{batch_in_c, m_repeat * MPerLevel1Cluster + m_in_sub_c, @@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM // A is transposed, b is not @@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; // loop over k #pragma unroll - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { // read first batch of A, B // copy A-sub to form A #pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { threadwise_matrix_copy( a_block_mtx, @@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 // copy B-sub to form B #pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( b_block_mtx, @@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 // loop over batch #pragma unroll - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) { // do current batch of gemm threadwise_gemm(a_thread_mtx, @@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 if(BlockMatrixStrideA != 0) { #pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { threadwise_matrix_copy( a_block_mtx, @@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 if(BlockMatrixStrideB != 0) { #pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( b_block_mtx, @@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM // A is transposed, b is not @@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; // loop over k //#pragma unroll - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { // read first batch of A, B // copy A-sub to form A //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { - for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) + for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i) { #if 1 - for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) + for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j) { p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = p_a_block[a_block_mtx.Get1dIndex(k_begin + i, @@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 // copy B-sub to form B //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { - for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) + for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i) { - for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) + for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j) { p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = p_b_block[b_block_mtx.Get1dIndex(k_begin + i, @@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 // loop over batch //#pragma unroll - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) { // do current batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) { #if 0 - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + for(index_t j = 0; j < c_thread_mtx.NCol(); ++j) { - const unsigned aindex = + const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = + const index_t bindex = b_thread_mtx.Get1dIndex(k, j); + const index_t cindex = c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); @@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, "asm is only for 16x4"); - const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); + for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); + const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); asm volatile("\n \ v_mac_f32 %0, %4, %5 \n \ @@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 if(BlockMatrixStrideA != 0) { //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { - for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) + for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i) { - for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) + for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j) { p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = @@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 if(BlockMatrixStrideB != 0) { //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { - for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) + for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i) { - for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) + for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j) { p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = @@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } // do last batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) { #if 0 - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + for(index_t j = 0; j < c_thread_mtx.NCol(); ++j) { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + + const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const index_t bindex = b_thread_mtx.Get1dIndex(k, j); + const index_t cindex = c_thread_mtx.Get1dIndex(i, j) + (BatchPerThread - 1) * ThreadMatrixStrideC; f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); @@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, "asm is only for 16x4"); - const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); + for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned cindex = + const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const index_t cindex = c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; asm volatile("\n \ @@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } } - template + template __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, FloatC* __restrict__ p_c_block) const { constexpr auto c_block_mtx = BlockMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( Number{}, Number{}, Number{}); - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned c_thread_offset = + const index_t c_thread_offset = c_thread_mtx_begin.batch * BlockMatrixStrideC + c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( c_thread_sub_mtx, diff --git a/src/include/blockwise_direct_convolution.hip.hpp b/src/include/blockwise_direct_convolution.hip.hpp index 7666607c9c..3aff3b7936 100644 --- a/src/include/blockwise_direct_convolution.hip.hpp +++ b/src/include/blockwise_direct_convolution.hip.hpp @@ -3,16 +3,16 @@ #include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_direct_convolution.hip.hpp" -template + index_t NPerThread, + index_t KPerThread, + index_t CPerThread, + index_t HoPerThread, + index_t WoPerThread> __device__ void blockwise_direct_convolution(InBlockDesc, Float* const __restrict__ p_in_block, WeiBlockDesc, @@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc, constexpr auto wei_block_desc = WeiBlockDesc{}; constexpr auto out_block_desc = OutBlockDesc{}; - constexpr unsigned Y = wei_block_desc.GetLength(I2); - constexpr unsigned X = wei_block_desc.GetLength(I3); + constexpr index_t Y = wei_block_desc.GetLength(I2); + constexpr index_t X = wei_block_desc.GetLength(I3); - constexpr unsigned InTileSizeH = HoPerThread + Y - 1; - constexpr unsigned InTileSizeW = WoPerThread + X - 1; + constexpr index_t InTileSizeH = HoPerThread + Y - 1; + constexpr index_t InTileSizeW = WoPerThread + X - 1; // divide thread work - constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; - constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; - constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; - constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; + constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread; + constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread; + constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread; + constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; #if 0 if(threadIdx.x == 0) @@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc, constexpr auto out_thread_block_desc = make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides()); - const unsigned thread_id = threadIdx.x; + const index_t thread_id = threadIdx.x; - for(unsigned thread_work_id = thread_id; + for(index_t thread_work_id = thread_id; thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork; thread_work_id += BlockSize) { - unsigned itmp = thread_work_id; - unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); + index_t itmp = thread_work_id; + index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); - unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); + index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork); itmp -= k_thread_work_id * (YThreadWork * XThreadWork); - unsigned y_thread_work_id = itmp / XThreadWork; - unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; + index_t y_thread_work_id = itmp / XThreadWork; + index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork; - unsigned n_thread_data_begin = n_thread_work_id * NPerThread; - unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread; - unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread; + index_t n_thread_data_begin = n_thread_work_id * NPerThread; + index_t k_thread_data_begin = k_thread_work_id * KPerThread; + index_t ho_thread_data_begin = y_thread_work_id * HoPerThread; + index_t wo_thread_data_begin = x_thread_work_id * WoPerThread; - unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding - unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding + index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding + index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding Float p_out_thread[out_thread_desc.GetElementSpace()]; @@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, p_out_thread, out_thread_desc.GetLengths()); - for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); + for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1); c_thread_data_begin += CPerThread) { // threadwise convolution diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 9471776a74..f7cb637d4e 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -1,26 +1,26 @@ #pragma once #include "threadwise_gemm.hip.hpp" -template struct BlockwiseGemmBlockABlockBThreadC { - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; + index_t mMyThreadOffsetA = 0; + index_t mMyThreadOffsetB = 0; struct MatrixIndex { - unsigned row; - unsigned col; + index_t row; + index_t col; }; __device__ BlockwiseGemmBlockABlockBThreadC() @@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC #endif } - __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const { if(TransA && (!TransB) && (!TransC)) @@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), "wrong! k dimension not consistent!"); - constexpr unsigned MPerBlock = a_block_mtx.NCol(); - constexpr unsigned NPerBlock = b_block_mtx.NCol(); + constexpr index_t MPerBlock = a_block_mtx.NCol(); + constexpr index_t NPerBlock = b_block_mtx.NCol(); constexpr auto c_thread_mtx = ThreadMatrixC{}; // divide thread work - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0, "MPerBlock % (MPerThread * MThreadPerCluster) != 0"); @@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0, "NPerBlock % (NPerThread * NThreadPerCluster) != 0"); - constexpr unsigned MClusterWork = + constexpr index_t MClusterWork = (MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster); - constexpr unsigned NClusterWork = + constexpr index_t NClusterWork = (NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster); static_assert(BlockSize == @@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC if(DistributeThreadAlongColumnFirst) { - const unsigned cluster_work_block_id = + const index_t cluster_work_block_id = thread_id / (MThreadPerCluster * NThreadPerCluster); - const unsigned thread_work_cluster_id = + const index_t thread_work_cluster_id = thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster); - const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork; - const unsigned n_cluster_work_block_id = + const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork; + const index_t n_cluster_work_block_id = cluster_work_block_id - m_cluster_work_block_id * NClusterWork; - const unsigned m_thread_work_cluster_id = - thread_work_cluster_id / NThreadPerCluster; - const unsigned n_thread_work_cluster_id = + const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster; + const index_t n_thread_work_cluster_id = thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster; #if 0 @@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC } // this should be optimized away if input is known - __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, - unsigned n_in_c) + __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c, + index_t n_in_c) { return MatrixIndex{m_in_c, n_in_c}; } @@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // a is transposed, b is not constexpr auto a_thread_mtx = @@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; // loop over k - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { threadwise_matrix_copy(a_block_mtx, p_a_block + mMyThreadOffsetA + @@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC // if following number are power of 2, index calculation shall be greatly reduced: // MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster -template + index_t MPerThreadSubC, + index_t NPerThreadSubC, + index_t MLevel0Cluster, + index_t NLevel0Cluster, + index_t MLevel1Cluster, + index_t NLevel1Cluster, + index_t KPerThreadLoop> struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 { struct MatrixIndex { - unsigned row; - unsigned col; + index_t row; + index_t col; }; - unsigned mMyThreadOffsetA; - unsigned mMyThreadOffsetB; + index_t mMyThreadOffsetA; + index_t mMyThreadOffsetB; __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() { - constexpr unsigned ThreadPerLevel1Cluster = + constexpr index_t ThreadPerLevel1Cluster = MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); @@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), "wrong! K dimension not consistent\n"); - constexpr unsigned M = a_block_mtx.NCol(); // A is transposed - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); // A is transposed + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), "wrong! Cannot evenly divide thread work among repeat \n"); - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; static_assert((M % MRepeat == 0) && (N % NRepeat == 0), "wrong! Cannot evenly divide work among repeat\n"); - constexpr unsigned MPerLevel1Cluster = M / MRepeat; - constexpr unsigned NPerLevel1Cluster = N / NRepeat; + constexpr index_t MPerLevel1Cluster = M / MRepeat; + constexpr index_t NPerLevel1Cluster = N / NRepeat; static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && (NPerLevel1Cluster % NLevel1Cluster == 0), "wrong! Cannot evenly divide work among Level1Cluster\n"); - constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; - constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; + constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; + constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && (NPerLevel0Cluster % NLevel0Cluster == 0), @@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); } - __device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) + __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) { - constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; - unsigned level1_id = thread_id / ThreadPerLevel0Cluster; - unsigned level1_m_id = level1_id / NLevel1Cluster; - unsigned level1_n_id = level1_id % NLevel1Cluster; + index_t level1_id = thread_id / ThreadPerLevel0Cluster; + index_t level1_m_id = level1_id / NLevel1Cluster; + index_t level1_n_id = level1_id % NLevel1Cluster; - unsigned level0_id = thread_id % ThreadPerLevel0Cluster; - unsigned level0_m_id = level0_id / NLevel0Cluster; - unsigned level0_n_id = level0_id % NLevel0Cluster; + index_t level0_id = thread_id % ThreadPerLevel0Cluster; + index_t level0_m_id = level0_id / NLevel0Cluster; + index_t level0_n_id = level0_id % NLevel0Cluster; - constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; - constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; + constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; } // this should be optimized away if input is known - __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, - unsigned n_in_c) + __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c, + index_t n_in_c) { constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - unsigned m_repeat = m_in_c / MPerThreadSubC; - unsigned n_repeat = n_in_c / NPerThreadSubC; + index_t m_repeat = m_in_c / MPerThreadSubC; + index_t n_repeat = n_in_c / NPerThreadSubC; - unsigned m_in_sub_c = m_in_c % MPerThreadSubC; - unsigned n_in_sub_c = n_in_c % NPerThreadSubC; + index_t m_in_sub_c = m_in_c % MPerThreadSubC; + index_t n_in_sub_c = n_in_c % NPerThreadSubC; return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c, n_repeat * NPerLevel1Cluster + n_in_sub_c}; @@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned M = a_block_mtx.NCol(); - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM constexpr auto a_thread_mtx = @@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; #pragma unroll // loop over k - for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { #pragma unroll // copy A-sub to form A - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { threadwise_matrix_copy( a_block_mtx, @@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 #pragma unroll // copy B-sub to form B - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( b_block_mtx, @@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned M = a_block_mtx.NCol(); - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM constexpr auto a_thread_mtx = @@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; #pragma unroll // loop over k - for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { -#pragma unroll + //#pragma unroll // copy A-sub to form A - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { threadwise_matrix_copy( a_block_mtx, @@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 a_thread_sub_mtx.GetLengths()); } -#pragma unroll + //#pragma unroll // copy B-sub to form B - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( b_block_mtx, @@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 False, p_c_thread, f_accum); -#else +#elif 0 // inline asm static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8, "asm is only for 8x8"); - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed + for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed { - const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + const index_t bindex = b_thread_mtx.Get1dIndex(k, 0); - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + for(index_t i = 0; i < c_thread_mtx.NRow(); ++i) { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); + const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); asm volatile("\n \ v_mac_f32 %0, %8, %9 \n \ @@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned M = a_block_mtx.NCol(); - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM constexpr auto a_thread_mtx = @@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; // preload A, B #pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { // copy A-sub to form A threadwise_matrix_copy(a_block_mtx, p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, @@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 } #pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { // copy B-sub to form B threadwise_matrix_copy(b_block_mtx, p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, @@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 bool even_loop = true; #pragma unroll - for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K; + for(index_t k_begin = 0; k_begin + KPerThreadLoop < K; k_begin += KPerThreadLoop, even_loop = !even_loop) { // loop over k FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; @@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 // preload next A, B #pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { // copy A-sub to form A threadwise_matrix_copy(a_block_mtx, p_a_block + mMyThreadOffsetA + @@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 } #pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { // copy B-sub to form B threadwise_matrix_copy(b_block_mtx, p_b_block + mMyThreadOffsetB + @@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr unsigned M = a_block_mtx.NCol(); - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); // thread A-sub, B-sub, C-sub constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( @@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; #pragma unroll // loop over k - for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { // C-sub(s) in first row-wise subblock of C { @@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 #pragma unroll // copy next B-sub, and do GEMM - for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( b_block_mtx, @@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 #pragma unroll // loop over rest of row-wise subblock // all B-sub(s) has been copied, so only A-sub(s) need to be copied - for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat) + for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat) { // copy a A-sub threadwise_matrix_copy( @@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 a_thread_sub_mtx.GetLengths()); // do some GEMMs - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_gemm( a_thread_sub_mtx, diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index a6c9d128e8..6770b590a9 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -5,9 +5,9 @@ #include "Array.hip.hpp" #include "functional.hip.hpp" -__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } -__device__ unsigned get_block_1d_id() { return blockIdx.x; } +__device__ index_t get_block_1d_id() { return blockIdx.x; } template struct is_same @@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b) } #endif -__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b) +__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b) { return (a + b - 1) / b; } diff --git a/src/include/config.h.in b/src/include/config.h.in index bb4f6cb51d..ce3232489d 100644 --- a/src/include/config.h.in +++ b/src/include/config.h.in @@ -11,3 +11,5 @@ #include "nvToolsExt.h" #include "helper_cuda.h" #endif + +using index_t = uint32_t; diff --git a/src/include/constant_integral.hip.hpp b/src/include/constant_integral.hip.hpp index 70dc69d181..cdba3290a0 100644 --- a/src/include/constant_integral.hip.hpp +++ b/src/include/constant_integral.hip.hpp @@ -8,5 +8,5 @@ struct integral_constant __host__ __device__ constexpr T Get() const { return value; } }; -template -using Number = integral_constant; +template +using Number = integral_constant; diff --git a/src/include/data_type.hip.hpp b/src/include/data_type.hip.hpp index 95d5b0b33f..1261e19989 100644 --- a/src/include/data_type.hip.hpp +++ b/src/include/data_type.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "config.h" -template +template struct vector_type { }; diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp index d3f645eaae..c5403f0452 100644 --- a/src/include/functional.hip.hpp +++ b/src/include/functional.hip.hpp @@ -1,7 +1,7 @@ #pragma once #include "constant_integral.hip.hpp" -template +template struct static_loop_n { template @@ -24,7 +24,7 @@ struct static_loop_n<1> } }; -template +template struct static_const_reduce_n { template diff --git a/src/include/gridwise_direct_convolution_1.hip.hpp b/src/include/gridwise_direct_convolution_1.hip.hpp index edcfd6d38e..1fb76988a7 100644 --- a/src/include/gridwise_direct_convolution_1.hip.hpp +++ b/src/include/gridwise_direct_convolution_1.hip.hpp @@ -8,18 +8,18 @@ template + index_t NPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t HoPerBlock, + index_t WoPerBlock, + index_t NPerThread, + index_t KPerThread, + index_t CPerThread, + index_t HoPerThread, + index_t WoPerThread, + index_t BlockSize, + index_t GridSize> __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) @@ -33,16 +33,16 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto out_global_desc = OutGlobalDesc{}; - constexpr unsigned Y = wei_global_desc.GetLength(I2); - constexpr unsigned X = wei_global_desc.GetLength(I3); + constexpr index_t Y = wei_global_desc.GetLength(I2); + constexpr index_t X = wei_global_desc.GetLength(I3); - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; - constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; constexpr auto in_block_global_desc = make_ConstantTensorDescriptor( Sequence{}, in_global_desc.GetStrides()); @@ -59,31 +59,31 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ constexpr auto out_block_desc = make_ConstantTensorDescriptor(out_block_global_desc.GetLengths()); - constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace(); - constexpr unsigned out_block_size = out_block_desc.GetElementSpace(); + constexpr index_t in_block_size = in_block_desc.GetElementSpace(); + constexpr index_t wei_block_size = wei_block_desc.GetElementSpace(); + constexpr index_t out_block_size = out_block_desc.GetElementSpace(); __shared__ Float p_in_block[in_block_size]; __shared__ Float p_wei_block[wei_block_size]; __shared__ Float p_out_block[out_block_size]; - const unsigned block_id = blockIdx.x; + const index_t block_id = blockIdx.x; - unsigned itmp = block_id; - unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + index_t itmp = block_id; + index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork); - unsigned h_block_work_id = itmp / WBlockWork; - unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + index_t h_block_work_id = itmp / WBlockWork; + index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - unsigned n_block_work_begin = n_block_work_id * NPerBlock; - unsigned k_block_work_begin = k_block_work_id * KPerBlock; - unsigned ho_block_work_begin = h_block_work_id * HoPerBlock; - unsigned wo_block_work_begin = w_block_work_id * WoPerBlock; + index_t n_block_work_begin = n_block_work_id * NPerBlock; + index_t k_block_work_begin = k_block_work_id * KPerBlock; + index_t ho_block_work_begin = h_block_work_id * HoPerBlock; + index_t wo_block_work_begin = w_block_work_id * WoPerBlock; - unsigned hi_block_work_begin = ho_block_work_begin; // minus padding - unsigned wi_block_work_begin = wo_block_work_begin; // minus padding + index_t hi_block_work_begin = ho_block_work_begin; // minus padding + index_t wi_block_work_begin = wo_block_work_begin; // minus padding constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1(out_block_desc, p_out_block); - for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1); + for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1); c_block_work_begin += CPerBlock) { // copy input tensor to LDS diff --git a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp index 1e6d3d24bd..944a1624ee 100644 --- a/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp @@ -11,20 +11,20 @@ template + index_t NPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t HoPerBlock, + index_t WoPerBlock, + index_t NPerThread, + index_t KPerThread, + index_t CPerThread, + index_t HoPerThread, + index_t WoPerThread, + index_t InBlockCopyDataPerRead, + index_t WeiBlockCopyDataPerRead, + index_t BlockSize, + index_t GridSize> __global__ void gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -39,17 +39,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{}; constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - constexpr unsigned N = in_nchw_global_desc.GetLength(I0); - constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3); + constexpr index_t N = in_nchw_global_desc.GetLength(I0); + constexpr index_t K = wei_kcyx_global_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_global_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_global_desc.GetLength(I3); constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor( Sequence{}); // 2d view of wei for blockwise copy - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); @@ -63,21 +63,21 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i Sequence{}); // shared mem - constexpr unsigned in_block_size = + constexpr index_t in_block_size = in_nchw_block_desc.GetElementSpace(Number{}); - constexpr unsigned wei_block_size = + constexpr index_t wei_block_size = wei_kcyx_block_desc.GetElementSpace(Number{}); - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; // threadwise tensors - constexpr unsigned HiPerThread = HoPerThread + Y - 1; - constexpr unsigned WiPerThread = WoPerThread + X - 1; + constexpr index_t HiPerThread = HoPerThread + Y - 1; + constexpr index_t WiPerThread = WoPerThread + X - 1; constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor(Sequence{}, @@ -93,56 +93,54 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; // divide block work - constexpr unsigned NBlockWork = - (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = - (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = + constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = + constexpr index_t WBlockWork = (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - const unsigned block_id = blockIdx.x; + const index_t block_id = blockIdx.x; - unsigned itmp = block_id; - const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + index_t itmp = block_id; + const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const unsigned h_block_work_id = itmp / WBlockWork; - const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + const index_t h_block_work_id = itmp / WBlockWork; + const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding - const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding + const index_t hi_block_data_begin = ho_block_data_begin; // minus padding + const index_t wi_block_data_begin = wo_block_data_begin; // minus padding // divide thread work - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; - constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; + constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; + constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; + constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; - const unsigned thread_id = threadIdx.x; + const index_t thread_id = threadIdx.x; - itmp = thread_id; - const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); + itmp = thread_id; + const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); - const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); + const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork); itmp -= k_thread_work_id * (HThreadWork * WThreadWork); - const unsigned h_thread_work_id = itmp / WThreadWork; - const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; + const index_t h_thread_work_id = itmp / WThreadWork; + const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork; - const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; - const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; - const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; + const index_t n_thread_data_begin = n_thread_work_id * NPerThread; + const index_t k_thread_data_begin = k_thread_work_id * KPerThread; + const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread; + const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread; - const unsigned hi_thread_data_begin = ho_thread_data_begin; - const unsigned wi_thread_data_begin = wo_thread_data_begin; + const index_t hi_thread_data_begin = ho_thread_data_begin; + const index_t wi_thread_data_begin = wo_thread_data_begin; constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1 + index_t ScalarPerVector, + index_t NPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t HoPerBlock, + index_t WoPerBlock, + index_t NPerThread, + index_t KPerThread, + index_t CPerThread, + index_t HoPerThread, + index_t WoPerThread, + index_t InBlockCopyDataPerRead, + index_t WeiBlockCopyDataPerRead, + index_t BlockSize, + index_t GridSize> __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( const typename vector_type::MemoryType* const __restrict__ p_in_vec_global, @@ -49,17 +49,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{}; constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - constexpr unsigned N = in_nchw_vec_global_desc.GetLength(I0); - constexpr unsigned K = wei_kcyx_vec_global_desc.GetLength(I0); - constexpr unsigned C = wei_kcyx_vec_global_desc.GetLength(I1); - constexpr unsigned Y = wei_kcyx_vec_global_desc.GetLength(I2); - constexpr unsigned X = wei_kcyx_vec_global_desc.GetLength(I3); + constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0); + constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0); + constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1); + constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2); + constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3); constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor( Sequence{}); // 2d view of wei for blockwise copy - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned( Sequence{}, Number{}); @@ -73,15 +73,15 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( Sequence{}); // shared mem - constexpr unsigned in_block_size = + constexpr index_t in_block_size = in_nchw_vec_block_desc.GetElementSpace(Number{}); - constexpr unsigned wei_block_size = + constexpr index_t wei_block_size = wei_kcyx_vec_block_desc.GetElementSpace(Number{}); - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; __shared__ in_vector_mem_t p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)]; @@ -89,8 +89,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; // threadwise tensors - constexpr unsigned HiPerThread = HoPerThread + Y - 1; - constexpr unsigned WiPerThread = WoPerThread + X - 1; + constexpr index_t HiPerThread = HoPerThread + Y - 1; + constexpr index_t WiPerThread = WoPerThread + X - 1; constexpr auto in_nchw_vec_thread_block_desc = make_ConstantTensorDescriptor(Sequence{}, @@ -106,56 +106,54 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; // divide block work - constexpr unsigned NBlockWork = - (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr unsigned KBlockWork = - (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = + constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = + constexpr index_t WBlockWork = (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - const unsigned block_id = blockIdx.x; + const index_t block_id = blockIdx.x; - unsigned itmp = block_id; - const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + index_t itmp = block_id; + const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const unsigned h_block_work_id = itmp / WBlockWork; - const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + const index_t h_block_work_id = itmp / WBlockWork; + const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding - const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding + const index_t hi_block_data_begin = ho_block_data_begin; // minus padding + const index_t wi_block_data_begin = wo_block_data_begin; // minus padding // divide thread work - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; - constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; + constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; + constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; + constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; - const unsigned thread_id = threadIdx.x; + const index_t thread_id = threadIdx.x; - itmp = thread_id; - const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); + itmp = thread_id; + const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); - const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); + const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork); itmp -= k_thread_work_id * (HThreadWork * WThreadWork); - const unsigned h_thread_work_id = itmp / WThreadWork; - const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; + const index_t h_thread_work_id = itmp / WThreadWork; + const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork; - const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; - const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; - const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; - const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; + const index_t n_thread_data_begin = n_thread_work_id * NPerThread; + const index_t k_thread_data_begin = k_thread_work_id * KPerThread; + const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread; + const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread; - const unsigned hi_thread_data_begin = ho_thread_data_begin; - const unsigned wi_thread_data_begin = wo_thread_data_begin; + const index_t hi_thread_data_begin = ho_thread_data_begin; + const index_t wi_thread_data_begin = wo_thread_data_begin; constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1 + index_t InBlockCopyDataPerRead, + index_t WeiBlockCopyDataPerRead, + index_t GemmMPerThreadSubC, + index_t GemmNPerThreadSubC, + index_t GemmMLevel0Cluster, + index_t GemmNLevel0Cluster, + index_t GemmMLevel1Cluster, + index_t GemmNLevel1Cluster, + index_t GemmKPerThreadLoop, + index_t OutThreadCopyDataPerWrite> __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -55,39 +55,39 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned N = out_khwn_global_desc.GetLength(I3); + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t N = out_khwn_global_desc.GetLength(I3); - constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); - constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; // divide block work: [K, Ho, Wo, N] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; - const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const unsigned w_block_work_id = itmp / NBlockWork; - const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; - const unsigned hi_block_data_begin = ho_block_data_begin; - const unsigned wi_block_data_begin = wo_block_data_begin; + const index_t hi_block_data_begin = ho_block_data_begin; + const index_t wi_block_data_begin = wo_block_data_begin; // flattend (2d) tensor view of gridwise weight constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -164,15 +164,15 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric HoPerThread>{}; // LDS: be careful of alignment - constexpr unsigned in_block_size = + constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace(Number{}); - constexpr unsigned wei_block_size = + constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace(Number{}); - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; @@ -191,10 +191,10 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric const Float* p_wei_global_block_begin = p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0), - p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - __syncthreads()) + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0), + p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) { // input: global mem to LDS blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); @@ -205,9 +205,9 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric __syncthreads(); // a series of batched GEMM - for(unsigned y = 0; y < Y; ++y) + for(index_t y = 0; y < Y; ++y) { - for(unsigned x = 0; x < X; ++x) + for(index_t x = 0; x < X; ++x) { #if 0 blockwise_batch_gemm.Run @@ -227,26 +227,26 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) + for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) { - for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) + for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) { - for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) + for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) { - for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) + for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) { - const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); + const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); const auto c_thread_mtx_distance = blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); - const unsigned ho_thread = + const index_t ho_thread = c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; - const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; - const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; + const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; + const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; - const unsigned wo_thread = b_thread / NPerBlock; - const unsigned n_thread = b_thread % NPerBlock; + const index_t wo_thread = b_thread / NPerBlock; + const index_t n_thread = b_thread % NPerBlock; p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, ho_block_data_begin + ho_thread, @@ -261,19 +261,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric const auto c_thread_mtx_begin = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = c_thread_mtx_begin.row; - const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; - const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; + const index_t k_thread_data_begin = c_thread_mtx_begin.row; + const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; + const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; + const index_t n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; // this is for v2 GEMM // output is a 8d tensor if(NPerThread < NPerBlock && WoPerThread == 1) { - constexpr unsigned N1_ = GemmNPerThreadSubC; - constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); - constexpr unsigned K2_ = GemmMPerThreadSubC; - constexpr unsigned K1_ = KPerBlock / KPerThread; + constexpr index_t N1_ = GemmNPerThreadSubC; + constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); + constexpr index_t K2_ = GemmMPerThreadSubC; + constexpr index_t K1_ = KPerBlock / KPerThread; constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( Sequence{}); diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp index 790a006023..fb0c781bfd 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hip.hpp @@ -7,26 +7,26 @@ #include "threadwise_4d_tensor_op.hip.hpp" #include "blockwise_gemm.hip.hpp" -template + index_t NPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t HoPerBlock, + index_t WoPerBlock, + index_t NPerThread, + index_t KPerThread, + index_t CPerThread, + index_t HoPerThread, + index_t WoPerThread, + index_t WeiBlockCopyThreadPerDim0, + index_t WeiBlockCopyThreadPerDim1> __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -48,42 +48,42 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); + constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned N = out_khwn_global_desc.GetLength(I3); + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t N = out_khwn_global_desc.GetLength(I3); - constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); - constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr unsigned HPadLow = LowerPads{}.Get(I0); - constexpr unsigned WPadLow = LowerPads{}.Get(I1); + constexpr index_t HPadLow = LowerPads{}.Get(I0); + constexpr index_t WPadLow = LowerPads{}.Get(I1); - constexpr unsigned HPadUp = UpperPads{}.Get(I0); - constexpr unsigned WPadUp = UpperPads{}.Get(I1); + constexpr index_t HPadUp = UpperPads{}.Get(I0); + constexpr index_t WPadUp = UpperPads{}.Get(I1); - constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; - constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + constexpr index_t HiPerBlock = HoPerBlock + Y - 1; + constexpr index_t WiPerBlock = WoPerBlock + X - 1; // divide block work: [K, Ho, Wo, N] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; + constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; + constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; - const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); + const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); + index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); + const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const unsigned w_block_work_id = itmp / NBlockWork; - const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; + const index_t w_block_work_id = itmp / NBlockWork; + const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; - const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; - const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; + const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; + const index_t n_block_data_begin = n_block_work_id * NPerBlock; // flattened (2d) tensor view of wei in global mem constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -114,11 +114,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( // blockwise copy // input: format is [C, Hi, Wi, N] - const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; - const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; + const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; + const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; - const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; - const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; + const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; + const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; #if 0 if(get_thread_local_1d_id() == 0) @@ -204,8 +204,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( true>{}; // LDS - constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_cyxk_block_desc.GetElementSpace(); + constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace(); __shared__ Float p_in_block[in_block_size]; __shared__ Float p_wei_block[wei_block_size]; @@ -219,9 +219,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( const Float* p_wei_global_block_begin = p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), - __syncthreads()) + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), + __syncthreads()) { #if 1 // input: global mem to LDS, @@ -245,9 +245,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( __syncthreads(); // a series of batched GEMM - for(unsigned y = 0; y < Y; ++y) + for(index_t y = 0; y < Y; ++y) { - for(unsigned x = 0; x < X; ++x) + for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; @@ -262,10 +262,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( const auto matrix_c_index = blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned ho_thread_data_begin = matrix_c_index.batch; - const unsigned k_thread_data_begin = matrix_c_index.row; - const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; - const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; + const index_t ho_thread_data_begin = matrix_c_index.batch; + const index_t k_thread_data_begin = matrix_c_index.row; + const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock; + const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; #if 0 printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index c359001b85..08aa8f90f5 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -8,32 +8,32 @@ #include "blockwise_gemm.hip.hpp" // define B = flatten(N, Hi, Wi) -template + index_t BPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t BPerThread, + index_t KPerThread, + index_t GemmThreadPerColumnPerCluster, + index_t GemmThreadPerRowPerCluster, + index_t GemmMPerThreadSubC, + index_t GemmNPerThreadSubC, + index_t GemmMLevel0Cluster, + index_t GemmNLevel0Cluster, + index_t GemmMLevel1Cluster, + index_t GemmNLevel1Cluster, + index_t GemmKPerThreadLoop, + index_t InBlockCopyThreadPerDim0, + index_t InBlockCopyThreadPerDim1, + index_t WeiBlockCopyThreadPerDim0, + index_t WeiBlockCopyThreadPerDim1, + index_t InBlockCopyDataPerRead, + index_t WeiBlockCopyDataPerRead> __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -48,30 +48,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); - constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); - constexpr unsigned N = in_chwn_global_desc.GetLength(I3); + constexpr index_t C = in_chwn_global_desc.GetLength(I0); + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); + constexpr index_t N = in_chwn_global_desc.GetLength(I3); - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); - constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + constexpr index_t B = N * Hi * Wi; + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + const index_t k_block_work_id = get_block_1d_id() / BBlockWork; + const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t b_block_data_begin = b_block_work_id * BPerBlock; // flattend (2d) tensor view of gridwise input constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -192,15 +192,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric GemmKPerThreadLoop>{}; // LDS: be careful of alignment - constexpr unsigned in_block_size = + constexpr index_t in_block_size = in_cb_block_desc.GetElementSpace(Number{}); - constexpr unsigned wei_block_size = + constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace(Number{}); - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; // LDS __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; @@ -218,10 +218,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric // set threadwise output tensor to 0 threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); - for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - __syncthreads()) + for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + __syncthreads()) { // load data blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); @@ -231,18 +231,16 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric // compute on current data // a series of GEMM - for(unsigned y = 0; y < Y; ++y) + for(index_t y = 0; y < Y; ++y) { - for(unsigned x = 0; x < X; ++x) + for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; #if 0 blockwise_gemm.Run -#elif 1 +#elif 0 blockwise_gemm.Run_asm -#elif 0 - blockwise_gemm.Run_v2 -#elif 0 +#elif 1 blockwise_gemm.Run_RegisterDoubleBuffer #endif (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), @@ -257,23 +255,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric const auto c_thread_mtx_begin = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) { const auto c_thread_mtx_distance = blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; - unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; + index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - unsigned h_data = b_data / (Wi * N); - unsigned itmp = b_data - h_data * (Wi * N); - unsigned w_data = itmp / N; - unsigned n_data = itmp - w_data * N; + index_t h_data = b_data / (Wi * N); + index_t itmp = b_data - h_data * (Wi * N); + index_t w_data = itmp / N; + index_t n_data = itmp - w_data * N; if(n_data < N && h_data < Ho && w_data < Wo) { diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp index b38348cfac..f15bc1807b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp @@ -8,32 +8,32 @@ #include "blockwise_gemm.hip.hpp" // define B = flatten(N, Hi, Wi) -template + index_t BPerBlock, + index_t KPerBlock, + index_t CPerBlock, + index_t BPerThread, + index_t KPerThread, + index_t GemmThreadPerColumnPerCluster, + index_t GemmThreadPerRowPerCluster, + index_t GemmMPerThreadSubC, + index_t GemmNPerThreadSubC, + index_t GemmMLevel0Cluster, + index_t GemmNLevel0Cluster, + index_t GemmMLevel1Cluster, + index_t GemmNLevel1Cluster, + index_t GemmKPerThreadLoop, + index_t InBlockCopyThreadPerDim0, + index_t InBlockCopyThreadPerDim1, + index_t WeiBlockCopyThreadPerDim0, + index_t WeiBlockCopyThreadPerDim1, + index_t InBlockCopyDataPerRead, + index_t WeiBlockCopyDataPerRead> __global__ void #if 0 __launch_bounds__(256,2) @@ -52,30 +52,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - constexpr unsigned C = in_chwn_global_desc.GetLength(I0); - constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); - constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); - constexpr unsigned N = in_chwn_global_desc.GetLength(I3); + constexpr index_t C = in_chwn_global_desc.GetLength(I0); + constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); + constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); + constexpr index_t N = in_chwn_global_desc.GetLength(I3); - constexpr unsigned K = out_khwn_global_desc.GetLength(I0); - constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); - constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); + constexpr index_t K = out_khwn_global_desc.GetLength(I0); + constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); + constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); - constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); + constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); + constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - constexpr unsigned B = N * Hi * Wi; - constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); + constexpr index_t B = N * Hi * Wi; + constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); // divide block work by 2d: [K, B] - constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; + constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; + constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; - const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; - const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; + const index_t k_block_work_id = get_block_1d_id() / BBlockWork; + const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - const unsigned k_block_data_begin = k_block_work_id * KPerBlock; - const unsigned b_block_data_begin = b_block_work_id * BPerBlock; + const index_t k_block_data_begin = k_block_work_id * KPerBlock; + const index_t b_block_data_begin = b_block_work_id * BPerBlock; // flattend (2d) tensor view of gridwise input constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); @@ -210,15 +210,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( #endif // LDS: be careful of alignment - constexpr unsigned in_block_size = + constexpr index_t in_block_size = in_cb_block_desc.GetElementSpace(Number{}); - constexpr unsigned wei_block_size = + constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace(Number{}); - constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; + constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead + ? InBlockCopyDataPerRead + : WeiBlockCopyDataPerRead; // LDS double buffer __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; @@ -248,11 +248,11 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( bool even_loop = true; - for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; + for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - even_loop = !even_loop) + p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), + even_loop = !even_loop) { Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; @@ -279,12 +279,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( // compute on current data // a series of GEMM - for(unsigned y = 0; y < Y; ++y) + for(index_t y = 0; y < Y; ++y) { - for(unsigned x = 0; x < X; ++x) + for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 0 +#if 1 blockwise_gemm.Run #else blockwise_gemm.Run_RegisterDoubleBuffer @@ -309,12 +309,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( __syncthreads(); - for(unsigned y = 0; y < Y; ++y) + for(index_t y = 0; y < Y; ++y) { - for(unsigned x = 0; x < X; ++x) + for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 0 +#if 1 blockwise_gemm.Run #else blockwise_gemm.Run_RegisterDoubleBuffer @@ -331,8 +331,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( const auto c_thread_mtx_begin = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; + const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; + const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; #if 0 if(get_block_1d_id() == 0) @@ -348,20 +348,20 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( } #endif - for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) + for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) { - for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) + for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) { const auto c_thread_mtx_distance = blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; - unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; + index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; + index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - unsigned h_data = b_data / (Wi * N); - unsigned itmp = b_data - h_data * (Wi * N); - unsigned w_data = itmp / N; - unsigned n_data = itmp - w_data * N; + index_t h_data = b_data / (Wi * N); + index_t itmp = b_data - h_data * (Wi * N); + index_t w_data = itmp / N; + index_t n_data = itmp - w_data * N; if(n_data < N && h_data < Ho && w_data < Wo) { diff --git a/src/include/threadwise_2d_tensor_op.hip.hpp b/src/include/threadwise_2d_tensor_op.hip.hpp index cc48e88317..6e25b61b73 100644 --- a/src/include/threadwise_2d_tensor_op.hip.hpp +++ b/src/include/threadwise_2d_tensor_op.hip.hpp @@ -16,11 +16,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re } #endif - for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1) { - const unsigned dindex = desc.Get1dIndex(did0, did1); + const index_t dindex = desc.Get1dIndex(did0, did1); f(p[dindex]); } @@ -47,22 +47,22 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); - constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); + constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) { - const unsigned aindex = src_desc.Get1dIndex(did0, did1); + const index_t aindex = src_desc.Get1dIndex(did0, did1); - const unsigned did[2] = {did0, did1}; + const index_t did[2] = {did0, did1}; - const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); + const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); f(p_src[aindex], p_dst[bindex]); } @@ -118,21 +118,21 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi } #endif - constexpr unsigned nshift = NShift::mValue; + constexpr index_t nshift = NShift::mValue; - constexpr unsigned did0_end = + constexpr index_t did0_end = is_same::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0); - constexpr unsigned did1_end = + constexpr index_t did1_end = is_same::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1); - for(unsigned did0 = 0; did0 < did0_end; ++did0) + for(index_t did0 = 0; did0 < did0_end; ++did0) { - for(unsigned did1 = 0; did1 < did1_end; ++did1) + for(index_t did1 = 0; did1 < did1_end; ++did1) { - const unsigned dindex = desc.Get1dIndex(did0, did1); + const index_t dindex = desc.Get1dIndex(did0, did1); - const unsigned sindex = dindex + nshift * desc.GetStride(IDim{}); + const index_t sindex = dindex + nshift * desc.GetStride(IDim{}); p[dindex] = p[sindex]; } diff --git a/src/include/threadwise_4d_tensor_op.hip.hpp b/src/include/threadwise_4d_tensor_op.hip.hpp index 5b908d3ac6..19ab68d013 100644 --- a/src/include/threadwise_4d_tensor_op.hip.hpp +++ b/src/include/threadwise_4d_tensor_op.hip.hpp @@ -18,15 +18,15 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re } #endif - for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1) { - for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2) + for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2) { - for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3) + for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3) { - const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3); + const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3); f(p[dindex]); } @@ -58,28 +58,28 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); - constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); - constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); - constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); + constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0); + constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1); + constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2); + constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); - for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) { - for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) { - for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) { - const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3); + const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3); - const unsigned did[4] = {did0, did1, did2, did3}; + const index_t did[4] = {did0, did1, did2, did3}; - const unsigned bindex = + const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[aindex], p_dst[bindex]); @@ -129,7 +129,7 @@ __device__ void threadwise_4d_tensor_copy( } // need to assume src and dst is aligned -template +template __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, const Float* __restrict__ p_src, DstDesc, @@ -163,24 +163,24 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, DstDesc{}.GetStride(I2) % DataPerRead == 0, "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); - constexpr unsigned L3 = SrcOpLengths{}.Get(I3); + constexpr index_t L3 = SrcOpLengths{}.Get(I3); static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead"); - constexpr unsigned nloop_d3 = L3 / DataPerRead; + constexpr index_t nloop_d3 = L3 / DataPerRead; - for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) { - for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) { - for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) + for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) { - const unsigned src_index = + const index_t src_index = src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); - const unsigned dst_index = + const index_t dst_index = dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); if(DataPerRead == 1) @@ -224,31 +224,31 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi } #endif - constexpr unsigned nshift = NShift::mValue; + constexpr index_t nshift = NShift::mValue; - constexpr unsigned did0_end = + constexpr index_t did0_end = is_same::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0); - constexpr unsigned did1_end = + constexpr index_t did1_end = is_same::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1); - constexpr unsigned did2_end = + constexpr index_t did2_end = is_same::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2); - constexpr unsigned did3_end = + constexpr index_t did3_end = is_same::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3); - for(unsigned did0 = 0; did0 < did0_end; ++did0) + for(index_t did0 = 0; did0 < did0_end; ++did0) { - for(unsigned did1 = 0; did1 < did1_end; ++did1) + for(index_t did1 = 0; did1 < did1_end; ++did1) { - for(unsigned did2 = 0; did2 < did2_end; ++did2) + for(index_t did2 = 0; did2 < did2_end; ++did2) { - for(unsigned did3 = 0; did3 < did3_end; ++did3) + for(index_t did3 = 0; did3 < did3_end; ++did3) { - const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3); + const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3); - const unsigned sindex = dindex + nshift * desc.GetStride(IDim{}); + const index_t sindex = dindex + nshift * desc.GetStride(IDim{}); p[dindex] = p[sindex]; } diff --git a/src/include/threadwise_direct_convolution.hip.hpp b/src/include/threadwise_direct_convolution.hip.hpp index b9a509d6a0..9b90c402a1 100644 --- a/src/include/threadwise_direct_convolution.hip.hpp +++ b/src/include/threadwise_direct_convolution.hip.hpp @@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc, } #endif - for(unsigned n = 0; n < out_desc.GetLength(I0); ++n) + for(index_t n = 0; n < out_desc.GetLength(I0); ++n) { - for(unsigned k = 0; k < out_desc.GetLength(I1); ++k) + for(index_t k = 0; k < out_desc.GetLength(I1); ++k) { - for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho) + for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho) { - for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo) + for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo) { - for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c) + for(index_t c = 0; c < wei_desc.GetLength(I1); ++c) { - for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) + for(index_t y = 0; y < wei_desc.GetLength(I2); ++y) { - for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x) + for(index_t x = 0; x < wei_desc.GetLength(I3); ++x) { - const unsigned hi = ho + y; - const unsigned wi = wo + x; + const index_t hi = ho + y; + const index_t wi = wo + x; - const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi); + const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi); - const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x); + const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x); - const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo); + const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo); fused_multiply_accumulate( p_out[out_index], p_wei[wei_index], p_in[in_index]); @@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, Data p_in_reg[in_reg_desc.GetElementSpace()]; Data p_wei_reg[wei_reg_desc.GetElementSpace()]; - constexpr unsigned in_w_new_read = 1; + constexpr index_t in_w_new_read = 1; constexpr auto in_desc_reg_new_read = make_ConstantTensorDescriptor(Sequence +template __device__ void threadwise_matrix_copy(SrcMatrix, const Float* __restrict__ p_src, DstMatrix, @@ -10,16 +10,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix, constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; - for(unsigned i = 0; i < NRow; ++i) +#if 0 + for(index_t i = 0; i < NRow; ++i) { - for(unsigned j = 0; j < NCol; ++j) + for(index_t j = 0; j < NCol; ++j) { - const unsigned src_index = src_mtx.Get1dIndex(i, j); - const unsigned dst_index = dst_mtx.Get1dIndex(i, j); + const index_t src_index = src_mtx.Get1dIndex(i, j); + const index_t dst_index = dst_mtx.Get1dIndex(i, j); p_dst[dst_index] = p_src[src_index]; } } +#elif 1 + static_assert(NCol == 4, "only for NCol == 4"); + + using vector_t = typename vector_type::MemoryType; + + for(index_t i = 0; i < NRow; ++i) + { + const index_t src_index = src_mtx.Get1dIndex(i, 0); + const index_t dst_index = dst_mtx.Get1dIndex(i, 0); + +#if 1 + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); +#elif 1 + asm volatile("\n \ + ds_read_b128 %0, %1, offset:0 \n \ + " + : "=v"(*(reinterpret_cast(p_dst+dst_index))) + : "v"((uint32_t)(p_src + src_index))); +#endif + } +#endif } template +template __device__ void threadwise_6d_tensor_copy(SrcDesc, const Float* __restrict__ p_src, DstDesc, @@ -37,28 +37,28 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, DstDesc{}.GetStride(I4) % DataPerRead == 0, "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); - constexpr unsigned L5 = SrcOpLengths{}.Get(I5); + constexpr index_t L5 = SrcOpLengths{}.Get(I5); static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead"); - constexpr unsigned nloop_d5 = L5 / DataPerRead; + constexpr index_t nloop_d5 = L5 / DataPerRead; - for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) { - for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) { - for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) { - for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) + for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) { - for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5) + for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5) { - const unsigned src_index = src_desc.Get1dIndex( + const index_t src_index = src_desc.Get1dIndex( did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); - const unsigned dst_index = dst_desc.Get1dIndex( + const index_t dst_index = dst_desc.Get1dIndex( did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); *(reinterpret_cast(p_dst + dst_index)) = @@ -72,7 +72,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, } // need to assume src and dst is aligned -template +template __device__ void threadwise_8d_tensor_copy(SrcDesc, const Float* __restrict__ p_src, DstDesc, @@ -109,29 +109,29 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, DstDesc{}.GetStride(I6) % DataPerRead == 0, "wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); - constexpr unsigned L7 = SrcOpLengths{}.Get(I7); + constexpr index_t L7 = SrcOpLengths{}.Get(I7); static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead"); - constexpr unsigned nloop_d7 = L7 / DataPerRead; + constexpr index_t nloop_d7 = L7 / DataPerRead; - for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) + for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { - for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) + for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) { - for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) + for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) { - for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) + for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) { - for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) + for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) { - for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) + for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) { - for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) + for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) { - for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7) + for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7) { - const unsigned src_index = + const index_t src_index = src_desc.Get1dIndex(did0, did1, did2, @@ -141,7 +141,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, did6, iloop_d7 * DataPerRead); - const unsigned dst_index = + const index_t dst_index = dst_desc.Get1dIndex(did0, did1, did2,