mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
experimenting
This commit is contained in:
@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
|
||||
const Tensor<T>& wei,
|
||||
OutDesc,
|
||||
Tensor<T>& out,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
|
||||
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
|
||||
|
||||
#if 1
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 16;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
|
||||
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(gridwise_direct_convolution_1<T,
|
||||
InDesc,
|
||||
|
||||
@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& wei,
|
||||
OutDesc,
|
||||
Tensor<T>& out,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
|
||||
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
#if 1
|
||||
// 3x3, 34x34, 128 thread
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3, 34x34, 128 thread, fp16
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
|
||||
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time =
|
||||
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
|
||||
|
||||
@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<TInWei>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
// this suppose in / wei data type is int8x4
|
||||
constexpr unsigned NVector = 4;
|
||||
using accum_t = int32_t;
|
||||
using vector_t = vector_type<TInWei, NVector>;
|
||||
using vector_mem_t = typename vector_t::MemoryType;
|
||||
constexpr index_t NVector = 4;
|
||||
using accum_t = int32_t;
|
||||
using vector_t = vector_type<TInWei, NVector>;
|
||||
using vector_mem_t = typename vector_t::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// vectorized input
|
||||
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
|
||||
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
#if 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 2
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 4;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 4;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// 1x1, 32x32, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 16;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr index_t NPerBlock = 1;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 4;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 4;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
|
||||
|
||||
@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// reorder weight
|
||||
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
|
||||
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
#if 0
|
||||
// for 3x3, 34x34
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5, 36x36
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
constexpr unsigned NPerBlock = 8;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 7x7, 38x38
|
||||
constexpr unsigned NPerBlock = 8;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56
|
||||
constexpr unsigned NPerBlock = 32;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 1x1, 14x14
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
|
||||
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,
|
||||
|
||||
@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// reorder weight
|
||||
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
|
||||
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 0
|
||||
constexpr unsigned NPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 1;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 1;
|
||||
constexpr index_t KPerBlock = 1;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 1;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 8;
|
||||
constexpr index_t BlockSize = 8;
|
||||
#elif 1
|
||||
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
constexpr unsigned NPerBlock = 8;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5, 36x36
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 7x7, 38x38
|
||||
constexpr unsigned NPerBlock = 8;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56
|
||||
constexpr unsigned NPerBlock = 32;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3 56x56, NKC = 16,256,128, with padding
|
||||
// 3x3 28x28, NKC = 16,512,256, with padding
|
||||
// 3x3 20x84, NKC = 16,256,256, with padding
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5 filter, 20x84 image, 1x1 padding
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 5x5 filter, 28x28 image, 2x2 padding
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
|
||||
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
|
||||
|
||||
@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_desc.GetLength(I0);
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
constexpr index_t N = in_nchw_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// convert in_nchw to in_cnhw
|
||||
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
|
||||
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
#if 0
|
||||
// 3x3, 34x34
|
||||
// need to use register double buffer for GEMM
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 8;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 1x1, 28x28, 64 threads
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr index_t BPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
#elif 0
|
||||
// 1x1, 28x28, 128 threads, no lds-double-buffer
|
||||
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr index_t BPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 4;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 1x1, 28x28, 256 thread
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 4;
|
||||
constexpr unsigned GemmMLevel1Cluster = 4;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 256;
|
||||
constexpr index_t BlockSize = 256;
|
||||
#elif 1
|
||||
// 1x1, 14x14, Vega 10
|
||||
constexpr index_t BPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
|
||||
constexpr index_t BPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr index_t GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
constexpr index_t GridSize =
|
||||
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
#if 1
|
||||
|
||||
@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
|
||||
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, unsigned long x) -> int { return init != (x % 2); })
|
||||
[](bool init, index_t x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto desc = TConstTensorDesc{};
|
||||
|
||||
std::initializer_list<unsigned> lengths = {
|
||||
std::initializer_list<index_t> lengths = {
|
||||
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
|
||||
std::initializer_list<unsigned> strides = {
|
||||
std::initializer_list<index_t> strides = {
|
||||
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
|
||||
|
||||
return TensorDescriptor(lengths, strides);
|
||||
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
{
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 1;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 5x5, 36x36
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 36;
|
||||
constexpr unsigned WI = 36;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 36;
|
||||
constexpr index_t WI = 36;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 7x7, 38x38
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 38;
|
||||
constexpr unsigned WI = 38;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 7;
|
||||
constexpr unsigned X = 7;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 38;
|
||||
constexpr index_t WI = 38;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3 filter, 58x58 image, 0x0 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3 filter, 56x56 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 3x3 filter, 28x28 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3 filter, 20x84 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 84;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 20;
|
||||
constexpr index_t WI = 84;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 3x3 filter, 112x112 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 112;
|
||||
constexpr unsigned WI = 112;
|
||||
constexpr unsigned K = 128;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 112;
|
||||
constexpr index_t WI = 112;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 5x5 filter, 20x86 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 86;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 20;
|
||||
constexpr index_t WI = 86;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 5x5 filter, 28x28 image, 2x2 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 192;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 32;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 2;
|
||||
constexpr unsigned WPad = 2;
|
||||
constexpr index_t HPad = 2;
|
||||
constexpr index_t WPad = 2;
|
||||
#elif 0
|
||||
// 1x1 filter, 32x32 image
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 32;
|
||||
constexpr unsigned WI = 32;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 32;
|
||||
constexpr index_t WI = 32;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr unsigned N = 128;
|
||||
constexpr unsigned C = 2048;
|
||||
constexpr unsigned HI = 14;
|
||||
constexpr unsigned WI = 14;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
// 1x1 filter, 14x14 image, C = 2048
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image, C = 512
|
||||
constexpr unsigned N = 128;
|
||||
constexpr unsigned C = 512;
|
||||
constexpr unsigned HI = 14;
|
||||
constexpr unsigned WI = 14;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#endif
|
||||
|
||||
auto lower_pads = Sequence<HPad, WPad>{};
|
||||
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
unsigned nrepeat = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[2]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
template <class TData, unsigned NSize>
|
||||
template <class TData, index_t NSize>
|
||||
struct Array
|
||||
{
|
||||
using Type = Array<TData, NSize>;
|
||||
|
||||
static constexpr unsigned nSize = NSize;
|
||||
static constexpr index_t nSize = NSize;
|
||||
|
||||
unsigned mData[nSize];
|
||||
index_t mData[nSize];
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ Array(Xs... xs) : mData{static_cast<TData>(xs)...}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
|
||||
__host__ __device__ TData operator[](index_t i) const { return mData[i]; }
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
|
||||
template <unsigned NRow_, unsigned NCol_, unsigned RowStride_>
|
||||
template <index_t NRow_, index_t NCol_, index_t RowStride_>
|
||||
struct ConstantMatrixDescriptor
|
||||
{
|
||||
__host__ __device__ constexpr ConstantMatrixDescriptor()
|
||||
@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
|
||||
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr unsigned NRow() const { return NRow_; }
|
||||
__host__ __device__ constexpr index_t NRow() const { return NRow_; }
|
||||
|
||||
__host__ __device__ constexpr unsigned NCol() const { return NCol_; }
|
||||
__host__ __device__ constexpr index_t NCol() const { return NCol_; }
|
||||
|
||||
__host__ __device__ constexpr unsigned RowStride() const { return RowStride_; }
|
||||
__host__ __device__ constexpr index_t RowStride() const { return RowStride_; }
|
||||
|
||||
__host__ __device__ constexpr auto GetLengths() const { return Sequence<NRow_, NCol_>{}; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; }
|
||||
__host__ __device__ constexpr index_t GetElementSize() const { return NRow_ * NCol_; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; }
|
||||
__host__ __device__ constexpr index_t GetElementSpace() const { return NRow_ * RowStride_; }
|
||||
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
|
||||
__host__ __device__ index_t Get1dIndex(index_t irow, index_t icol) const
|
||||
{
|
||||
#if DEVICE_BACKEND_HIP
|
||||
return __mul24(irow, RowStride_) + icol;
|
||||
#else
|
||||
return irow * RowStride_ + icol;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <unsigned SubNRow, unsigned SubNCol>
|
||||
template <index_t SubNRow, index_t SubNCol>
|
||||
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
|
||||
Number<SubNCol>) const
|
||||
{
|
||||
@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned NRow, unsigned NCol>
|
||||
template <index_t NRow, index_t NCol>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
|
||||
}
|
||||
|
||||
template <unsigned NRow, unsigned NCol, unsigned RowStride>
|
||||
template <index_t NRow, index_t NCol, index_t RowStride>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
|
||||
{
|
||||
|
||||
@@ -2,35 +2,35 @@
|
||||
#include "common.hip.hpp"
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1>
|
||||
template <index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
{
|
||||
return Sequence<L1, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <unsigned L0,
|
||||
unsigned L1,
|
||||
unsigned L2,
|
||||
unsigned L3,
|
||||
unsigned L4,
|
||||
unsigned L5,
|
||||
unsigned L6,
|
||||
unsigned L7>
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
|
||||
{
|
||||
@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1, unsigned Align>
|
||||
template <index_t L0, index_t L1, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
|
||||
constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
|
||||
return Sequence<L1_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
|
||||
constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
|
||||
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr unsigned nDim = Lengths::nDim;
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr index_t nDim = Lengths::nDim;
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
|
||||
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
|
||||
|
||||
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
|
||||
|
||||
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned GetLength(Number<I>) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetLength(Number<I>) const
|
||||
{
|
||||
return Lengths{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned GetStride(Number<I>) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetStride(Number<I>) const
|
||||
{
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
|
||||
struct GetElementSize_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
__host__ __device__ constexpr index_t GetElementSize() const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
|
||||
struct GetElementSpace_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
|
||||
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ unsigned Get1dIndex(Is... is) const
|
||||
__host__ __device__ index_t Get1dIndex(Is... is) const
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<unsigned, nDim>(is...);
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
|
||||
unsigned id = 0;
|
||||
index_t id = 0;
|
||||
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
constexpr unsigned idim = IDim.Get();
|
||||
constexpr index_t idim = IDim.Get();
|
||||
#if DEVICE_BACKEND_HIP
|
||||
id += __mul24(multi_id[idim], GetStride(IDim));
|
||||
#else
|
||||
id += multi_id[idim] * GetStride(IDim);
|
||||
#endif
|
||||
});
|
||||
|
||||
return id;
|
||||
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <unsigned IDim, unsigned NVector>
|
||||
template <index_t IDim, index_t NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, unsigned Align>
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
|
||||
@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr unsigned ndim = desc.GetDimension();
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
|
||||
|
||||
|
||||
@@ -2,38 +2,38 @@
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
template <unsigned... Is>
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
static constexpr index_t nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
const index_t mData[nDim] = {Is...};
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned Get(Number<I>) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
// this is ugly, only for nDIm = 4
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
static_assert(nDim == 4, "nDim != 4");
|
||||
|
||||
constexpr auto old_sequence = Type{};
|
||||
|
||||
constexpr unsigned NR0 = old_sequence.mData[I0];
|
||||
constexpr unsigned NR1 = old_sequence.mData[I1];
|
||||
constexpr unsigned NR2 = old_sequence.mData[I2];
|
||||
constexpr unsigned NR3 = old_sequence.mData[I3];
|
||||
constexpr index_t NR0 = old_sequence.mData[I0];
|
||||
constexpr index_t NR1 = old_sequence.mData[I1];
|
||||
constexpr index_t NR2 = old_sequence.mData[I2];
|
||||
constexpr index_t NR3 = old_sequence.mData[I3];
|
||||
|
||||
return Sequence<NR0, NR1, NR2, NR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
// don't know how to implement this
|
||||
@@ -41,7 +41,7 @@ struct Sequence
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <unsigned I>
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushBack(Number<I>) const
|
||||
{
|
||||
return Sequence<Is..., I>{};
|
||||
@@ -56,14 +56,14 @@ struct Sequence
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned... Is, unsigned I>
|
||||
template <index_t... Is, index_t I>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
{
|
||||
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <class F, unsigned... Xs, unsigned... Ys>
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
|
||||
@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <unsigned... Xs, unsigned... Ys>
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
|
||||
__host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
|
||||
{
|
||||
return x + y;
|
||||
}
|
||||
@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
|
||||
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
|
||||
}
|
||||
|
||||
template <unsigned... Is>
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
|
||||
{
|
||||
return sequence_pop_back(Type{});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc, class F>
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
unsigned did[2];
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
unsigned did[2];
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
const index_t aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc>
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct Blockwise2dTensorCopy1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1
|
||||
|
||||
// need to be aligned to float4 and float2
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
unsigned ThreadPerDim0,
|
||||
unsigned ThreadPerDim1>
|
||||
index_t ThreadPerDim0,
|
||||
index_t ThreadPerDim1>
|
||||
struct Blockwise2dTensorCopy2
|
||||
{
|
||||
unsigned mThreadId0;
|
||||
unsigned mThreadId1;
|
||||
index_t mThreadId0;
|
||||
index_t mThreadId1;
|
||||
|
||||
__device__ Blockwise2dTensorCopy2()
|
||||
{
|
||||
@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2
|
||||
constexpr bool align_v2 =
|
||||
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
|
||||
constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
|
||||
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
|
||||
|
||||
constexpr unsigned Dim1V2Loop =
|
||||
constexpr index_t Dim1V2Loop =
|
||||
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
|
||||
|
||||
constexpr unsigned Dim1V1Loop =
|
||||
constexpr index_t Dim1V1Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
|
||||
ThreadPerDim1;
|
||||
|
||||
constexpr bool d1_has_tail =
|
||||
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
|
||||
|
||||
for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop)
|
||||
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
|
||||
{
|
||||
unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0;
|
||||
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
// v4
|
||||
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
unsigned did1 =
|
||||
index_t did1 =
|
||||
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2
|
||||
// dim-1 tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2
|
||||
// dim-0 tail
|
||||
if(d0_has_tail)
|
||||
{
|
||||
unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
|
||||
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
if(did0 < L0)
|
||||
{
|
||||
|
||||
// v4
|
||||
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
|
||||
Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 +
|
||||
mThreadId1;
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2
|
||||
// tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
|
||||
Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 +
|
||||
mThreadId1;
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
const index_t sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
unsigned DataPerRead>
|
||||
index_t DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise2dTensorCopy3()
|
||||
{
|
||||
@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3
|
||||
|
||||
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
|
||||
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
|
||||
@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](unsigned iloop) {
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
};
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr unsigned GetRegisterClipboardSize() const
|
||||
__device__ constexpr index_t GetRegisterClipboardSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
|
||||
}
|
||||
@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](unsigned iloop) {
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(p_clipboard + iloop * 4)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
};
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](unsigned iloop) {
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_clipboard + iloop * 4));
|
||||
};
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc, class F>
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const unsigned did2 = is / desc.GetStride(I2);
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const unsigned did3 = is / desc.GetStride(I3);
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const unsigned did2 = is / desc.GetStride(I2);
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const unsigned did3 = is / desc.GetStride(I3);
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const index_t dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
|
||||
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
unsigned did[4];
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
unsigned did[4];
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const unsigned src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const unsigned dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
const index_t dst_index = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc>
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
unsigned DataPerRead>
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr unsigned L3 = CopyLengths{}.Get(I3);
|
||||
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr unsigned L2 = CopyLengths{}.Get(I2);
|
||||
constexpr unsigned L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr unsigned read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{});
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](unsigned is) {
|
||||
unsigned did[4];
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const unsigned src_index =
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
const unsigned dst_index =
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
@@ -315,15 +315,15 @@ template <unsigned BlockSize,
|
||||
struct BlockwiseChwnTensorCopyPadded
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src,
|
||||
unsigned c_block_data_begin,
|
||||
unsigned ho_block_data_begin,
|
||||
unsigned wo_block_data_begin,
|
||||
unsigned n_block_data_begin,
|
||||
index_t c_block_data_begin,
|
||||
index_t ho_block_data_begin,
|
||||
index_t wo_block_data_begin,
|
||||
index_t n_block_data_begin,
|
||||
Float* __restrict__ p_dst,
|
||||
unsigned h_block_pad_low,
|
||||
unsigned w_block_pad_low,
|
||||
unsigned h_block_pad_up,
|
||||
unsigned w_block_pad_up) const
|
||||
index_t h_block_pad_low,
|
||||
index_t w_block_pad_low,
|
||||
index_t h_block_pad_up,
|
||||
index_t w_block_pad_up) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
|
||||
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
const Float* p_src_tmp =
|
||||
p_src +
|
||||
@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
index_t is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
unsigned did[4];
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
|
||||
@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
index_t is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
unsigned did[4];
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
const index_t bindex = dst_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low ||
|
||||
@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
unsigned DataPerRead>
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopy3()
|
||||
{
|
||||
@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr unsigned L2 = CopyLengths{}.Get(I2);
|
||||
constexpr unsigned L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr unsigned num_active_thread =
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
const unsigned thread_id_d0 =
|
||||
const index_t thread_id_d0 =
|
||||
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
unsigned itmp = get_thread_local_1d_id() -
|
||||
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
const unsigned thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
|
||||
index_t itmp = get_thread_local_1d_id() -
|
||||
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
|
||||
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
|
||||
const unsigned thread_id_d2 = itmp / thread_per_d3;
|
||||
const unsigned thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
|
||||
const index_t thread_id_d2 = itmp / thread_per_d3;
|
||||
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
|
||||
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
|
||||
@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
constexpr unsigned L2 = CopyLengths{}.Get(I2);
|
||||
constexpr unsigned L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr unsigned thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr unsigned thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr unsigned thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr unsigned thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr unsigned num_active_thread =
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr unsigned nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr unsigned nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr unsigned nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(unsigned iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(unsigned iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const unsigned src_offset =
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const unsigned dst_offset =
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
#pragma once
|
||||
#include "threadwise_gemm.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
unsigned BlockMatrixStrideA,
|
||||
unsigned BlockMatrixStrideB,
|
||||
unsigned ThreadMatrixStrideC,
|
||||
unsigned BatchSize,
|
||||
unsigned BatchPerThread,
|
||||
unsigned KPerThreadLoop,
|
||||
index_t BlockMatrixStrideA,
|
||||
index_t BlockMatrixStrideB,
|
||||
index_t ThreadMatrixStrideC,
|
||||
index_t BatchSize,
|
||||
index_t BatchPerThread,
|
||||
index_t KPerThreadLoop,
|
||||
bool DistributeThreadAlongColumnFirst>
|
||||
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned batch;
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t batch;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
|
||||
@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! k dimension not consistent!");
|
||||
|
||||
constexpr unsigned MPerBlock = a_block_mtx.NCol();
|
||||
constexpr unsigned NPerBlock = b_block_mtx.NCol();
|
||||
constexpr index_t MPerBlock = a_block_mtx.NCol();
|
||||
constexpr index_t NPerBlock = b_block_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0");
|
||||
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
|
||||
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
|
||||
|
||||
constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
|
||||
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
|
||||
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
|
||||
constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
|
||||
static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
|
||||
"wrong! wrong BlockSize");
|
||||
@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
if(DistributeThreadAlongColumnFirst)
|
||||
{
|
||||
// num of operations can be reduced
|
||||
const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork);
|
||||
unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
|
||||
const unsigned m_work_id = itmp / NThreadWork;
|
||||
const unsigned n_work_id = itmp - m_work_id * NThreadWork;
|
||||
const index_t b_work_id = thread_id / (MThreadWork * NThreadWork);
|
||||
index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
|
||||
const index_t m_work_id = itmp / NThreadWork;
|
||||
const index_t n_work_id = itmp - m_work_id * NThreadWork;
|
||||
|
||||
return MatrixIndex{
|
||||
b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread};
|
||||
@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex
|
||||
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
|
||||
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
|
||||
{
|
||||
return MatrixIndex{batch_in_c, m_in_c, n_in_c};
|
||||
}
|
||||
@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// a is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// read first batch of a, b
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
b_thread_mtx.GetLengths());
|
||||
|
||||
// loop over batch
|
||||
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
{
|
||||
// do current batch of gemm
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
unsigned BlockMatrixStrideA,
|
||||
unsigned BlockMatrixStrideB,
|
||||
unsigned ThreadMatrixStrideC,
|
||||
unsigned BatchSize,
|
||||
unsigned MPerThreadSubC,
|
||||
unsigned NPerThreadSubC,
|
||||
unsigned MLevel0Cluster,
|
||||
unsigned NLevel0Cluster,
|
||||
unsigned MLevel1Cluster,
|
||||
unsigned NLevel1Cluster,
|
||||
unsigned KPerThreadLoop,
|
||||
unsigned BatchPerThread>
|
||||
index_t BlockMatrixStrideA,
|
||||
index_t BlockMatrixStrideB,
|
||||
index_t ThreadMatrixStrideC,
|
||||
index_t BatchSize,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t BatchPerThread>
|
||||
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned batch;
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t batch;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
|
||||
@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
static_assert(BatchSize % BatchPerThread == 0,
|
||||
"wrong! BatchSize is not dividable by BatchPerThread");
|
||||
|
||||
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
|
||||
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
|
||||
|
||||
constexpr unsigned ThreadPerLevel1Cluster =
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
|
||||
@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
|
||||
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
|
||||
|
||||
constexpr unsigned ThreadPerLevel1Cluster =
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster;
|
||||
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
|
||||
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
|
||||
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
|
||||
|
||||
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster;
|
||||
unsigned level1_m_id = level1_id / NLevel1Cluster;
|
||||
unsigned level1_n_id = level1_id % NLevel1Cluster;
|
||||
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster;
|
||||
unsigned level0_m_id = level0_id / NLevel0Cluster;
|
||||
unsigned level0_n_id = level0_id % NLevel0Cluster;
|
||||
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{batch_work_id * BatchPerThread,
|
||||
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex
|
||||
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
|
||||
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
unsigned m_repeat = m_in_c / MPerThreadSubC;
|
||||
unsigned n_repeat = n_in_c / NPerThreadSubC;
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{batch_in_c,
|
||||
m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// loop over k
|
||||
#pragma unroll
|
||||
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// read first batch of A, B
|
||||
// copy A-sub to form A
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
// copy B-sub to form B
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
// loop over batch
|
||||
#pragma unroll
|
||||
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
{
|
||||
// do current batch of gemm
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
if(BlockMatrixStrideA != 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
if(BlockMatrixStrideB != 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// loop over k
|
||||
//#pragma unroll
|
||||
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// read first batch of A, B
|
||||
// copy A-sub to form A
|
||||
//#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
|
||||
{
|
||||
#if 1
|
||||
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
|
||||
{
|
||||
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
|
||||
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
|
||||
@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
// copy B-sub to form B
|
||||
//#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
|
||||
{
|
||||
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
|
||||
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
|
||||
@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
// loop over batch
|
||||
//#pragma unroll
|
||||
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
{
|
||||
// do current batch of gemm
|
||||
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
|
||||
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
|
||||
{
|
||||
#if 0
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
|
||||
{
|
||||
const unsigned aindex =
|
||||
const index_t aindex =
|
||||
a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
|
||||
const unsigned cindex =
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
|
||||
const index_t cindex =
|
||||
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
|
||||
|
||||
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
|
||||
@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
|
||||
"asm is only for 16x4");
|
||||
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
if(BlockMatrixStrideA != 0)
|
||||
{
|
||||
//#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
|
||||
{
|
||||
p_a_thread[a_thread_mtx.Get1dIndex(i,
|
||||
m_repeat * MPerThreadSubC + j)] =
|
||||
@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
if(BlockMatrixStrideB != 0)
|
||||
{
|
||||
//#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
|
||||
{
|
||||
p_b_thread[b_thread_mtx.Get1dIndex(i,
|
||||
n_repeat * NPerThreadSubC + j)] =
|
||||
@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
}
|
||||
|
||||
// do last batch of gemm
|
||||
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
|
||||
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
|
||||
{
|
||||
#if 0
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
|
||||
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
|
||||
{
|
||||
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
|
||||
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
|
||||
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
|
||||
(BatchPerThread - 1) * ThreadMatrixStrideC;
|
||||
|
||||
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
|
||||
@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
|
||||
"asm is only for 16x4");
|
||||
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned cindex =
|
||||
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t cindex =
|
||||
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
|
||||
|
||||
asm volatile("\n \
|
||||
@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
}
|
||||
}
|
||||
|
||||
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
|
||||
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
{
|
||||
constexpr auto c_block_mtx = BlockMatrixC{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned c_thread_offset =
|
||||
const index_t c_thread_offset =
|
||||
c_thread_mtx_begin.batch * BlockMatrixStrideC +
|
||||
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
|
||||
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_direct_convolution.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class InBlockDesc,
|
||||
class WeiBlockDesc,
|
||||
class OutBlockDesc,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread>
|
||||
__device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
Float* const __restrict__ p_in_block,
|
||||
WeiBlockDesc,
|
||||
@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
constexpr auto wei_block_desc = WeiBlockDesc{};
|
||||
constexpr auto out_block_desc = OutBlockDesc{};
|
||||
|
||||
constexpr unsigned Y = wei_block_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_block_desc.GetLength(I3);
|
||||
constexpr index_t Y = wei_block_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_block_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned InTileSizeH = HoPerThread + Y - 1;
|
||||
constexpr unsigned InTileSizeW = WoPerThread + X - 1;
|
||||
constexpr index_t InTileSizeH = HoPerThread + Y - 1;
|
||||
constexpr index_t InTileSizeW = WoPerThread + X - 1;
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
|
||||
constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
|
||||
constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
|
||||
constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
|
||||
constexpr index_t NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
constexpr auto out_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
|
||||
|
||||
const unsigned thread_id = threadIdx.x;
|
||||
const index_t thread_id = threadIdx.x;
|
||||
|
||||
for(unsigned thread_work_id = thread_id;
|
||||
for(index_t thread_work_id = thread_id;
|
||||
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
|
||||
thread_work_id += BlockSize)
|
||||
{
|
||||
unsigned itmp = thread_work_id;
|
||||
unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
|
||||
index_t itmp = thread_work_id;
|
||||
index_t n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
|
||||
unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
|
||||
index_t k_thread_work_id = itmp / (YThreadWork * XThreadWork);
|
||||
itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
|
||||
unsigned y_thread_work_id = itmp / XThreadWork;
|
||||
unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
|
||||
index_t y_thread_work_id = itmp / XThreadWork;
|
||||
index_t x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
|
||||
|
||||
unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread;
|
||||
unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread;
|
||||
index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
index_t ho_thread_data_begin = y_thread_work_id * HoPerThread;
|
||||
index_t wo_thread_data_begin = x_thread_work_id * WoPerThread;
|
||||
|
||||
unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
|
||||
unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding
|
||||
index_t hi_thread_data_begin = ho_thread_data_begin; // minus padding
|
||||
index_t wi_thread_data_begin = wo_thread_data_begin; // minus padding
|
||||
|
||||
Float p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
p_out_thread,
|
||||
out_thread_desc.GetLengths());
|
||||
|
||||
for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
|
||||
for(index_t c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
|
||||
c_thread_data_begin += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#pragma once
|
||||
#include "threadwise_gemm.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
unsigned KPerThreadLoop,
|
||||
unsigned MThreadPerCluster,
|
||||
unsigned NThreadPerCluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t MThreadPerCluster,
|
||||
index_t NThreadPerCluster,
|
||||
bool DistributeThreadAlongColumnFirst>
|
||||
struct BlockwiseGemmBlockABlockBThreadC
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadC()
|
||||
@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! k dimension not consistent!");
|
||||
|
||||
constexpr unsigned MPerBlock = a_block_mtx.NCol();
|
||||
constexpr unsigned NPerBlock = b_block_mtx.NCol();
|
||||
constexpr index_t MPerBlock = a_block_mtx.NCol();
|
||||
constexpr index_t NPerBlock = b_block_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
|
||||
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
|
||||
@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
|
||||
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
|
||||
|
||||
constexpr unsigned MClusterWork =
|
||||
constexpr index_t MClusterWork =
|
||||
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
|
||||
|
||||
constexpr unsigned NClusterWork =
|
||||
constexpr index_t NClusterWork =
|
||||
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
|
||||
|
||||
static_assert(BlockSize ==
|
||||
@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
|
||||
if(DistributeThreadAlongColumnFirst)
|
||||
{
|
||||
const unsigned cluster_work_block_id =
|
||||
const index_t cluster_work_block_id =
|
||||
thread_id / (MThreadPerCluster * NThreadPerCluster);
|
||||
|
||||
const unsigned thread_work_cluster_id =
|
||||
const index_t thread_work_cluster_id =
|
||||
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
|
||||
|
||||
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
|
||||
const unsigned n_cluster_work_block_id =
|
||||
const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
|
||||
const index_t n_cluster_work_block_id =
|
||||
cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
|
||||
|
||||
const unsigned m_thread_work_cluster_id =
|
||||
thread_work_cluster_id / NThreadPerCluster;
|
||||
const unsigned n_thread_work_cluster_id =
|
||||
const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
|
||||
const index_t n_thread_work_cluster_id =
|
||||
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
|
||||
|
||||
#if 0
|
||||
@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
return MatrixIndex{m_in_c, n_in_c};
|
||||
}
|
||||
@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// a is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
unsigned MPerThreadSubC,
|
||||
unsigned NPerThreadSubC,
|
||||
unsigned MLevel0Cluster,
|
||||
unsigned NLevel0Cluster,
|
||||
unsigned MLevel1Cluster,
|
||||
unsigned NLevel1Cluster,
|
||||
unsigned KPerThreadLoop>
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
unsigned mMyThreadOffsetA;
|
||||
unsigned mMyThreadOffsetB;
|
||||
index_t mMyThreadOffsetA;
|
||||
index_t mMyThreadOffsetB;
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel1Cluster =
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
unsigned level1_m_id = level1_id / NLevel1Cluster;
|
||||
unsigned level1_n_id = level1_id % NLevel1Cluster;
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
unsigned level0_m_id = level0_id / NLevel0Cluster;
|
||||
unsigned level0_n_id = level0_id % NLevel0Cluster;
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
unsigned m_repeat = m_in_c / MPerThreadSubC;
|
||||
unsigned n_repeat = n_in_c / NPerThreadSubC;
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
//#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
//#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
False,
|
||||
p_c_thread,
|
||||
f_accum);
|
||||
#else
|
||||
#elif 0
|
||||
// inline asm
|
||||
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
|
||||
"asm is only for 8x8");
|
||||
|
||||
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
|
||||
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
|
||||
{
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %8, %9 \n \
|
||||
@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// preload A, B
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
bool even_loop = true;
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
k_begin += KPerThreadLoop, even_loop = !even_loop)
|
||||
{ // loop over k
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
// preload next A, B
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A-sub, B-sub, C-sub
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// C-sub(s) in first row-wise subblock of C
|
||||
{
|
||||
@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
#pragma unroll
|
||||
// copy next B-sub, and do GEMM
|
||||
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
#pragma unroll
|
||||
// loop over rest of row-wise subblock
|
||||
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
|
||||
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
// copy a A-sub
|
||||
threadwise_matrix_copy(
|
||||
@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
|
||||
// do some GEMMs
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_gemm(
|
||||
a_thread_sub_mtx,
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
#include "Array.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr unsigned integer_divide_ceil(unsigned a, unsigned b)
|
||||
__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
|
||||
{
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
@@ -11,3 +11,5 @@
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
|
||||
using index_t = uint32_t;
|
||||
|
||||
@@ -8,5 +8,5 @@ struct integral_constant
|
||||
__host__ __device__ constexpr T Get() const { return value; }
|
||||
};
|
||||
|
||||
template <unsigned N>
|
||||
using Number = integral_constant<unsigned, N>;
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "config.h"
|
||||
|
||||
template <class T, unsigned N>
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
|
||||
template <unsigned NLoop>
|
||||
template <index_t NLoop>
|
||||
struct static_loop_n
|
||||
{
|
||||
template <class F>
|
||||
@@ -24,7 +24,7 @@ struct static_loop_n<1>
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned NLoop>
|
||||
template <index_t NLoop>
|
||||
struct static_const_reduce_n
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
|
||||
@@ -8,18 +8,18 @@ template <class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t BlockSize,
|
||||
index_t GridSize>
|
||||
__global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
@@ -33,16 +33,16 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
constexpr auto wei_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned Y = wei_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_global_desc.GetLength(I3);
|
||||
constexpr index_t Y = wei_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
|
||||
@@ -59,31 +59,31 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
constexpr auto out_block_desc =
|
||||
make_ConstantTensorDescriptor(out_block_global_desc.GetLengths());
|
||||
|
||||
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace();
|
||||
constexpr unsigned out_block_size = out_block_desc.GetElementSpace();
|
||||
constexpr index_t in_block_size = in_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_size = wei_block_desc.GetElementSpace();
|
||||
constexpr index_t out_block_size = out_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
__shared__ Float p_out_block[out_block_size];
|
||||
|
||||
const unsigned block_id = blockIdx.x;
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
unsigned itmp = block_id;
|
||||
unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
index_t itmp = block_id;
|
||||
index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
unsigned h_block_work_id = itmp / WBlockWork;
|
||||
unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
index_t h_block_work_id = itmp / WBlockWork;
|
||||
index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
unsigned n_block_work_begin = n_block_work_id * NPerBlock;
|
||||
unsigned k_block_work_begin = k_block_work_id * KPerBlock;
|
||||
unsigned ho_block_work_begin = h_block_work_id * HoPerBlock;
|
||||
unsigned wo_block_work_begin = w_block_work_id * WoPerBlock;
|
||||
index_t n_block_work_begin = n_block_work_id * NPerBlock;
|
||||
index_t k_block_work_begin = k_block_work_id * KPerBlock;
|
||||
index_t ho_block_work_begin = h_block_work_id * HoPerBlock;
|
||||
index_t wo_block_work_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
unsigned hi_block_work_begin = ho_block_work_begin; // minus padding
|
||||
unsigned wi_block_work_begin = wo_block_work_begin; // minus padding
|
||||
index_t hi_block_work_begin = ho_block_work_begin; // minus padding
|
||||
index_t wi_block_work_begin = wo_block_work_begin; // minus padding
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
@@ -109,7 +109,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
// set output tensor in LDS to 0
|
||||
blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block);
|
||||
|
||||
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
|
||||
for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
|
||||
c_block_work_begin += CPerBlock)
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
|
||||
@@ -11,20 +11,20 @@ template <class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t BlockSize,
|
||||
index_t GridSize>
|
||||
__global__ void
|
||||
gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -39,17 +39,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_global_desc.GetLength(I0);
|
||||
constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3);
|
||||
constexpr index_t N = in_nchw_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
@@ -63,21 +63,21 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size =
|
||||
constexpr index_t in_block_size =
|
||||
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
constexpr unsigned wei_block_size =
|
||||
constexpr index_t wei_block_size =
|
||||
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr unsigned WiPerThread = WoPerThread + X - 1;
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
@@ -93,56 +93,54 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr unsigned NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork =
|
||||
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork =
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const unsigned block_id = blockIdx.x;
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
unsigned itmp = block_id;
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const unsigned thread_id = threadIdx.x;
|
||||
const index_t thread_id = threadIdx.x;
|
||||
|
||||
itmp = thread_id;
|
||||
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const unsigned h_thread_work_id = itmp / WThreadWork;
|
||||
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const unsigned hi_thread_data_begin = ho_thread_data_begin;
|
||||
const unsigned wi_thread_data_begin = wo_thread_data_begin;
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
@@ -172,7 +170,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
@@ -191,7 +189,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
|
||||
@@ -13,21 +13,21 @@ template <class TInWei,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned ScalarPerVector,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
index_t ScalarPerVector,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t BlockSize,
|
||||
index_t GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
|
||||
@@ -49,17 +49,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_vec_global_desc.GetLength(I0);
|
||||
constexpr unsigned K = wei_kcyx_vec_global_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_vec_global_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_vec_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_vec_global_desc.GetLength(I3);
|
||||
constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
@@ -73,15 +73,15 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size =
|
||||
constexpr index_t in_block_size =
|
||||
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
constexpr index_t wei_block_size =
|
||||
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ in_vector_mem_t
|
||||
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
@@ -89,8 +89,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr unsigned WiPerThread = WoPerThread + X - 1;
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
@@ -106,56 +106,54 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr unsigned NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork =
|
||||
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork =
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const unsigned block_id = blockIdx.x;
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
unsigned itmp = block_id;
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const unsigned thread_id = threadIdx.x;
|
||||
const index_t thread_id = threadIdx.x;
|
||||
|
||||
itmp = thread_id;
|
||||
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const unsigned h_thread_work_id = itmp / WThreadWork;
|
||||
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const unsigned hi_thread_data_begin = ho_thread_data_begin;
|
||||
const unsigned wi_thread_data_begin = wo_thread_data_begin;
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
@@ -188,7 +186,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
#endif
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
@@ -207,7 +205,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
|
||||
@@ -8,32 +8,32 @@
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
class InBlockCopyThreadPerDims,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead,
|
||||
unsigned GemmMPerThreadSubC,
|
||||
unsigned GemmNPerThreadSubC,
|
||||
unsigned GemmMLevel0Cluster,
|
||||
unsigned GemmNLevel0Cluster,
|
||||
unsigned GemmMLevel1Cluster,
|
||||
unsigned GemmNLevel1Cluster,
|
||||
unsigned GemmKPerThreadLoop,
|
||||
unsigned OutThreadCopyDataPerWrite>
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -55,39 +55,39 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
|
||||
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const unsigned w_block_work_id = itmp / NBlockWork;
|
||||
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
@@ -164,15 +164,15 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
HoPerThread>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size =
|
||||
constexpr index_t in_block_size =
|
||||
in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
constexpr index_t wei_block_size =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
@@ -191,10 +191,10 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
const Float* p_wei_global_block_begin =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
|
||||
p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
|
||||
p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// input: global mem to LDS
|
||||
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
|
||||
@@ -205,9 +205,9 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 0
|
||||
blockwise_batch_gemm.Run
|
||||
@@ -227,26 +227,26 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const unsigned ho_thread =
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const unsigned wo_thread = b_thread / NPerBlock;
|
||||
const unsigned n_thread = b_thread % NPerBlock;
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
@@ -261,19 +261,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
|
||||
// this is for v2 GEMM
|
||||
// output is a 8d tensor
|
||||
if(NPerThread < NPerBlock && WoPerThread == 1)
|
||||
{
|
||||
constexpr unsigned N1_ = GemmNPerThreadSubC;
|
||||
constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
|
||||
constexpr unsigned K2_ = GemmMPerThreadSubC;
|
||||
constexpr unsigned K1_ = KPerBlock / KPerThread;
|
||||
constexpr index_t N1_ = GemmNPerThreadSubC;
|
||||
constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
|
||||
constexpr index_t K2_ = GemmMPerThreadSubC;
|
||||
constexpr index_t K1_ = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{});
|
||||
|
||||
@@ -7,26 +7,26 @@
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class LowerPads,
|
||||
class UpperPads,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1>
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1>
|
||||
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -48,42 +48,42 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
|
||||
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr unsigned N = out_khwn_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr unsigned WPadLow = LowerPads{}.Get(I1);
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
|
||||
constexpr unsigned HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr unsigned WPadUp = UpperPads{}.Get(I1);
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const unsigned w_block_work_id = itmp / NBlockWork;
|
||||
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
// flattened (2d) tensor view of wei in global mem
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
@@ -114,11 +114,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
|
||||
const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
|
||||
const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
|
||||
const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
|
||||
|
||||
const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
|
||||
const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
|
||||
const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
|
||||
const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
@@ -204,8 +204,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_cyxk_block_desc.GetElementSpace();
|
||||
constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
@@ -219,9 +219,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
const Float* p_wei_global_block_begin =
|
||||
p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
@@ -245,9 +245,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
@@ -262,10 +262,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
const index_t ho_thread_data_begin = matrix_c_index.batch;
|
||||
const index_t k_thread_data_begin = matrix_c_index.row;
|
||||
const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
|
||||
@@ -8,32 +8,32 @@
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned BPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned GemmMPerThreadSubC,
|
||||
unsigned GemmNPerThreadSubC,
|
||||
unsigned GemmMLevel0Cluster,
|
||||
unsigned GemmNLevel0Cluster,
|
||||
unsigned GemmMLevel1Cluster,
|
||||
unsigned GemmNLevel1Cluster,
|
||||
unsigned GemmKPerThreadLoop,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmThreadPerColumnPerCluster,
|
||||
index_t GemmThreadPerRowPerCluster,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -48,30 +48,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr unsigned N = in_chwn_global_desc.GetLength(I3);
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned B = N * Hi * Wi;
|
||||
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
@@ -192,15 +192,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
GemmKPerThreadLoop>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size =
|
||||
constexpr index_t in_block_size =
|
||||
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
constexpr index_t wei_block_size =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
// LDS
|
||||
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
@@ -218,10 +218,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// load data
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
@@ -231,18 +231,16 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 1
|
||||
#elif 0
|
||||
blockwise_gemm.Run_asm
|
||||
#elif 0
|
||||
blockwise_gemm.Run_v2
|
||||
#elif 0
|
||||
#elif 1
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
@@ -257,23 +255,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
unsigned h_data = b_data / (Wi * N);
|
||||
unsigned itmp = b_data - h_data * (Wi * N);
|
||||
unsigned w_data = itmp / N;
|
||||
unsigned n_data = itmp - w_data * N;
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
|
||||
@@ -8,32 +8,32 @@
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned BPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned GemmMPerThreadSubC,
|
||||
unsigned GemmNPerThreadSubC,
|
||||
unsigned GemmMLevel0Cluster,
|
||||
unsigned GemmNLevel0Cluster,
|
||||
unsigned GemmMLevel1Cluster,
|
||||
unsigned GemmNLevel1Cluster,
|
||||
unsigned GemmKPerThreadLoop,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmThreadPerColumnPerCluster,
|
||||
index_t GemmThreadPerRowPerCluster,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
__global__ void
|
||||
#if 0
|
||||
__launch_bounds__(256,2)
|
||||
@@ -52,30 +52,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr unsigned N = in_chwn_global_desc.GetLength(I3);
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2);
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned B = N * Hi * Wi;
|
||||
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
@@ -210,15 +210,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
#endif
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size =
|
||||
constexpr index_t in_block_size =
|
||||
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
constexpr index_t wei_block_size =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
@@ -248,11 +248,11 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
|
||||
bool even_loop = true;
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
even_loop = !even_loop)
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
even_loop = !even_loop)
|
||||
{
|
||||
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
|
||||
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
|
||||
@@ -279,12 +279,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
@@ -309,12 +309,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
@@ -331,8 +331,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
@@ -348,20 +348,20 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
unsigned h_data = b_data / (Wi * N);
|
||||
unsigned itmp = b_data - h_data * (Wi * N);
|
||||
unsigned w_data = itmp / N;
|
||||
unsigned n_data = itmp - w_data * N;
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
|
||||
@@ -16,11 +16,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p[dindex]);
|
||||
}
|
||||
@@ -47,22 +47,22 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
const unsigned aindex = src_desc.Get1dIndex(did0, did1);
|
||||
const index_t aindex = src_desc.Get1dIndex(did0, did1);
|
||||
|
||||
const unsigned did[2] = {did0, did1};
|
||||
const index_t did[2] = {did0, did1};
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
@@ -118,21 +118,21 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned nshift = NShift::mValue;
|
||||
constexpr index_t nshift = NShift::mValue;
|
||||
|
||||
constexpr unsigned did0_end =
|
||||
constexpr index_t did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
constexpr unsigned did1_end =
|
||||
constexpr index_t did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
for(unsigned did0 = 0; did0 < did0_end; ++did0)
|
||||
for(index_t did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < did1_end; ++did1)
|
||||
for(index_t did1 = 0; did1 < did1_end; ++did1)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1);
|
||||
const index_t dindex = desc.Get1dIndex(did0, did1);
|
||||
|
||||
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
|
||||
@@ -18,15 +18,15 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2)
|
||||
for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3)
|
||||
for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p[dindex]);
|
||||
}
|
||||
@@ -58,28 +58,28 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned did[4] = {did0, did1, did2, did3};
|
||||
const index_t did[4] = {did0, did1, did2, did3};
|
||||
|
||||
const unsigned bindex =
|
||||
const index_t bindex =
|
||||
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
@@ -129,7 +129,7 @@ __device__ void threadwise_4d_tensor_copy(
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
@@ -163,24 +163,24 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L3 = SrcOpLengths{}.Get(I3);
|
||||
constexpr index_t L3 = SrcOpLengths{}.Get(I3);
|
||||
|
||||
static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d3 = L3 / DataPerRead;
|
||||
constexpr index_t nloop_d3 = L3 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const unsigned src_index =
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
const unsigned dst_index =
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
if(DataPerRead == 1)
|
||||
@@ -224,31 +224,31 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned nshift = NShift::mValue;
|
||||
constexpr index_t nshift = NShift::mValue;
|
||||
|
||||
constexpr unsigned did0_end =
|
||||
constexpr index_t did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
constexpr unsigned did1_end =
|
||||
constexpr index_t did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
constexpr unsigned did2_end =
|
||||
constexpr index_t did2_end =
|
||||
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned did3_end =
|
||||
constexpr index_t did3_end =
|
||||
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
|
||||
|
||||
for(unsigned did0 = 0; did0 < did0_end; ++did0)
|
||||
for(index_t did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < did1_end; ++did1)
|
||||
for(index_t did1 = 0; did1 < did1_end; ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < did2_end; ++did2)
|
||||
for(index_t did2 = 0; did2 < did2_end; ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < did3_end; ++did3)
|
||||
for(index_t did3 = 0; did3 < did3_end; ++did3)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
|
||||
@@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
{
|
||||
for(unsigned k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
{
|
||||
for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
{
|
||||
for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
{
|
||||
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
{
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
const unsigned hi = ho + y;
|
||||
const unsigned wi = wo + x;
|
||||
const index_t hi = ho + y;
|
||||
const index_t wi = wo + x;
|
||||
|
||||
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
|
||||
const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi);
|
||||
|
||||
const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x);
|
||||
const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x);
|
||||
|
||||
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
|
||||
fused_multiply_accumulate(
|
||||
p_out[out_index], p_wei[wei_index], p_in[in_index]);
|
||||
@@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
Data p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr unsigned in_w_new_read = 1;
|
||||
constexpr index_t in_w_new_read = 1;
|
||||
|
||||
constexpr auto in_desc_reg_new_read =
|
||||
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
|
||||
@@ -136,7 +136,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
@@ -157,7 +157,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
@@ -186,10 +186,10 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
template <class Float, class SrcMatrix, class DstMatrix, unsigned NRow, unsigned NCol>
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
DstMatrix,
|
||||
@@ -10,16 +10,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
constexpr auto src_mtx = SrcMatrix{};
|
||||
constexpr auto dst_mtx = DstMatrix{};
|
||||
|
||||
for(unsigned i = 0; i < NRow; ++i)
|
||||
#if 0
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < NCol; ++j)
|
||||
for(index_t j = 0; j < NCol; ++j)
|
||||
{
|
||||
const unsigned src_index = src_mtx.Get1dIndex(i, j);
|
||||
const unsigned dst_index = dst_mtx.Get1dIndex(i, j);
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, j);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
static_assert(NCol == 4, "only for NCol == 4");
|
||||
|
||||
using vector_t = typename vector_type<Float, 4>::MemoryType;
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, 0);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
|
||||
|
||||
#if 1
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
#elif 1
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1, offset:0 \n \
|
||||
"
|
||||
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index)))
|
||||
: "v"((uint32_t)(p_src + src_index)));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class MatrixA,
|
||||
@@ -49,21 +72,31 @@ __device__ void threadwise_gemm(MatrixA,
|
||||
constexpr auto b_mtx = MatrixB{};
|
||||
constexpr auto c_mtx = MatrixC{};
|
||||
|
||||
constexpr unsigned M = c_mtx.NRow();
|
||||
constexpr unsigned N = c_mtx.NCol();
|
||||
constexpr unsigned K = a_mtx.NRow(); // A is transposed
|
||||
constexpr index_t M = c_mtx.NRow();
|
||||
constexpr index_t N = c_mtx.NCol();
|
||||
constexpr index_t K = a_mtx.NRow(); // A is transposed
|
||||
|
||||
for(unsigned k = 0; k < K; ++k)
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(unsigned i = 0; i < M; ++i)
|
||||
for(index_t i = 0; i < M; ++i)
|
||||
{
|
||||
for(unsigned j = 0; j < N; ++j)
|
||||
for(index_t j = 0; j < N; ++j)
|
||||
{
|
||||
const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned bindex = b_mtx.Get1dIndex(k, j);
|
||||
const unsigned cindex = c_mtx.Get1dIndex(i, j);
|
||||
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_mtx.Get1dIndex(k, j);
|
||||
const index_t cindex = c_mtx.Get1dIndex(i, j);
|
||||
|
||||
#if 0
|
||||
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
|
||||
#elif 1
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(p_c_thread[cindex])
|
||||
: "v"(p_a_thread[aindex]),
|
||||
"v"(p_b_thread[bindex]),
|
||||
"0"(p_c_thread[cindex]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
@@ -37,28 +37,28 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
DstDesc{}.GetStride(I4) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L5 = SrcOpLengths{}.Get(I5);
|
||||
constexpr index_t L5 = SrcOpLengths{}.Get(I5);
|
||||
|
||||
static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d5 = L5 / DataPerRead;
|
||||
constexpr index_t nloop_d5 = L5 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
|
||||
for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
|
||||
{
|
||||
const unsigned src_index = src_desc.Get1dIndex(
|
||||
const index_t src_index = src_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
|
||||
const unsigned dst_index = dst_desc.Get1dIndex(
|
||||
const index_t dst_index = dst_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
@@ -72,7 +72,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_8d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
@@ -109,29 +109,29 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
|
||||
DstDesc{}.GetStride(I6) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L7 = SrcOpLengths{}.Get(I7);
|
||||
constexpr index_t L7 = SrcOpLengths{}.Get(I7);
|
||||
|
||||
static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d7 = L7 / DataPerRead;
|
||||
constexpr index_t nloop_d7 = L7 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
|
||||
for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
|
||||
{
|
||||
for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
|
||||
for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
|
||||
{
|
||||
for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
|
||||
for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
|
||||
{
|
||||
const unsigned src_index =
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
@@ -141,7 +141,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
|
||||
did6,
|
||||
iloop_d7 * DataPerRead);
|
||||
|
||||
const unsigned dst_index =
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
|
||||
Reference in New Issue
Block a user