From f25da17c3684fdfb79b4933bc3d04f8e8602e63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 28 Jul 2025 22:39:07 +0200 Subject: [PATCH] Enable multiple D for grouped conv fwd large tensors (#2572) [ROCm/composable_kernel commit: 5b244105d9faaef58486c815e436c1bb03be2dd9] --- Jenkinsfile | 4 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 4 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 12 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 358 +++++++++++------- .../transform_conv_fwd_to_gemm.hpp | 8 + .../CMakeLists.txt | 4 + ...uped_convnd_fwd_bias_clamp_large_cases.cpp | 135 +++++++ 7 files changed, 377 insertions(+), 148 deletions(-) create mode 100644 test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp diff --git a/Jenkinsfile b/Jenkinsfile index b34e366f1b..f08e247a06 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1046,8 +1046,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f90f9b457b..1448914dd3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -106,9 +106,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -121,7 +123,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); if constexpr(isMultiA || isMultiB) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 55ec0d21e9..bb31d64a93 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -88,13 +88,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; using DsGridPointer = typename GridwiseGemm::DsGridPointer; DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); @@ -168,13 +170,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; using DsGridPointer = typename GridwiseGemm::DsGridPointer; DsGridPointer p_ds_grid_grp{}; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 9279f7547a..8f3feee1c1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -63,11 +63,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -89,10 +91,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) group_id = index_t((left + right) / 2); } + using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); + DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = DsPointer::Size(); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = + gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; + }); + GridwiseGemm::template Run( gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, - Tuple<>{}, + p_ds_grid_grp, gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, p_shared, a_element_op, @@ -100,7 +110,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) c_element_op, gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - Tuple<>{}, + gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_kernel_args[group_id].block_2_etile_map_); #else @@ -259,18 +269,44 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return out_gemmm_gemmn_desc; } + static auto + MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + static auto CastDsPointers(const std::array& p_ds) + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(p_ds[i]); + }, + Number{}); + } + + using DsPointer = decltype(CastDsPointers(std::array{})); // desc for problem definition constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; using AGridDesc_M_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; using BGridDesc_N_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; static auto GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, const ADataType* a_grid_ptr_base, + DsPointer ds_grid_ptr_base, EDataType* c_grid_ptr_base) { // Max number of splits @@ -279,11 +315,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Arrays to store transformers with smaller descs than 2GB Array conv_to_gemm_transformers_arr; Array a_grid_ptrs_arr; + Array ds_grid_ptrs_arr; Array c_grid_ptrs_arr; // Queue for spliting std::queue conv_to_gemm_transformers_queue( {conv_to_gemm_transformer_base}); std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue ds_grid_ptrs_queue({ds_grid_ptr_base}); std::queue c_grid_ptrs_queue({c_grid_ptr_base}); index_t gemms_number = 0; @@ -300,6 +338,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Get transformer from the queue const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front(); EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); // Check if convolution not exceed 2GB @@ -308,8 +347,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // If yes, push into result array conv_to_gemm_transformers_arr(gemms_number) = ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; - a_grid_ptrs_arr(gemms_number) = a_grid_ptr; - c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; gemms_number++; } else @@ -318,19 +358,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, conv_to_gemm_transformers_right_part; const ADataType* a_grid_right_ptr; + DsPointer ds_grid_right_ptr; EDataType* c_grid_right_ptr; ck::tie(conv_to_gemm_transformers_left_part, conv_to_gemm_transformers_right_part, a_grid_right_ptr, + ds_grid_right_ptr, c_grid_right_ptr) = - conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr); + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr); conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); // Left offsets remain the same a_grid_ptrs_queue.push(a_grid_ptr); a_grid_ptrs_queue.push(a_grid_right_ptr); + ds_grid_ptrs_queue.push(ds_grid_ptr); + ds_grid_ptrs_queue.push(ds_grid_right_ptr); c_grid_ptrs_queue.push(c_grid_ptr); c_grid_ptrs_queue.push(c_grid_right_ptr); split_numbers++; @@ -338,6 +382,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Remove from the queue conv_to_gemm_transformers_queue.pop(); a_grid_ptrs_queue.pop(); + ds_grid_ptrs_queue.pop(); c_grid_ptrs_queue.pop(); } @@ -345,6 +390,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return ck::make_tuple(conv_to_gemm_transformers_arr, a_grid_ptrs_arr, + ds_grid_ptrs_arr, c_grid_ptrs_arr, gemms_number, is_split_valid); @@ -375,6 +421,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -388,11 +437,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // pointers const ADataType* a_ptr_; const BDataType* b_ptr_; + DsPointer ds_ptr_; EDataType* e_ptr_; // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map @@ -405,16 +457,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { Argument(const void* p_a, const void* p_b, - const std::array& /*p_ds*/, + const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, const std::array, NumDTensor>& - /*ds_g_n_k_wos_lengths*/, + ds_g_n_k_wos_lengths, const std::array, NumDTensor>& - /*ds_g_n_k_wos_strides*/, + ds_g_n_k_wos_strides, const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, @@ -434,6 +486,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, conv_filter_strides_{conv_filter_strides}, @@ -441,94 +495,105 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - if constexpr(NumDTensor == 0) + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array ds_grid_ptrs; + Array c_grid_ptrs; + + DsPointer p_ds_casted = CastDsPointers(p_ds); + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + ds_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + p_ds_casted, + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) { - // Perform grouped gemm, generate array of tranformer for convolution - Array conv_to_gemm_transformer_arr; - Array a_grid_ptrs; - Array c_grid_ptrs; - - ck::tie(conv_to_gemm_transformer_arr, - a_grid_ptrs, - c_grid_ptrs, - gemms_count_, - is_split_valid_) = - GenerateConvToGemmTransforms( - ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, - a_g_n_c_wis_strides_, - b_g_k_c_xs_lengths_, - b_g_k_c_xs_strides_, - e_g_n_k_wos_lengths_, - e_g_n_k_wos_strides_, - conv_filter_strides_, - conv_filter_dilations_, - input_left_pads_, - input_right_pads_}, - static_cast(p_a), - static_cast(p_e)); - - grid_size_ = 0; - valid_gemms_count_ = 0; - - if(is_split_valid_) + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) { - // Create GemmArg for each gemm(conv) - for(index_t i = 0; i < gemms_count_; i++) + const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); + + const auto ds_grid_desc_m_n = + generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number{}); + + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) { - const AGridDesc_M_K a_grid_desc_m_k{ - DeviceOp::MakeAGridDescriptor_M_K( - conv_to_gemm_transformer_arr[i])}; - const BGridDesc_N_K b_grid_desc_n_k{ - DeviceOp::MakeBGridDescriptor_N_K( - conv_to_gemm_transformer_arr[i])}; - const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( - conv_to_gemm_transformer_arr[i]); + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; - const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - const index_t grid_size_grp = - block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); - - const index_t BlockStart = grid_size_; - const index_t BlockEnd = grid_size_ + grid_size_grp; - - grid_size_ += grid_size_grp; - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - Tuple<>{}, - e_grid_desc_m_n, - block_2_etile_map)) - { - - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; - - valid_gemms_count_++; - } + valid_gemms_count_++; } - // N is the same for all convs - conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } - - // Strides for G and N remain the same - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; - compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); } void Print() const @@ -558,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor bool is_split_valid_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -571,6 +636,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor std::array a_g_n_c_wis_strides_; std::array b_g_k_c_xs_lengths_; std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; std::array e_g_n_k_wos_lengths_; std::array e_g_n_k_wos_strides_; std::array conv_filter_strides_; @@ -584,63 +651,55 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if constexpr(NumDTensor == 0) + if(stream_config.log_level_ > 0) { - if(stream_config.log_level_ > 0) - { - arg.Print(); - } + arg.Print(); + } - const index_t num_workgroups_per_Conv_N = - arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; - const index_t gdx = arg.grid_size_; - const index_t gdy = arg.num_group_; - const index_t gdz = num_workgroups_per_Conv_N; + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; - // K is constant for all gemms - const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * - arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = - kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< - GridwiseGemm, - MaxGemmsNum, - GemmArgs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; - return launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg.gemm_desc_kernel_args_, - arg.gemms_count_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); - }; + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); } else { - return 0.f; + return launch_kernel(integral_constant{}); } } @@ -657,9 +716,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; - // Move this to runtime check to align Conv instances - // with Conv Multiple D instances - if constexpr(NumDTensor != 0) + + bool ds_valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + for(int d = 0; d < NDimSpatial + I3; d++) + { + if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d]) + { + ds_valid = false; + } + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + ds_valid = false; + } + } + + using DDataType = remove_cvref_t>; + static_assert(is_same_v); + }); + + if(!ds_valid) { return false; } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 92b48c44b3..50f6ba3b53 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -389,7 +389,9 @@ struct TransformConvFwdToGemm return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; } + template __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + DsPointer& ds_grid_ptr_base, CDataType* c_grid_ptr_base) const { // Create copies @@ -480,11 +482,17 @@ struct TransformConvFwdToGemm a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; c_right_offset = (Wo_ / 2) * WoStride_; } + + static constexpr index_t NumDTensor = DsPointer::Size(); + const auto ds_grid_right_ptr = generate_tuple( + [&](auto i) { return ds_grid_ptr_base(i) + c_right_offset; }, Number{}); + // Return left transform, right transformer, right offset to Input and right offset to // Output return ck::make_tuple(conv_to_gemm_transformer_left, conv_to_gemm_transformer_right, a_grid_ptr_base + a_right_offset, + ds_grid_right_ptr, c_grid_ptr_base + c_right_offset); } diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index 8bded647b6..f964325c06 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -7,4 +7,8 @@ if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp) target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance) + + add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp) + target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp new file mode 100644 index 0000000000..7a59a95527 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AddClamp = ck::tensor_operation::element_wise::AddClamp; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::long_index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdBiasClamp2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwdBiasClamp3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdBiasClamp2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwdBiasClamp3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdBiasClamp2d, Test2D) +{ + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back( + {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwdBiasClamp3d, Test3D) +{ + // Case larger than 2GB + this->conv_params.push_back({3, + 1, + 128, + 4, + 192, + {2, 2, 2}, + {2, 224, 224}, + {1, 224, 224}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back({3, + 32, + 64, + 1, + 1, + {2, 2, 2}, + {360, 2, 672}, + {360, 2, 672}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back({3, + 1, + 2, + 128, + 128, + {3, 1, 3}, + {900, 2, 2048}, + {300, 1, 300}, + {3, 2, 3}, + {1, 1, 1}, + {1, 1, 1}}); + this->template Run<3>(); +}