mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
refactored implicit gemm v1r3
[ROCm/composable_kernel commit: 284e7bb317]
This commit is contained in:
@@ -98,8 +98,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
@@ -212,44 +211,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global_block_offset +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
@@ -282,7 +243,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
|
||||
@@ -128,17 +128,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
@@ -147,7 +138,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
|
||||
Reference in New Issue
Block a user