From e98451e42ff32df3a1047c889dc2876a3e36492b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 11 Jun 2025 23:41:03 +0200 Subject: [PATCH] Move SetZero functions inside the kernels for Grouped Conv (#2255) * Disable SetZero before launch kernel for grouped conv fwd * Move set zero to kernel * wmma fix * fix --------- Co-authored-by: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com> [ROCm/composable_kernel commit: 8c1ed6f4c152ac29aa535afabf7b5cb7da4ba316] --- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 30 +++++++++++++- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 29 +++++++++++++- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 8 +++- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 15 +++++-- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 16 ++++---- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 40 +++++++++++++------ .../profile_grouped_conv_bwd_data_impl.hpp | 6 --- .../profile_grouped_conv_bwd_weight_impl.hpp | 3 -- .../profile_grouped_conv_fwd_impl.hpp | 3 -- .../test_grouped_convnd_bwd_data_xdl.cpp | 10 +++++ 10 files changed, 121 insertions(+), 39 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 5e41c96dfc..651e730b63 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -6,6 +6,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -244,6 +245,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle input_right_pads_{input_right_pads}, k_batch_{split_k} { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal to the we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; @@ -449,6 +466,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle std::array input_right_pads_; const index_t k_batch_; + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; }; // Invoker @@ -474,6 +493,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && i == 0) + { + hip_check_error(hipMemsetAsync( + arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -494,8 +521,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ComputePtrOffsetOfStridedBatch, has_main_loop>; - return launch_and_time_kernel( + return launch_and_time_kernel_with_preprocess( stream_config, + clear_workspace, kernel, dim3(grid_size), dim3(BlockSize), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index f18ce40fc5..f6f354f98e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -517,6 +517,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads_{input_right_pads}, k_batch_{split_k} { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal to the we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); @@ -887,6 +903,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const index_t k_batch_; index_t num_workgroups_per_Conv_N_; + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; }; // Invoker @@ -940,6 +958,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && i == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -961,8 +987,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 has_main_loop, ElementOp>; - return launch_and_time_kernel( + return launch_and_time_kernel_with_preprocess( stream_config, + clear_workspace, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 33b6d7c585..672c7dd2f7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -595,6 +595,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle input_right_pads_{input_right_pads}, k_batch_{split_k} { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(AccDataType); + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -709,6 +714,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; const index_t k_batch_; + long_index_t c_space_size_bytes; }; // Invoker @@ -757,7 +763,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle auto preprocess = [&]() { hip_check_error(hipMemsetAsync( - p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 6a708a9e7e..c7c463f43d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -550,6 +550,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads_{input_right_pads}, k_batch_{split_k} { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(AccDataType); + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -747,6 +752,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; const index_t k_batch_; + long_index_t c_space_size_bytes; }; // Invoker @@ -810,10 +816,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { - hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid, - 0, - arg.GetWorkspaceETensorSizeBytes(), - stream_config.stream_id_)); + if(arg.k_batch_ > 1) + { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + } }; const auto Run = [&](const auto& kernel) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index c904b4e7d5..6c53161ded 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -468,6 +468,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle input_right_pads_{input_right_pads}, k_batch_{split_k} { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -654,6 +659,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; const index_t k_batch_; + long_index_t c_space_size_bytes; }; // Invoker @@ -773,14 +779,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle has_main_loop>; const auto clear_workspace = [&]() { - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - hip_check_error(hipMemsetAsync(p_e_grid, - 0, - arg.GetWorkspaceETensorSizeBytes(), - stream_config.stream_id_)); - } + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; avg_time += launch_and_time_kernel_with_preprocess( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index b28b7347b6..f13a256d6b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -427,6 +427,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 input_right_pads_{input_right_pads}, k_batch_{split_k} { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -509,6 +514,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const std::array& input_left_pads_; const std::array& input_right_pads_; const index_t k_batch_; + long_index_t c_space_size_bytes; }; // Invoker @@ -559,6 +565,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const auto num_k_per_block = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + const auto clear_workspace = [&]() { + if(arg.k_batch_ > 1) + { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + } + }; + const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) { @@ -575,6 +589,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ck::utility::flush_icache(); // rotating mem rotating_mem.Next(); + clear_workspace(); }; ave_time += ck::utility::launch_and_time_kernel_with_preprocess( stream_config, @@ -592,18 +607,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_batch_, - num_k_per_block); + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); } }; diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 4e0ced347d..6cd8440e58 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -86,9 +86,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); - // reset input to zero - in_device_buf.SetZero(); - float max_accumulated_value = 0; if(do_verification) { @@ -136,9 +133,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - // re-init output to zero before profiling next kernel - in_device_buf.SetZero(); - std::string op_name = op_ptr->GetTypeString(); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index a13f79182e..ca9b2f1d24 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -11,7 +11,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp" @@ -207,8 +206,6 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - // using atomic add, so need to reset input - wei_device_buf.SetZero(); std::string op_name = op_ptr->GetTypeString(); diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index dfa6bc1edd..08e707b665 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -155,9 +155,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification, if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - // re-init output to zero before profiling next kernel - out_device_buf.SetZero(); - std::string op_name = op_ptr->GetTypeString(); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp index c4404b95ba..7f8f64c2e2 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp @@ -104,6 +104,12 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) {2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 2, 32, 32, {2, 2}, {12, 12}, {3, 3}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 2, 32, 32, {2, 2}, {12, 12}, {2, 2}, {2, 2}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 1, 6, 448, 896, {1, 1}, {118, 182}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); @@ -119,6 +125,10 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 2, 2}, {1, 2, 2}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back(