This commit is contained in:
Chao Liu
2019-01-21 11:36:31 -06:00
parent 2096847297
commit c64f63d5ec
5 changed files with 73 additions and 46 deletions

View File

@@ -22,7 +22,7 @@ template <unsigned GridSize,
unsigned HoPerThread,
unsigned WoPerThread>
__global__ void
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc,
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
Float* const __restrict__ p_in_global,
WeiGlobalDesc,
Float* const __restrict__ p_wei_global,

View File

@@ -19,7 +19,8 @@ template <unsigned GridSize,
unsigned CPerBlock,
unsigned BPerThread,
unsigned KPerThread,
unsigned CPerThread>
unsigned CPerThread,
unsigned BPerBatch>
__global__ void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
Float* const __restrict__ p_in_global,
@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{},
Number<BPerBlock>{},
Number<BPerBatch>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
const auto blockwise_gemm =
const auto blockwise_batched_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc),
@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false,
false,
0,
BPerBatch,
0,
0,
1,
BPerBlock/BPerBatch,
1,
CPerThread,
true>{};
@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block + s * Wi + r,
p_out_thread,
f_accum);
@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
// output: register to global mem,
const auto matrix_c_index =
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;