mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
|
||||
@@ -199,7 +199,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
#else
|
||||
constexpr auto map_k_e_2_e_k = Sequence<1, 0>{};
|
||||
|
||||
auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
|
||||
const auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)),
|
||||
@@ -296,9 +296,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0 // debug
|
||||
return;
|
||||
#endif
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
@@ -306,7 +304,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
@@ -339,14 +338,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
#if 0
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#else
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -369,14 +369,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
#if 0
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#else
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global, p_wei_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
Reference in New Issue
Block a user