mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn use khwn for thread C data now
This commit is contained in:
@@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
|
||||
@@ -104,8 +104,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
constexpr auto out_khwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -179,7 +179,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxwn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_khwn_thread_desc.GetStride(I1)>{});
|
||||
|
||||
#if 0
|
||||
const auto blockwise_batch_gemm =
|
||||
@@ -192,7 +194,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
out_khwn_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
GemmKPerThreadLoop,
|
||||
@@ -205,7 +207,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
out_khwn_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
@@ -230,10 +232,10 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
|
||||
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_begin =
|
||||
p_in_global + in_chwn_global_desc.Get1dIndex(
|
||||
@@ -275,33 +277,30 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
|
||||
#if 0
|
||||
// for v1 batch-gemm
|
||||
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
out_hkwn_thread_desc,
|
||||
threadwise_4d_tensor_copy(
|
||||
out_khwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_khwn_from_hkwn);
|
||||
out_khwn_thread_desc.GetLengths());
|
||||
#else
|
||||
for(unsigned ho = 0; ho < out_hkwn_thread_desc.GetLength(I0); ++ho)
|
||||
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned k = 0; k < out_hkwn_thread_desc.GetLength(I1); ++k)
|
||||
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(unsigned wo = 0; wo < out_hkwn_thread_desc.GetLength(I2); ++wo)
|
||||
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(unsigned n = 0; n < out_hkwn_thread_desc.GetLength(I3); ++n)
|
||||
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const unsigned b = out_hkwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
@@ -312,13 +311,13 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const unsigned wo_thread = b_thread / NPerBlock;
|
||||
const unsigned n_thread = b_thread - NPerBlock * wo_thread;
|
||||
const unsigned n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_hkwn_thread_desc.Get1dIndex(ho, k, wo, n)];
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user