mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
add another blockwise gemm
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
|
||||
//#include "device_winograd_convolution.cuh"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
@@ -391,7 +392,7 @@ int main()
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -484,7 +485,7 @@ int main()
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -591,8 +592,10 @@ int main()
|
||||
device_implicit_gemm_convolution_1_chwn_csrk_khwn
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
|
||||
@@ -608,7 +611,7 @@ int main()
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(S == 3 && R == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
|
||||
@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
@@ -90,31 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#elif 1
|
||||
// 1x1, 28x28 try
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh"
|
||||
#include <unistd.h>
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcsr_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_desc.GetLength(I0);
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
|
||||
constexpr unsigned S = wei_kcsr_desc.GetLength(I2);
|
||||
constexpr unsigned R = wei_kcsr_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
|
||||
|
||||
// convert in_nchw to in_cnhw
|
||||
auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence<C, N, Hi, Wi>{});
|
||||
ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: ");
|
||||
|
||||
Tensor<T> in_cnhw(make_TensorDescriptor(in_cnhw_desc));
|
||||
|
||||
auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// convert wei_kcsr to wei_csrk
|
||||
auto wei_csrk_desc = make_ConstantTensorDescriptor(Sequence<C, S, R, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_csrk_desc, std::cout << "wei_csrk_desc: ");
|
||||
|
||||
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
|
||||
|
||||
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
|
||||
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, S, R)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// conver out_nkhw to out_knhw
|
||||
auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence<K, N, Ho, Wo>{});
|
||||
ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: ");
|
||||
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 0
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 16;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 8;
|
||||
constexpr unsigned GemmMLevel1Cluster = 1;
|
||||
constexpr unsigned GemmNLevel1Cluster = 2;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#elif 1
|
||||
// 1x1, 28x28 try
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 8;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 1;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
dim3 block_dim(BlockSize);
|
||||
dim3 grid_dim(GridSize);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
// mem
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead +
|
||||
BPerBlock)); // reserve extra space for BGhostRead
|
||||
DeviceMem wei_csrk_device_buf(data_sz * wei_csrk.mDesc.GetElementSpace());
|
||||
DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace());
|
||||
|
||||
in_cnhw_device_buf.ToDevice(in_cnhw.mData.data());
|
||||
wei_csrk_device_buf.ToDevice(wei_csrk.mData.data());
|
||||
out_knhw_device_buf.ToDevice(out_knhw.mData.data());
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
cudaEvent_t start, stop;
|
||||
float elapsedTime;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_cnhw_desc),
|
||||
decltype(wei_csrk_desc),
|
||||
decltype(out_knhw_desc),
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_csrk_desc,
|
||||
static_cast<T*>(wei_csrk_device_buf.GetDeviceBuffer()),
|
||||
out_knhw_desc,
|
||||
static_cast<T*>(out_knhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
cudaEventCreate(&stop);
|
||||
cudaEventRecord(stop, 0);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
|
||||
usleep(std::min(elapsedTime * 1000, float(10000)));
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
out_knhw_device_buf.FromDevice(out_knhw.mData.data());
|
||||
|
||||
// convert out_knhw to out_nkhw
|
||||
auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -22,9 +22,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned batch_begin;
|
||||
unsigned row_begin;
|
||||
unsigned col_begin;
|
||||
unsigned batch;
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
};
|
||||
|
||||
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
|
||||
@@ -32,15 +32,15 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
|
||||
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA +
|
||||
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
|
||||
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin));
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
|
||||
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
|
||||
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row));
|
||||
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB +
|
||||
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0));
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
|
||||
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0));
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -52,16 +52,16 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch_begin,
|
||||
c_thread_mtx_index.row_begin,
|
||||
c_thread_mtx_index.col_begin,
|
||||
c_thread_mtx_index.batch,
|
||||
c_thread_mtx_index.row,
|
||||
c_thread_mtx_index.col,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
{
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
@@ -237,8 +237,8 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row_begin;
|
||||
unsigned col_begin;
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadC()
|
||||
@@ -246,13 +246,13 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
|
||||
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
|
||||
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin);
|
||||
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0)
|
||||
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
|
||||
|
||||
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0);
|
||||
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col)
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -264,16 +264,16 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch_begin,
|
||||
c_thread_mtx_index.row_begin,
|
||||
c_thread_mtx_index.col_begin,
|
||||
c_thread_mtx_index.batch,
|
||||
c_thread_mtx_index.row,
|
||||
c_thread_mtx_index.col,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
{
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
@@ -359,6 +359,13 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
}
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
{
|
||||
return MatrixIndex{m_in_c, n_in_c};
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void Run(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
@@ -420,3 +427,215 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
|
||||
template <unsigned BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
unsigned MPerThreadSubC,
|
||||
unsigned NPerThreadSubC,
|
||||
unsigned MLevel0Cluster,
|
||||
unsigned NLevel0Cluster,
|
||||
unsigned MLevel1Cluster,
|
||||
unsigned NLevel1Cluster,
|
||||
unsigned KPerThreadLoop>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
};
|
||||
|
||||
unsigned mMyThreadOffsetA;
|
||||
unsigned mMyThreadOffsetB;
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
|
||||
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level0Cluster\n");
|
||||
|
||||
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
|
||||
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
|
||||
"wrong! thread work size is wrong\n");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
|
||||
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
unsigned level1_m_id = level1_id / NLevel1Cluster;
|
||||
unsigned level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
unsigned level0_m_id = level0_id / NLevel0Cluster;
|
||||
unsigned level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
{
|
||||
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
unsigned m_repeat = m_in_c / MPerThreadSubC;
|
||||
unsigned n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void Run(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread,
|
||||
Accumulator f_accum) const
|
||||
{
|
||||
constexpr auto True = Constant<bool, true>{};
|
||||
constexpr auto False = Constant<bool, false>{};
|
||||
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
const auto a_thread_sub_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
|
||||
Number<MPerThreadSubC>{},
|
||||
Number<MPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_thread_sub_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
|
||||
Number<NPerThreadSubC>{},
|
||||
Number<NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
k_begin * a_block_mtx.RowStride() +
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
k_begin * b_block_mtx.RowStride() +
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -208,13 +208,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
|
||||
const unsigned n_thread_data_begin =
|
||||
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
|
||||
@@ -262,13 +262,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
|
||||
const unsigned n_thread_data_begin =
|
||||
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
|
||||
@@ -318,13 +318,12 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
|
||||
const unsigned n_thread_data_begin =
|
||||
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
|
||||
@@ -228,15 +228,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
#if 0
|
||||
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin);
|
||||
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch, matrix_c_index.row, matrix_c_index.col);
|
||||
#endif
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerThread;
|
||||
|
||||
#if 1
|
||||
// output: register to global mem,
|
||||
|
||||
@@ -205,13 +205,12 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
|
||||
const unsigned n_thread_data_begin =
|
||||
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
|
||||
@@ -75,6 +75,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
@@ -245,11 +246,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
@@ -257,11 +257,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.row_begin,
|
||||
matrix_c_index.col_begin,
|
||||
matrix_c_index.row,
|
||||
matrix_c_index.col,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_2d_tensor_op.cuh"
|
||||
#include "threadwise_2d_tensor_op.cuh"
|
||||
#include "blockwise_gemm.cuh"
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned BPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned GemmMPerThreadSubC,
|
||||
unsigned GemmNPerThreadSubC,
|
||||
unsigned GemmMLevel0Cluster,
|
||||
unsigned GemmNLevel0Cluster,
|
||||
unsigned GemmMLevel1Cluster,
|
||||
unsigned GemmNLevel1Cluster,
|
||||
unsigned GemmKPerThreadLoop,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_cnhw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_csrk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_knhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_cnhw_global_desc.GetLength(I0);
|
||||
constexpr unsigned N = in_cnhw_global_desc.GetLength(I1);
|
||||
constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = out_knhw_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned S = wei_csrk_global_desc.GetLength(I1);
|
||||
constexpr unsigned R = wei_csrk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned B = N * Hi * Wi;
|
||||
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
|
||||
print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
|
||||
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
|
||||
|
||||
printf("KPerBlock %u\n", KPerBlock);
|
||||
}
|
||||
#endif
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*S*R,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_csrk_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<BPerBlock>{},
|
||||
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
#if 0
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
#else
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>{};
|
||||
#endif
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size =
|
||||
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
|
||||
|
||||
Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
|
||||
|
||||
Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
for(unsigned r = 0; r < R; ++r)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.row,
|
||||
matrix_c_index.col,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
unsigned n_data = b_data / (Hi * Wi);
|
||||
unsigned itmp = b_data - n_data * (Hi * Wi);
|
||||
unsigned h_data = itmp / Wi;
|
||||
unsigned w_data = itmp - h_data * Wi;
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
k,
|
||||
b,
|
||||
k_data,
|
||||
n_data,
|
||||
h_data,
|
||||
w_data,
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
|
||||
}
|
||||
#endif
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -290,11 +290,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
@@ -302,11 +301,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.row_begin,
|
||||
matrix_c_index.col_begin,
|
||||
matrix_c_index.row,
|
||||
matrix_c_index.col,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
|
||||
@@ -217,11 +217,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
@@ -229,11 +228,11 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.row_begin,
|
||||
matrix_c_index.col_begin,
|
||||
matrix_c_index.row,
|
||||
matrix_c_index.col,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
|
||||
@@ -276,11 +276,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
const auto matrix_c_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
@@ -288,11 +287,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
printf("%u %u, row %u col %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.row_begin,
|
||||
matrix_c_index.col_begin,
|
||||
matrix_c_index.row,
|
||||
matrix_c_index.col,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
|
||||
Reference in New Issue
Block a user