experimenting

This commit is contained in:
Chao Liu
2019-03-24 12:09:57 -05:00
parent f35c64eb78
commit 766b0a9eaf
33 changed files with 1886 additions and 1822 deletions

View File

@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
// 3x3, 34x34
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 16;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc,

View File

@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
// 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,

View File

@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx,
OutDesc,
Tensor<TOut>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
// this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr index_t NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 16;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 16;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,

View File

@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,

View File

@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
constexpr index_t BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,

View File

@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
#if 1

View File

@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts>
double operator()(Ts... Xs) const
{
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(),
dims.end(),
true,
[](bool init, unsigned long x) -> int { return init != (x % 2); })
[](bool init, index_t x) -> int { return init != (x % 2); })
? 1
: -1;
}
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = {
std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = {
std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides);
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads,
UpperPads)
{
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1;
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 1;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 5x5, 36x36
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 36;
constexpr unsigned WI = 36;
constexpr unsigned K = 64;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 36;
constexpr index_t WI = 36;
constexpr index_t K = 64;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 7x7, 38x38
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 64;
constexpr unsigned Y = 7;
constexpr unsigned X = 7;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 1x1 filter, 28x28 image
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 84;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 64;
constexpr unsigned HI = 112;
constexpr unsigned WI = 112;
constexpr unsigned K = 128;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 64;
constexpr index_t HI = 112;
constexpr index_t WI = 112;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 86;
constexpr unsigned K = 512;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 86;
constexpr index_t K = 512;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16;
constexpr unsigned C = 192;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 32;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 16;
constexpr index_t C = 192;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 32;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2;
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 32;
constexpr unsigned WI = 32;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 14x14 image
constexpr unsigned N = 128;
constexpr unsigned C = 2048;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
// 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128;
constexpr unsigned C = 512;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
auto lower_pads = Sequence<HPad, WPad>{};
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]);
index_t nrepeat = atoi(argv[2]);
if(do_verification)
{

View File

@@ -1,18 +1,18 @@
#pragma once
template <class TData, unsigned NSize>
template <class TData, index_t NSize>
struct Array
{
using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize;
static constexpr index_t nSize = NSize;
unsigned mData[nSize];
index_t mData[nSize];
template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
{
}
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
__host__ __device__ TData operator[](index_t i) const { return mData[i]; }
};

View File

@@ -1,7 +1,7 @@
#pragma once
#include "common.hip.hpp"
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_>
template <index_t NRow_, index_t NCol_, index_t RowStride_>
struct ConstantMatrixDescriptor
{
__host__ __device__ constexpr ConstantMatrixDescriptor()
@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
}
__host__ __device__ constexpr unsigned NRow() const { return NRow_; }
__host__ __device__ constexpr index_t NRow() const { return NRow_; }
__host__ __device__ constexpr unsigned NCol() const { return NCol_; }
__host__ __device__ constexpr index_t NCol() const { return NCol_; }
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; }
__host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
__host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
{
#if DEVICE_BACKEND_HIP
return __mul24(irow, RowStride_) + icol;
#else
return irow * RowStride_ + icol;
#endif
}
template <unsigned SubNRow, unsigned SubNCol>
template <index_t SubNRow, index_t SubNCol>
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
Number<SubNCol>) const
{
@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
}
};
template <unsigned NRow, unsigned NCol>
template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
{
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
}
template <unsigned NRow, unsigned NCol, unsigned RowStride>
template <index_t NRow, index_t NCol, index_t RowStride>
__host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
{

View File

@@ -2,35 +2,35 @@
#include "common.hip.hpp"
// this is ugly, only for 2d
template <unsigned L0, unsigned L1>
template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
{
return Sequence<L1, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 6d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
}
// this is ugly, only for 8d
template <unsigned L0,
unsigned L1,
unsigned L2,
unsigned L3,
unsigned L4,
unsigned L5,
unsigned L6,
unsigned L7>
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{
@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
}
// this is ugly, only for 2d
template <unsigned L0, unsigned L1, unsigned Align>
template <index_t L0, index_t L1, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
Number<Align>)
{
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
return Sequence<L1_align, 1>{};
}
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
Number<Align>)
{
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
}
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim;
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr index_t nDim = Lengths::nDim;
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
}
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I>
__host__ __device__ constexpr unsigned GetLength(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetLength(Number<I>) const
{
return Lengths{}.Get(Number<I>{});
}
template <unsigned I>
__host__ __device__ constexpr unsigned GetStride(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t GetStride(Number<I>) const
{
return Strides{}.Get(Number<I>{});
}
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct GetElementSize_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return Type{}.GetLength(idim);
}
};
__host__ __device__ constexpr unsigned GetElementSize() const
__host__ __device__ constexpr index_t GetElementSize() const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a * b;
}
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct GetElementSpace_f
{
template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
}
};
template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a + b;
}
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
}
template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const
__host__ __device__ index_t Get1dIndex(Is... is) const
{
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
const auto multi_id = Array<unsigned, nDim>(is...);
const auto multi_id = Array<index_t, nDim>(is...);
unsigned id = 0;
index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) {
constexpr unsigned idim = IDim.Get();
constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim);
#endif
});
return id;
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
template <unsigned IDim, unsigned NVector>
template <index_t IDim, index_t NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{
assert(false); // not implemented
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, unsigned Align>
template <class Lengths, index_t Align>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
{
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr auto desc = TDesc{};
constexpr unsigned ndim = desc.GetDimension();
constexpr auto desc = TDesc{};
constexpr index_t ndim = desc.GetDimension();
static_assert(ndim >= 2 && ndim <= 8, "wrong!");

View File

@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template <unsigned... Is>
template <index_t... Is>
struct Sequence
{
using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is);
static constexpr index_t nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
const index_t mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
template <index_t I>
__host__ __device__ constexpr index_t Get(Number<I>) const
{
return mData[I];
}
// this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{
static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3];
constexpr index_t NR0 = old_sequence.mData[I0];
constexpr index_t NR1 = old_sequence.mData[I1];
constexpr index_t NR2 = old_sequence.mData[I2];
constexpr index_t NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{
// don't know how to implement this
@@ -41,7 +41,7 @@ struct Sequence
assert(false);
}
template <unsigned I>
template <index_t I>
__host__ __device__ constexpr auto PushBack(Number<I>) const
{
return Sequence<Is..., I>{};
@@ -56,14 +56,14 @@ struct Sequence
}
};
template <unsigned... Is, unsigned I>
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{};
}
template <class F, unsigned... Xs, unsigned... Ys>
template <class F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return Sequence<f(Xs, Ys)...>{};
}
template <unsigned... Xs, unsigned... Ys>
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
__host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
{
return x + y;
}
@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
}
template <unsigned... Is>
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
return sequence_pop_back(Type{});

View File

@@ -1,7 +1,7 @@
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F>
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]);
}
@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
f(p_dst[dindex]);
}
@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[2];
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[2];
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[1] = is / ref_desc.GetStride(I1);
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
}
}
template <unsigned BlockSize, class Float, class DstDesc>
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
struct Blockwise2dTensorCopy1
{
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
unsigned ThreadPerDim0,
unsigned ThreadPerDim1>
index_t ThreadPerDim0,
index_t ThreadPerDim1>
struct Blockwise2dTensorCopy2
{
unsigned mThreadId0;
unsigned mThreadId1;
index_t mThreadId0;
index_t mThreadId1;
__device__ Blockwise2dTensorCopy2()
{
@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2
constexpr bool align_v2 =
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr unsigned Dim1V2Loop =
constexpr index_t Dim1V2Loop =
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
constexpr unsigned Dim1V1Loop =
constexpr index_t Dim1V1Loop =
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
ThreadPerDim1;
constexpr bool d1_has_tail =
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop)
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
{
unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0;
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
// v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
unsigned did1 =
index_t did1 =
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2
// dim-1 tail
if(d1_has_tail)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2
// dim-0 tail
if(d0_has_tail)
{
unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
if(did0 < L0)
{
// v4
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 +
mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2
// tail
if(d1_has_tail)
{
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 +
mThreadId1;
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
const index_t sindex = src_desc.Get1dIndex(did0, did1);
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise2dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset;
unsigned mDstMyThreadOffset;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise2dTensorCopy3()
{
@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3
DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3
}
}
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3
}
}
__device__ constexpr unsigned GetRegisterClipboardSize() const
__device__ constexpr index_t GetRegisterClipboardSize() const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
}
@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](unsigned iloop) {
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4));
};
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3
if(has_tail_d0)
{
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{

View File

@@ -1,7 +1,7 @@
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template <unsigned BlockSize, class Float, class DstDesc, class F>
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3);
const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const unsigned did0 = is / desc.GetStride(I0);
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / desc.GetStride(I1);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / desc.GetStride(I2);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / desc.GetStride(I3);
const index_t did3 = is / desc.GetStride(I3);
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
}
}
template <unsigned BlockSize, class Float, class DstDesc>
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise4dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
constexpr auto ref_desc =
make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{});
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](unsigned is) {
unsigned did[4];
auto f_copy = [&](index_t is) {
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1
did[3] = is / ref_desc.GetStride(I3);
const unsigned src_index =
const index_t src_index =
src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
const unsigned dst_index =
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
f_copy(is);
}
@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1
}
};
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
@@ -315,15 +315,15 @@ template <unsigned BlockSize,
struct BlockwiseChwnTensorCopyPadded
{
__device__ void Run(const Float* __restrict__ p_src,
unsigned c_block_data_begin,
unsigned ho_block_data_begin,
unsigned wo_block_data_begin,
unsigned n_block_data_begin,
index_t c_block_data_begin,
index_t ho_block_data_begin,
index_t wo_block_data_begin,
index_t n_block_data_begin,
Float* __restrict__ p_dst,
unsigned h_block_pad_low,
unsigned w_block_pad_low,
unsigned h_block_pad_up,
unsigned w_block_pad_up) const
index_t h_block_pad_low,
index_t w_block_pad_low,
index_t h_block_pad_up,
index_t w_block_pad_up) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp =
p_src +
@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded
}
#endif
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
unsigned is = threadIdx.x + iloop * BlockSize;
index_t is = threadIdx.x + iloop * BlockSize;
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded
if(has_tail)
{
unsigned is = threadIdx.x + NLoop * BlockSize;
index_t is = threadIdx.x + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
unsigned did[4];
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded
did[3] = is / ref_desc.GetStride(I3);
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low ||
@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
unsigned DataPerRead>
index_t DataPerRead>
struct Blockwise4dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
unsigned mSrcMyThreadOffset;
unsigned mDstMyThreadOffset;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopy3()
{
@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr unsigned num_active_thread =
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3
}
}
const unsigned thread_id_d0 =
const index_t thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
unsigned itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const unsigned thread_id_d2 = itmp / thread_per_d3;
const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
const index_t thread_id_d2 = itmp / thread_per_d3;
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned L0 = CopyLengths{}.Get(I0);
constexpr unsigned L1 = CopyLengths{}.Get(I1);
constexpr unsigned L2 = CopyLengths{}.Get(I2);
constexpr unsigned L3 = CopyLengths{}.Get(I3);
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr unsigned num_active_thread =
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3
}
}
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
constexpr unsigned nloop_d1 = L1 / thread_per_d1;
constexpr unsigned nloop_d2 = L2 / thread_per_d2;
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll
for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const unsigned src_offset =
const index_t src_offset =
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const unsigned dst_offset =
const index_t dst_offset =
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,

View File

@@ -1,30 +1,30 @@
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned BatchPerThread,
unsigned KPerThreadLoop,
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t BatchPerThread,
index_t KPerThreadLoop,
bool DistributeThreadAlongColumnFirst>
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0");
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
"wrong! wrong BlockSize");
@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
// num of operations can be reduced
const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork);
unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const unsigned m_work_id = itmp / NThreadWork;
const unsigned n_work_id = itmp - m_work_id * NThreadWork;
const index_t b_work_id = thread_id / (MThreadWork * NThreadWork);
index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const index_t m_work_id = itmp / NThreadWork;
const index_t n_work_id = itmp - m_work_id * NThreadWork;
return MatrixIndex{
b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread};
@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
return MatrixIndex{batch_in_c, m_in_c, n_in_c};
}
@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of a, b
threadwise_matrix_copy(a_block_mtx,
@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
b_thread_mtx.GetLengths());
// loop over batch
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
};
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop,
unsigned BatchPerThread>
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster;
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
//#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
#if 1
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
//#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
const index_t aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] =
@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] =
@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
// do last batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex =
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex =
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
asm volatile("\n \
@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const
{
constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned c_thread_offset =
const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
c_thread_sub_mtx,

View File

@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class Float,
class InBlockDesc,
class WeiBlockDesc,
class OutBlockDesc,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread>
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread>
__device__ void blockwise_direct_convolution(InBlockDesc,
Float* const __restrict__ p_in_block,
WeiBlockDesc,
@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto wei_block_desc = WeiBlockDesc{};
constexpr auto out_block_desc = OutBlockDesc{};
constexpr unsigned Y = wei_block_desc.GetLength(I2);
constexpr unsigned X = wei_block_desc.GetLength(I3);
constexpr index_t Y = wei_block_desc.GetLength(I2);
constexpr index_t X = wei_block_desc.GetLength(I3);
constexpr unsigned InTileSizeH = HoPerThread + Y - 1;
constexpr unsigned InTileSizeW = WoPerThread + X - 1;
constexpr index_t InTileSizeH = HoPerThread + Y - 1;
constexpr index_t InTileSizeW = WoPerThread + X - 1;
// divide thread work
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0
if(threadIdx.x == 0)
@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const unsigned thread_id = threadIdx.x;
const index_t thread_id = threadIdx.x;
for(unsigned thread_work_id = thread_id;
for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
thread_work_id += BlockSize)
{
unsigned itmp = thread_work_id;
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
index_t itmp = thread_work_id;
index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
unsigned y_thread_work_id = itmp / XThreadWork;
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
index_t y_thread_work_id = itmp / XThreadWork;
index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread;
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread;
index_t n_thread_data_begin = n_thread_work_id * NPerThread;
index_t k_thread_data_begin = k_thread_work_id * KPerThread;
index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding
index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
Float p_out_thread[out_thread_desc.GetElementSpace()];
@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread,
out_thread_desc.GetLengths());
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
c_thread_data_begin += CPerThread)
{
// threadwise convolution

View File

@@ -1,26 +1,26 @@
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned KPerThreadLoop,
unsigned MThreadPerCluster,
unsigned NThreadPerCluster,
index_t KPerThreadLoop,
index_t MThreadPerCluster,
index_t NThreadPerCluster,
bool DistributeThreadAlongColumnFirst>
struct BlockwiseGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
__device__ BlockwiseGemmBlockABlockBThreadC()
@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
constexpr unsigned MClusterWork =
constexpr index_t MClusterWork =
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
constexpr unsigned NClusterWork =
constexpr index_t NClusterWork =
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
static_assert(BlockSize ==
@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
const unsigned cluster_work_block_id =
const index_t cluster_work_block_id =
thread_id / (MThreadPerCluster * NThreadPerCluster);
const unsigned thread_work_cluster_id =
const index_t thread_work_cluster_id =
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const unsigned n_cluster_work_block_id =
const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const index_t n_cluster_work_block_id =
cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
const unsigned m_thread_work_cluster_id =
thread_work_cluster_id / NThreadPerCluster;
const unsigned n_thread_work_cluster_id =
const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
const index_t n_thread_work_cluster_id =
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
#if 0
@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
return MatrixIndex{m_in_c, n_in_c};
}
@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop>
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
unsigned mMyThreadOffsetA;
unsigned mMyThreadOffsetB;
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
//#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
}
#pragma unroll
//#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False,
p_c_thread,
f_accum);
#else
#elif 0
// inline asm
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
"asm is only for 8x8");
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
{
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %8, %9 \n \
@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool even_loop = true;
#pragma unroll
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// preload next A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A-sub, B-sub, C-sub
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// C-sub(s) in first row-wise subblock of C
{
@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy next B-sub, and do GEMM
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
{
// copy a A-sub
threadwise_matrix_copy(
@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
// do some GEMMs
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_gemm(
a_thread_sub_mtx,

View File

@@ -5,9 +5,9 @@
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
}
#endif
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b)
__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
{
return (a + b - 1) / b;
}

View File

@@ -11,3 +11,5 @@
#include "nvToolsExt.h"
#include "helper_cuda.h"
#endif
using index_t = uint32_t;

View File

@@ -8,5 +8,5 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; }
};
template <unsigned N>
using Number = integral_constant<unsigned, N>;
template <index_t N>
using Number = integral_constant<index_t, N>;

View File

@@ -1,7 +1,7 @@
#pragma once
#include "config.h"
template <class T, unsigned N>
template <class T, index_t N>
struct vector_type
{
};

View File

@@ -1,7 +1,7 @@
#pragma once
#include "constant_integral.hip.hpp"
template <unsigned NLoop>
template <index_t NLoop>
struct static_loop_n
{
template <class F>
@@ -24,7 +24,7 @@ struct static_loop_n<1>
}
};
template <unsigned NLoop>
template <index_t NLoop>
struct static_const_reduce_n
{
template <class F, class Reduce>

View File

@@ -8,18 +8,18 @@ template <class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned BlockSize,
unsigned GridSize>
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t BlockSize,
index_t GridSize>
__global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
@@ -33,16 +33,16 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr auto wei_global_desc = WeiGlobalDesc{};
constexpr auto out_global_desc = OutGlobalDesc{};
constexpr unsigned Y = wei_global_desc.GetLength(I2);
constexpr unsigned X = wei_global_desc.GetLength(I3);
constexpr index_t Y = wei_global_desc.GetLength(I2);
constexpr index_t X = wei_global_desc.GetLength(I3);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
@@ -59,31 +59,31 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr auto out_block_desc =
make_ConstantTensorDescriptor(out_block_global_desc.GetLengths());
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace();
constexpr unsigned out_block_size = out_block_desc.GetElementSpace();
constexpr index_t in_block_size = in_block_desc.GetElementSpace();
constexpr index_t wei_block_size = wei_block_desc.GetElementSpace();
constexpr index_t out_block_size = out_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
__shared__ Float p_out_block[out_block_size];
const unsigned block_id = blockIdx.x;
const index_t block_id = blockIdx.x;
unsigned itmp = block_id;
unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
index_t itmp = block_id;
index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
unsigned h_block_work_id = itmp / WBlockWork;
unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
index_t h_block_work_id = itmp / WBlockWork;
index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
unsigned n_block_work_begin = n_block_work_id * NPerBlock;
unsigned k_block_work_begin = k_block_work_id * KPerBlock;
unsigned ho_block_work_begin = h_block_work_id * HoPerBlock;
unsigned wo_block_work_begin = w_block_work_id * WoPerBlock;
index_t n_block_work_begin = n_block_work_id * NPerBlock;
index_t k_block_work_begin = k_block_work_id * KPerBlock;
index_t ho_block_work_begin = h_block_work_id * HoPerBlock;
index_t wo_block_work_begin = w_block_work_id * WoPerBlock;
unsigned hi_block_work_begin = ho_block_work_begin; // minus padding
unsigned wi_block_work_begin = wo_block_work_begin; // minus padding
index_t hi_block_work_begin = ho_block_work_begin; // minus padding
index_t wi_block_work_begin = wo_block_work_begin; // minus padding
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
@@ -109,7 +109,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
// set output tensor in LDS to 0
blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block);
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
c_block_work_begin += CPerBlock)
{
// copy input tensor to LDS

View File

@@ -11,20 +11,20 @@ template <class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead,
unsigned BlockSize,
unsigned GridSize>
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t BlockSize,
index_t GridSize>
__global__ void
gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
@@ -39,17 +39,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr unsigned N = in_nchw_global_desc.GetLength(I0);
constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3);
constexpr index_t N = in_nchw_global_desc.GetLength(I0);
constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
@@ -63,21 +63,21 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem
constexpr unsigned in_block_size =
constexpr index_t in_block_size =
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
constexpr index_t wei_block_size =
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1;
constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
@@ -93,56 +93,54 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work
constexpr unsigned NBlockWork =
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork =
constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x;
const index_t block_id = blockIdx.x;
unsigned itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
index_t itmp = block_id;
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const index_t h_block_work_id = itmp / WBlockWork;
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const unsigned thread_id = threadIdx.x;
const index_t thread_id = threadIdx.x;
itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const unsigned h_thread_work_id = itmp / WThreadWork;
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const index_t h_thread_work_id = itmp / WThreadWork;
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const unsigned hi_thread_data_begin = ho_thread_data_begin;
const unsigned wi_thread_data_begin = wo_thread_data_begin;
const index_t hi_thread_data_begin = ho_thread_data_begin;
const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
@@ -172,7 +170,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
@@ -191,7 +189,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
__syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise convolution
#if 1

View File

@@ -13,21 +13,21 @@ template <class TInWei,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned ScalarPerVector,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead,
unsigned BlockSize,
unsigned GridSize>
index_t ScalarPerVector,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t BlockSize,
index_t GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
@@ -49,17 +49,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr unsigned N = in_nchw_vec_global_desc.GetLength(I0);
constexpr unsigned K = wei_kcyx_vec_global_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_vec_global_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_vec_global_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_vec_global_desc.GetLength(I3);
constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
@@ -73,15 +73,15 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem
constexpr unsigned in_block_size =
constexpr index_t in_block_size =
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
constexpr index_t wei_block_size =
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ in_vector_mem_t
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
@@ -89,8 +89,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1;
constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_vec_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
@@ -106,56 +106,54 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work
constexpr unsigned NBlockWork =
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork =
constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x;
const index_t block_id = blockIdx.x;
unsigned itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
index_t itmp = block_id;
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const index_t h_block_work_id = itmp / WBlockWork;
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const unsigned thread_id = threadIdx.x;
const index_t thread_id = threadIdx.x;
itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const unsigned h_thread_work_id = itmp / WThreadWork;
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const index_t h_thread_work_id = itmp / WThreadWork;
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const unsigned hi_thread_data_begin = ho_thread_data_begin;
const unsigned wi_thread_data_begin = wo_thread_data_begin;
const index_t hi_thread_data_begin = ho_thread_data_begin;
const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
@@ -188,7 +186,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
#endif
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
@@ -207,7 +205,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
__syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise convolution
#if 1

View File

@@ -8,32 +8,32 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template <unsigned GridSize,
unsigned BlockSize,
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
class InBlockCopyThreadPerDims,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned OutThreadCopyDataPerWrite>
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t OutThreadCopyDataPerWrite>
__global__ void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
@@ -55,39 +55,39 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// flattend (2d) tensor view of gridwise weight
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
@@ -164,15 +164,15 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
HoPerThread>{};
// LDS: be careful of alignment
constexpr unsigned in_block_size =
constexpr index_t in_block_size =
in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
@@ -191,10 +191,10 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const Float* p_wei_global_block_begin =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
{
// input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
@@ -205,9 +205,9 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
__syncthreads();
// a series of batched GEMM
for(unsigned y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
#if 0
blockwise_batch_gemm.Run
@@ -227,26 +227,26 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const unsigned ho_thread =
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread % NPerBlock;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
@@ -261,19 +261,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// this is for v2 GEMM
// output is a 8d tensor
if(NPerThread < NPerBlock && WoPerThread == 1)
{
constexpr unsigned N1_ = GemmNPerThreadSubC;
constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
constexpr unsigned K2_ = GemmMPerThreadSubC;
constexpr unsigned K1_ = KPerBlock / KPerThread;
constexpr index_t N1_ = GemmNPerThreadSubC;
constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
constexpr index_t K2_ = GemmMPerThreadSubC;
constexpr index_t K1_ = KPerBlock / KPerThread;
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{});

View File

@@ -7,26 +7,26 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
template <unsigned GridSize,
unsigned BlockSize,
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class LowerPads,
class UpperPads,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1>
index_t NPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t NPerThread,
index_t KPerThread,
index_t CPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
@@ -48,42 +48,42 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned HPadLow = LowerPads{}.Get(I0);
constexpr unsigned WPadLow = LowerPads{}.Get(I1);
constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr unsigned HPadUp = UpperPads{}.Get(I0);
constexpr unsigned WPadUp = UpperPads{}.Get(I1);
constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
// flattened (2d) tensor view of wei in global mem
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
@@ -114,11 +114,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
// blockwise copy
// input: format is [C, Hi, Wi, N]
const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
#if 0
if(get_thread_local_1d_id() == 0)
@@ -204,8 +204,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
true>{};
// LDS
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_cyxk_block_desc.GetElementSpace();
constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
@@ -219,9 +219,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const Float* p_wei_global_block_begin =
p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
__syncthreads())
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
__syncthreads())
{
#if 1
// input: global mem to LDS,
@@ -245,9 +245,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
__syncthreads();
// a series of batched GEMM
for(unsigned y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
@@ -262,10 +262,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const auto matrix_c_index =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
const index_t ho_thread_data_begin = matrix_c_index.batch;
const index_t k_thread_data_begin = matrix_c_index.row;
const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",

View File

@@ -8,32 +8,32 @@
#include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
index_t BPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t BPerThread,
index_t KPerThread,
index_t GemmThreadPerColumnPerCluster,
index_t GemmThreadPerRowPerCluster,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead>
__global__ void
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
@@ -48,30 +48,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1);
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2);
constexpr unsigned N = in_chwn_global_desc.GetLength(I3);
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t B = N * Hi * Wi;
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
@@ -192,15 +192,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
GemmKPerThreadLoop>{};
// LDS: be careful of alignment
constexpr unsigned in_block_size =
constexpr index_t in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
// LDS
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
@@ -218,10 +218,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads())
{
// load data
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
@@ -231,18 +231,16 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
// compute on current data
// a series of GEMM
for(unsigned y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
blockwise_gemm.Run
#elif 1
#elif 0
blockwise_gemm.Run_asm
#elif 0
blockwise_gemm.Run_v2
#elif 0
#elif 1
blockwise_gemm.Run_RegisterDoubleBuffer
#endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
@@ -257,23 +255,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned h_data = b_data / (Wi * N);
unsigned itmp = b_data - h_data * (Wi * N);
unsigned w_data = itmp / N;
unsigned n_data = itmp - w_data * N;
index_t h_data = b_data / (Wi * N);
index_t itmp = b_data - h_data * (Wi * N);
index_t w_data = itmp / N;
index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo)
{

View File

@@ -8,32 +8,32 @@
#include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
index_t BPerBlock,
index_t KPerBlock,
index_t CPerBlock,
index_t BPerThread,
index_t KPerThread,
index_t GemmThreadPerColumnPerCluster,
index_t GemmThreadPerRowPerCluster,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead>
__global__ void
#if 0
__launch_bounds__(256,2)
@@ -52,30 +52,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1);
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2);
constexpr unsigned N = in_chwn_global_desc.GetLength(I3);
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t B = N * Hi * Wi;
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
@@ -210,15 +210,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size =
constexpr index_t in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
// LDS double buffer
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
@@ -248,11 +248,11 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
bool even_loop = true;
for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
even_loop = !even_loop)
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
even_loop = !even_loop)
{
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
@@ -279,12 +279,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
// compute on current data
// a series of GEMM
for(unsigned y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
#if 1
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
@@ -309,12 +309,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
__syncthreads();
for(unsigned y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
for(unsigned x = 0; x < X; ++x)
for(index_t x = 0; x < X; ++x)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0
#if 1
blockwise_gemm.Run
#else
blockwise_gemm.Run_RegisterDoubleBuffer
@@ -331,8 +331,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0
if(get_block_1d_id() == 0)
@@ -348,20 +348,20 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
}
#endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned h_data = b_data / (Wi * N);
unsigned itmp = b_data - h_data * (Wi * N);
unsigned w_data = itmp / N;
unsigned n_data = itmp - w_data * N;
index_t h_data = b_data / (Wi * N);
index_t itmp = b_data - h_data * (Wi * N);
index_t w_data = itmp / N;
index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo)
{

View File

@@ -16,11 +16,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
}
#endif
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
{
const unsigned dindex = desc.Get1dIndex(did0, did1);
const index_t dindex = desc.Get1dIndex(did0, did1);
f(p[dindex]);
}
@@ -47,22 +47,22 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
const unsigned aindex = src_desc.Get1dIndex(did0, did1);
const index_t aindex = src_desc.Get1dIndex(did0, did1);
const unsigned did[2] = {did0, did1};
const index_t did[2] = {did0, did1};
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
@@ -118,21 +118,21 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
}
#endif
constexpr unsigned nshift = NShift::mValue;
constexpr index_t nshift = NShift::mValue;
constexpr unsigned did0_end =
constexpr index_t did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
constexpr unsigned did1_end =
constexpr index_t did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
for(unsigned did0 = 0; did0 < did0_end; ++did0)
for(index_t did0 = 0; did0 < did0_end; ++did0)
{
for(unsigned did1 = 0; did1 < did1_end; ++did1)
for(index_t did1 = 0; did1 < did1_end; ++did1)
{
const unsigned dindex = desc.Get1dIndex(did0, did1);
const index_t dindex = desc.Get1dIndex(did0, did1);
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
p[dindex] = p[sindex];
}

View File

@@ -18,15 +18,15 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
}
#endif
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2)
for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3)
for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
{
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
f(p[dindex]);
}
@@ -58,28 +58,28 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{
const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
const unsigned did[4] = {did0, did1, did2, did3};
const index_t did[4] = {did0, did1, did2, did3};
const unsigned bindex =
const index_t bindex =
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[aindex], p_dst[bindex]);
@@ -129,7 +129,7 @@ __device__ void threadwise_4d_tensor_copy(
}
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
@@ -163,24 +163,24 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L3 = SrcOpLengths{}.Get(I3);
constexpr index_t L3 = SrcOpLengths{}.Get(I3);
static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d3 = L3 / DataPerRead;
constexpr index_t nloop_d3 = L3 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const unsigned src_index =
const index_t src_index =
src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
const unsigned dst_index =
const index_t dst_index =
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
if(DataPerRead == 1)
@@ -224,31 +224,31 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
}
#endif
constexpr unsigned nshift = NShift::mValue;
constexpr index_t nshift = NShift::mValue;
constexpr unsigned did0_end =
constexpr index_t did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
constexpr unsigned did1_end =
constexpr index_t did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
constexpr unsigned did2_end =
constexpr index_t did2_end =
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
constexpr unsigned did3_end =
constexpr index_t did3_end =
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
for(unsigned did0 = 0; did0 < did0_end; ++did0)
for(index_t did0 = 0; did0 < did0_end; ++did0)
{
for(unsigned did1 = 0; did1 < did1_end; ++did1)
for(index_t did1 = 0; did1 < did1_end; ++did1)
{
for(unsigned did2 = 0; did2 < did2_end; ++did2)
for(index_t did2 = 0; did2 < did2_end; ++did2)
{
for(unsigned did3 = 0; did3 < did3_end; ++did3)
for(index_t did3 = 0; did3 < did3_end; ++did3)
{
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
p[dindex] = p[sindex];
}

View File

@@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc,
}
#endif
for(unsigned n = 0; n < out_desc.GetLength(I0); ++n)
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
{
for(unsigned k = 0; k < out_desc.GetLength(I1); ++k)
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
{
for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho)
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
{
for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo)
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
{
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
{
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
const unsigned hi = ho + y;
const unsigned wi = wo + x;
const index_t hi = ho + y;
const index_t wi = wo + x;
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi);
const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x);
const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x);
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo);
fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]);
@@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
Data p_in_reg[in_reg_desc.GetElementSpace()];
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
constexpr unsigned in_w_new_read = 1;
constexpr index_t in_w_new_read = 1;
constexpr auto in_desc_reg_new_read =
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
@@ -136,7 +136,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// read first input
threadwise_4d_tensor_copy(in_desc,
@@ -157,7 +157,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction
for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x)
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
@@ -186,10 +186,10 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#elif 1
// this version read all input from LDS when filter moves
// loop over vertical direction
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// loop over horizontal direction
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,

View File

@@ -1,6 +1,6 @@
#pragma once
template <class Float, class SrcMatrix, class DstMatrix, unsigned NRow, unsigned NCol>
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
__device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
@@ -10,16 +10,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
for(unsigned i = 0; i < NRow; ++i)
#if 0
for(index_t i = 0; i < NRow; ++i)
{
for(unsigned j = 0; j < NCol; ++j)
for(index_t j = 0; j < NCol; ++j)
{
const unsigned src_index = src_mtx.Get1dIndex(i, j);
const unsigned dst_index = dst_mtx.Get1dIndex(i, j);
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
p_dst[dst_index] = p_src[src_index];
}
}
#elif 1
static_assert(NCol == 4, "only for NCol == 4");
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
const index_t src_index = src_mtx.Get1dIndex(i, 0);
const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
#if 1
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
#elif 1
asm volatile("\n \
ds_read_b128 %0, %1, offset:0 \n \
"
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index)))
: "v"((uint32_t)(p_src + src_index)));
#endif
}
#endif
}
template <class MatrixA,
@@ -49,21 +72,31 @@ __device__ void threadwise_gemm(MatrixA,
constexpr auto b_mtx = MatrixB{};
constexpr auto c_mtx = MatrixC{};
constexpr unsigned M = c_mtx.NRow();
constexpr unsigned N = c_mtx.NCol();
constexpr unsigned K = a_mtx.NRow(); // A is transposed
constexpr index_t M = c_mtx.NRow();
constexpr index_t N = c_mtx.NCol();
constexpr index_t K = a_mtx.NRow(); // A is transposed
for(unsigned k = 0; k < K; ++k)
for(index_t k = 0; k < K; ++k)
{
for(unsigned i = 0; i < M; ++i)
for(index_t i = 0; i < M; ++i)
{
for(unsigned j = 0; j < N; ++j)
for(index_t j = 0; j < N; ++j)
{
const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_mtx.Get1dIndex(k, j);
const unsigned cindex = c_mtx.Get1dIndex(i, j);
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_mtx.Get1dIndex(k, j);
const index_t cindex = c_mtx.Get1dIndex(i, j);
#if 0
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
#elif 1
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(p_c_thread[cindex])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex]),
"0"(p_c_thread[cindex]));
#endif
}
}
}

View File

@@ -2,7 +2,7 @@
#include "ConstantTensorDescriptor.hip.hpp"
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_6d_tensor_copy(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
@@ -37,28 +37,28 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
DstDesc{}.GetStride(I4) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L5 = SrcOpLengths{}.Get(I5);
constexpr index_t L5 = SrcOpLengths{}.Get(I5);
static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d5 = L5 / DataPerRead;
constexpr index_t nloop_d5 = L5 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
{
for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
{
const unsigned src_index = src_desc.Get1dIndex(
const index_t src_index = src_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
const unsigned dst_index = dst_desc.Get1dIndex(
const index_t dst_index = dst_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
@@ -72,7 +72,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
}
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_8d_tensor_copy(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
@@ -109,29 +109,29 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
DstDesc{}.GetStride(I6) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L7 = SrcOpLengths{}.Get(I7);
constexpr index_t L7 = SrcOpLengths{}.Get(I7);
static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d7 = L7 / DataPerRead;
constexpr index_t nloop_d7 = L7 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
{
for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
{
for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
{
for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
{
const unsigned src_index =
const index_t src_index =
src_desc.Get1dIndex(did0,
did1,
did2,
@@ -141,7 +141,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
did6,
iloop_d7 * DataPerRead);
const unsigned dst_index =
const index_t dst_index =
dst_desc.Get1dIndex(did0,
did1,
did2,