From e91e8e7908300466686545c5ea00149dd9a64e25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Thu, 31 Jul 2025 13:08:45 +0300 Subject: [PATCH] Automatic deduction of split-K value for grouped convolution (#2491) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Use simple best occupancy model to calculate the split-K. * Handle split-K autodeduction in explicit gemm conv. * Add unit tests for split-K autodeduction. * Remove oversubscription. * Small fixes. * Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle. * Run clang formatting. * Fix error handling in the conv profiler. * Add missing documentation for the autodeducted split-K values. * Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver. * Fix clang formatting and split-K profiler documentation. * Rename max_occupancy value variable. * Calculate grid size for split-K autodeduction directly from input array shapes and template params. --------- Co-authored-by: Ville Pietilä <> [ROCm/composable_kernel commit: e962a4163818c1f316172626ea6330be0d6afa5e] --- ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 60 ++++++++++++ ...e_grouped_conv_bwd_weight_explicit_xdl.hpp | 21 ++++- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 55 ++++++++++- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 87 ++++++++++++++++- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 56 ++++++++++- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 85 ++++++++++++++++- .../gpu/device/impl/split_k_arg.hpp | 17 ++++ .../gpu/device/impl/split_k_utils.hpp | 93 +++++++++++++++++++ profiler/README.md | 2 +- .../profile_grouped_conv_bwd_weight_impl.hpp | 42 ++++++--- .../src/profile_grouped_conv_bwd_weight.cpp | 6 +- .../test_grouped_convnd_bwd_weight.cpp | 4 +- ...rouped_convnd_bwd_weight_interface_xdl.cpp | 44 +++++---- ...ped_convnd_bwd_weight_v3_interface_xdl.cpp | 44 +++++---- 14 files changed, 544 insertions(+), 72 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index cf7941195e..64d5fbd509 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 } }; + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + // Invoker struct Invoker : public BaseInvoker { @@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 return str.str(); } + + static ck::index_t GetMaxOccupancy() + { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + return active_workgroups_per_cu.max_occupancy_; + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index e5872816f5..5d68ca720a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -13,6 +13,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" namespace ck { namespace tensor_operation { @@ -142,6 +144,20 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); + if(split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else + { + split_k_ = split_k; + } + if constexpr(IsTwoStageNeeded) { const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths), @@ -176,7 +192,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k}; + split_k_}; } else { @@ -199,7 +215,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k}; + split_k_}; } } @@ -236,6 +252,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl bool is_filter_data_packed; CElementwiseGridDesc elementwise_desc_; Block2TileMapElementwise elementwise_block_2_ctile_map_; + ck::index_t split_k_; }; // Invoker diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 4e6b4927fc..b761939642 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -19,6 +19,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -542,7 +544,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + AccDataType, + OutElementwiseOperation, + InElementwiseOperation, + element_wise::PassThrough, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true>, + BlockSize, + dynamic_smem_size)); + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( const InDataType* p_in_grid, @@ -591,9 +622,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -610,6 +642,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + if(split_k < 0) + { + ck::index_t gemmM, gemmN; + std::tie(gemmM, gemmN, std::ignore) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -712,7 +760,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index bfb6707e09..95361287db 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -22,6 +22,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -504,7 +506,55 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -547,9 +597,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -576,6 +627,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = calculate_mn_grid_size(gemmM, gemmN) * + Conv_G_ / NumGroupsToMerge; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = static_cast((gemmK - 1) / KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -751,7 +831,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index b58f6885c7..488dadf512 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -19,6 +19,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -419,7 +421,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + false>, // Both true/false give the same occupancy. + BlockSize, + dynamic_smem_size)); + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -463,9 +494,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -491,6 +523,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle std::array e_g_k_c_xs_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + if(split_k < 0) + { + ck::index_t gemmM, gemmN; + std::tie(gemmM, gemmN, std::ignore) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -656,7 +705,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 243a6adafc..1cd1f16245 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -20,6 +20,8 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -381,7 +383,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); - struct Argument : public BaseArgument + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -424,9 +472,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + c_space_size_bytes = ck::accumulate_n( e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * @@ -443,6 +492,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else + { + k_batch_ = split_k; + } + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -513,7 +591,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; long_index_t c_space_size_bytes; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp new file mode 100644 index 0000000000..de683f3282 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +struct ArgumentSplitK +{ + index_t k_batch_{1}; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp new file mode 100644 index 0000000000..32179d179e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include +#include "ck/utility/env.hpp" +#include "ck/utility/number.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceProperties +{ + DeviceProperties() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + + num_cu_ = dev_prop.multiProcessorCount; + }; + int num_cu_; +}; + +inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size) +{ + static DeviceProperties device_properties; + const int max_capacity = max_occupancy * device_properties.num_cu_; + + ck::index_t k_batch = 1; + const auto optimal_split = + static_cast(std::floor((1.0 * max_capacity) / grid_size)); + if(optimal_split > 1) + { + k_batch = optimal_split; + } + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " + << max_occupancy << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl; + } + return k_batch; +} + +template +inline auto +get_bwd_weight_gemm_sizes(const std::array& a_g_n_k_wos_lengths, + const std::array& e_g_k_c_xs_lengths) +{ + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // The input array has elements in the order: G, N, K, Do, Ho, Wo + // GemmK = N * Do * Ho * Wo for the BWD weight pass. + constexpr index_t spatial_offset = 3; + const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + index_t{1}, + std::multiplies<>{}); + const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo; + + // The GEMM M dimension is the number of output channels. + const auto gemmM = e_g_k_c_xs_lengths[I1]; + + // The output array has elements in the order: G, K, C, X, Y, Z + // GemmN = C * X * Y * Z for the BWD weight pass. + const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + index_t{1}, + std::multiplies<>{}); + const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ; + return std::make_tuple(gemmM, gemmN, gemmK); +} + +template +inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN) +{ + const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock); + const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock); + return M0 * N0; +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/README.md b/profiler/README.md index 4398a878bc..05bbc7b4f9 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -148,7 +148,7 @@ # , (ie Dy, Dx for 2D) # , (ie LeftPy, LeftPx for 2D) # , (ie RightPy, RightPx for 2D) -# SplitK +# SplitK (-1 for internally computed split-K value, positive value to set k batches explicitly, or 'all' to test all internal split-K values) ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK ./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 84acb53425..479fed78e7 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -11,6 +11,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp" @@ -40,7 +41,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, bool do_log, bool time_kernel, const ck::utils::conv::ConvParam& conv_param, - ck::index_t split_k) + const std::string& split_k) { using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -138,10 +139,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::string best_op_name; - float best_avg_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - ck::index_t best_split_k = 1; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + std::string best_split_k("1"); // profile device Conv instances bool all_pass = true; @@ -170,11 +171,20 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_right_pads_, begin(input_right_pads)); - std::vector split_k_list = {1, 2, 4, 8, 16, 32, 64, 128}; + std::vector split_k_list = {/*auto deduce value*/ -1, 1, 2, 4, 8, 16, 32, 64, 128}; - if(split_k > 0) + if(split_k != "all") { - split_k_list = {split_k}; + try + { + ck::index_t split_k_value = std::stoi(split_k); + split_k_list = {split_k_value}; + } + catch(const std::exception& e) + { + std::cerr << e.what() << '\n'; + exit(EXIT_FAILURE); + } } for(auto& op_ptr : op_ptrs) @@ -200,6 +210,16 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, out_element_op, split_k_list[split_k_id]); + auto split_k_value = split_k_list[split_k_id]; + auto split_k_param_str = std::to_string(split_k_value); + auto* split_k_arg = + dynamic_cast(argument_ptr.get()); + if(split_k_arg && split_k_value < 0) + { + split_k_value = split_k_arg->k_batch_; + split_k_param_str = std::to_string(split_k_value) + " (best occupancy)"; + } + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); DeviceMem workspace_dev(workspace_sz); op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); @@ -222,7 +242,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK " - << split_k_list[split_k_id] << std::endl; + << split_k_param_str << std::endl; if(tflops > best_tflops) { @@ -230,7 +250,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, best_tflops = tflops; best_avg_time = avg_time; best_gb_per_sec = gb_per_sec; - best_split_k = split_k_list[split_k_id]; + best_split_k = split_k_param_str; } if(do_verification) @@ -244,7 +264,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; const index_t num_accums = output.GetElementSize() / conv_param.K_; - const index_t num_accums_split_k = split_k_list[split_k_id]; + const index_t num_accums_split_k = split_k_value; // Calculate thresholds auto rtol = ck::utils::get_relative_threshold( diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 1640b48ffd..8347ce0e42 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -56,7 +56,9 @@ static void print_helper_msg() << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n" - << ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() + << " SplitK (-1 for internally computed split-K value, positive value to set k " + "batches explicitly, or 'all' to test all internal split-K values)\n" << std::endl; } @@ -88,7 +90,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); - ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); + const auto& split_k = std::string(argv[8 + 1 + 4 + 6 * num_dim_spatial]); using F32 = float; using F16 = ck::half_t; diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 95a0a09414..8343629f3a 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -30,7 +30,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test using NDimSpatial = std::tuple_element_t<6, Tuple>; std::vector conv_params; - std::vector split_ks{1, 2}; + std::vector split_ks{-1, 1, 2}; bool skip_case(const ck::index_t split_k) { @@ -108,7 +108,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test false, // do_log false, // time_kernel param, - split_k); + std::to_string(split_k)); } } } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index cfbf13f00e..2ad1cd11f0 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test // clang-format on ck::utils::conv::ConvParam conv_param; - ck::index_t split_k{2}; + std::vector split_ks{-1, 2}; template bool Run() @@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto conv = GroupedConvBwdWeightDeviceInstance{}; - auto argument = conv.MakeArgument(nullptr, - nullptr, - nullptr, - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}, - split_k); - return conv.IsSupportedArgument(argument); + bool is_supported = true; + + for(const auto split_k : split_ks) + { + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + is_supported &= conv.IsSupportedArgument(argument); + } + return is_supported; } }; diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp index 1556f15898..bfd55a7c55 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp @@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test // clang-format on ck::utils::conv::ConvParam conv_param; - ck::index_t split_k{2}; + std::vector split_ks{-1, 2}; template bool Run() @@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto conv = GroupedConvBwdWeightDeviceInstance{}; - auto argument = conv.MakeArgument(nullptr, - nullptr, - nullptr, - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}, - split_k); - return conv.IsSupportedArgument(argument); + bool is_supported = true; + + for(const auto split_k : split_ks) + { + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + is_supported &= conv.IsSupportedArgument(argument); + } + return is_supported; } };