adding implicit gemm

[ROCm/composable_kernel commit: aa0199a31c]
This commit is contained in:
Chao Liu
2019-01-14 11:13:36 -06:00
parent 61e180de4a
commit 50256bbcfe
3 changed files with 70 additions and 109 deletions

View File

@@ -26,53 +26,24 @@ void device_implicit_gemm_convolution(
constexpr auto out_desc = OutDesc{};
#if 1
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 1;
constexpr unsigned XPerBlock = 16;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned BlockSize = 128;
#elif 0
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 1;
constexpr unsigned XPerBlock = 27;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned BlockSize = 216;
#elif 0
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 1;
constexpr unsigned XPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 4;
constexpr unsigned BlockSize = 256;
#endif
constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) *
(out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / (OutTileSizeH * YPerBlock)) *
(out_desc.GetLength(I3) / (OutTileSizeW * XPerBlock));
constexpr unsigned GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
dim3 block_dim(BlockSize);
dim3 grid_dim(GridSize);
@@ -85,22 +56,21 @@ void device_implicit_gemm_convolution(
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution<T,
InDesc,
WeiDesc,
OutDesc,
OutTileSizeH,
OutTileSizeW,
NPerBlock,
KPerBlock,
CPerBlock,
YPerBlock,
XPerBlock,
NPerThread,
KPerThread,
CPerThread,
BlockSize,
GridSize>
gridwise_implicit_gemm_convolution_nchw_kcsr<GridSize,
BlockSize,
T,
InDesc,
WeiDesc,
OutDesc,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread>
<<<grid_dim, block_dim>>>(InDesc{},
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
WeiDesc{},