From b56e9f6bc4bfe3fd75e885d4c660a7e5b5e90165 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 14 Aug 2025 09:19:58 +0000 Subject: [PATCH] Add support for occupancy-based splitk --- ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 74 ++++++++++++++++++- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 73 +++++++++++++++++- .../test_grouped_convnd_bwd_weight.cpp | 5 -- 3 files changed, 139 insertions(+), 13 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 00d1406a19..049118597a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -22,6 +22,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -440,7 +442,42 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -483,9 +520,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -507,6 +545,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = calculate_mn_grid_size(gemmM, gemmN) * + Conv_G_ / NumGroupsToMerge; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -682,7 +749,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; }; // Invoker diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 5c14b03bf9..aa4ff80719 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -22,6 +22,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -401,7 +403,41 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -444,9 +480,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -463,6 +500,35 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else + { + k_batch_ = split_k; + } + std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); @@ -633,7 +699,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; }; diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index e216ecb27c..fd84f57095 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -44,11 +44,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if((split_k < 1) && (ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - return true; - } - return false; }