diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 6624570b27..60cd06eed7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 } }; + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + // Invoker struct Invoker : public BaseInvoker { @@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 return str.str(); } + + static ck::index_t GetMaxOccupancy() + { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + return active_workgroups_per_cu.max_occupancy_; + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index f6e2a383ab..b4a90af666 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -13,6 +13,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" namespace ck { namespace tensor_operation { @@ -118,8 +120,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, - p_wei_grid_{p_wei_grid}, - split_k_{split_k} + p_wei_grid_{p_wei_grid} { constexpr index_t spatial_offset = 3; const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, @@ -143,6 +144,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); + if (split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else { + split_k_ = split_k; + } + if constexpr(IsTwoStageNeeded) { const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths), @@ -237,7 +251,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl bool is_filter_data_packed; CElementwiseGridDesc elementwise_desc_; Block2TileMapElementwise elementwise_block_2_ctile_map_; - const ck::index_t split_k_; + ck::index_t split_k_; }; // Invoker @@ -303,15 +317,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl static bool IsSupportedArgument(const Argument& arg) { - if(arg.split_k_ < 0) - { - // TODO: Add split-K autodeduction. - // This will probably require adding interface to the GEMM operation for - // querying the optimal split-K value, as we cannot easily access the actual GEMM kernel - // from here. - return false; - } - if constexpr(NDimSpatial == 2) { if constexpr(!is_NHWGC_GKYXC_NHWGK())