diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 84439cb8d1..0b8cc1b0e9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -552,7 +552,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n); const auto k_size = a_grid_desc_kbatch_k0_m_k1.GetLength(I0) * a_grid_desc_kbatch_k0_m_k1.GetLength(I1); - k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, k_size, Conv_G_); + //const auto multiplier = static_cast(-split_k); + k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, k_size, Conv_G_/*, multiplier*/); } else { k_batch_ = split_k; diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 6b4ca45de6..8fe22afd50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -7,6 +7,8 @@ #include "ck/host_utility/hip_check_error.hpp" #include "ck/ck.hpp" +CK_DECLARE_ENV_VAR_UINT64(CK_SPLIT_K_BATCH_SIZE) + namespace ck { namespace tensor_operation { namespace device { @@ -25,22 +27,43 @@ struct DeviceProperties int num_cu_; }; -inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t K_size, ck::index_t conv_G) +inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t K_size, ck::index_t conv_G /*, ck::index_t multiplier*/) { static DeviceProperties device_properties; constexpr ck::index_t k_batch_min = 1; - constexpr ck::index_t batch_size_min = 16; + constexpr ck::index_t batch_size_min = 512; const int num_cu = device_properties.num_cu_; const auto k_batch_max = math::integer_divide_ceil(K_size, batch_size_min); - auto k_batch = static_cast(std::ceil((max_occupancy * num_cu) / (1.0 * grid_size))); // Exclude th egrid size from the occupancy calculation - k_batch = std::min(std::max(k_batch_min, k_batch), k_batch_max); + // Ensure that we do not exceed the maximum capacity. This would lead to wave quantization. + const auto optimal_split = static_cast(std::floor((max_occupancy * num_cu) / (1.0 * grid_size * conv_G))); + auto k_batch = 1; + if (optimal_split > 1) + { + //The optimal value of k_batch is a multiple of the optimal_split. + //We need to find the optimal number K values per batch - this gives the optimal k_batch value. + auto target_batch_size = static_cast(ck::EnvValue(CK_ENV(CK_SPLIT_K_BATCH_SIZE))); + if (target_batch_size < k_batch_min) + { + target_batch_size = k_batch_min; + } + k_batch = optimal_split; + const auto current_batch_size = math::integer_divide_ceil(K_size, k_batch); + if (current_batch_size > target_batch_size) + { + // If the current batch size is larger than the target batch size, we need to increase k_batch. + const ck::index_t multiplier = std::max(1, math::integer_divide_ceil(K_size, target_batch_size * optimal_split)); + k_batch = optimal_split * multiplier; + } + } + if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Output grid size (M tiles x N tiles x Conv groups): " << grid_size << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] K-dim size: " << K_size << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Conv groups: " << conv_G << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Maximum k_batch value: " << k_batch_max << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl; } diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 255346d7ae..4941ae5e3a 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -152,7 +152,8 @@ void write_perf_results_to_file(const PerfResults& perf_results_global, { file << res.opt_split_k_best_op_name_ << separator; } - file << res.opt_split_k_best_arg_ << separator + file << res.opt_split_k_avg_time_ << separator + << res.opt_split_k_best_arg_ << separator << rank << separator << total_num; };