add another blockwise gemm

This commit is contained in:
Chao Liu
2019-02-06 23:10:08 -06:00
parent 5e77650415
commit 1b323316a8
14 changed files with 840 additions and 111 deletions

View File

@@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
struct MatrixIndex
{
unsigned batch_begin;
unsigned row_begin;
unsigned col_begin;
unsigned batch;
unsigned row;
unsigned col;
};
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
@@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA +
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin));
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row));
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB +
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0));
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
@@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin,
c_thread_mtx_index.row_begin,
c_thread_mtx_index.col_begin,
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC
struct MatrixIndex
{
unsigned row_begin;
unsigned col_begin;
unsigned row;
unsigned col;
};
__device__ BlockwiseGemmBlockABlockBThreadC()
@@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin);
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0);
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
@@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin,
c_thread_mtx_index.row_begin,
c_thread_mtx_index.col_begin,
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC
}
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
{
return MatrixIndex{m_in_c, n_in_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block,
FloatB* const p_b_block,
@@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC
}
}
};
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
unsigned row;
unsigned col;
};
unsigned mMyThreadOffsetA;
unsigned mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr unsigned ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
"wrong! Cannot evenly divide work among Level0Cluster\n");
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
"wrong! thread work size is wrong\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
{
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
{
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = Constant<bool, true>{};
constexpr auto False = Constant<bool, false>{};
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
// thread A-sub, B-sub for copy
const auto a_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<MPerThreadSubC>{},
Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_sub_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
Number<NPerThreadSubC>{},
Number<NPerThread>{}); // constexpr doesn't compile
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
k_begin * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
}
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
k_begin * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread,
f_accum);
}
}
};

View File

@@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",

View File

@@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",

View File

@@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",

View File

@@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
#if 0
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin);
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col);
#endif
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerThread;
#if 1
// output: register to global mem,

View File

@@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
}
const auto matrix_c_index =
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
const unsigned ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]

View File

@@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
@@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
}
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
@@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row_begin,
matrix_c_index.col_begin,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);

View File

@@ -0,0 +1,327 @@
#pragma once
#include "common.cuh"
#include "ConstantTensorDescriptor.cuh"
#include "ConstantMatrixDescriptor.cuh"
#include "blockwise_4d_tensor_op.cuh"
#include "blockwise_2d_tensor_op.cuh"
#include "threadwise_2d_tensor_op.cuh"
#include "blockwise_gemm.cuh"
// define B = flatten(N, Hi, Wi)
template <unsigned GridSize,
unsigned BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned BPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead>
__global__ void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,
OutGlobalDesc,
Float* __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_cnhw_global_desc = InGlobalDesc{};
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
constexpr auto out_knhw_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_cnhw_global_desc.GetLength(I0);
constexpr unsigned N = in_cnhw_global_desc.GetLength(I1);
constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2);
constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3);
constexpr unsigned K = out_knhw_global_desc.GetLength(I0);
constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2);
constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3);
constexpr unsigned S = wei_csrk_global_desc.GetLength(I1);
constexpr unsigned R = wei_csrk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi;
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
// divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
// tensor view of threadwise output in register
constexpr auto out_kb_thread_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc");
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc");
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
printf("KPerBlock %u\n", KPerBlock);
}
#endif
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyThreadPerDim0,
InBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy
// format is [CPerBlock*S*R,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1>{};
#elif 1
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<KPerBlock>{},
Number<wei_csrk_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<BPerBlock>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
#if 0
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
true,
false,
false,
GemmKPerThreadLoop,
GemmThreadPerColumnPerCluster,
GemmThreadPerRowPerCluster,
true>{};
#else
const auto blockwise_gemm =
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
#endif
// LDS: be careful of alignment
constexpr unsigned in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size =
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads())
{
// input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
// weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
__syncthreads();
// a series of GEMM
for(unsigned s = 0; s < S; ++s)
{
for(unsigned r = 0; r < R; ++r)
{
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
p_in_block + s * Wi + r,
p_out_thread,
f_accum);
}
}
}
// output: register to global mem,
const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
}
#endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned n_data = b_data / (Hi * Wi);
unsigned itmp = b_data - n_data * (Hi * Wi);
unsigned h_data = itmp / Wi;
unsigned w_data = itmp - h_data * Wi;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
k,
b,
k_data,
n_data,
h_data,
w_data,
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
}
#endif
if(n_data < N && h_data < Ho && w_data < Wo)
{
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
}
}
}
}

View File

@@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
}
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
@@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row_begin,
matrix_c_index.col_begin,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);

View File

@@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
}
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
@@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row_begin,
matrix_c_index.col_begin,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);

View File

@@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
}
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned k_thread_data_begin = matrix_c_index.row;
const unsigned b_thread_data_begin = matrix_c_index.col;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
@@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
matrix_c_index.row_begin,
matrix_c_index.col_begin,
matrix_c_index.row,
matrix_c_index.col,
k_data_begin,
b_data_begin,
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);