mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
refactor
This commit is contained in:
@@ -22,7 +22,7 @@ template <unsigned GridSize,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc,
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
@@ -19,7 +19,8 @@ template <unsigned GridSize,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread>
|
||||
unsigned CPerThread,
|
||||
unsigned BPerBatch>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
|
||||
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
|
||||
|
||||
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<BPerBlock>{},
|
||||
Number<BPerBatch>{},
|
||||
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
const auto blockwise_batched_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
BPerBatch,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
BPerBlock/BPerBatch,
|
||||
1,
|
||||
CPerThread,
|
||||
true>{};
|
||||
@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
|
||||
Reference in New Issue
Block a user