diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index d79cfd5665..816ea78f08 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -549,13 +549,29 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto& c_grid_desc_m_n = descs_initial[I2]; const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial); - const auto gemmK = get_bwd_weight_gemm_k(a_g_n_k_wos_lengths); // Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups. // Hence, the grid is just size of the tile map. const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n); - k_dim_size_ = gemmK; - k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size); + k_dim_size_ = get_bwd_weight_gemm_k(a_g_n_k_wos_lengths); + const bool enable_oversubscription = k_dim_size_ > 1 << 13; + + // For small GemmK size, cap the max value of the k_batch. + k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription); + const auto k_batch_max = static_cast((k_dim_size_ - 1) / K0PerBlock); + if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_dim_size: " + << k_dim_size_ << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] K0PerBlock: " << K0PerBlock << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " + << k_batch_max << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: " + << k_batch_ << std::endl; + k_batch_ = std::min(k_batch_, k_batch_max); + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " + << k_batch_ << std::endl; + } } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index c10c6062e1..83c80b3150 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -526,7 +526,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Hence, the grid is just size of the tile map. const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(GemmM, GemmN); k_dim_size_ = get_bwd_weight_gemm_k(a_g_n_k_wos_lengths); - k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size); + const bool enable_oversubscription = k_dim_size_ > 1 << 13; + k_batch_ = get_k_batch_value(max_occupancy.value_, grid_size, BlockSize, enable_oversubscription); + + // Cap the k_batch_ value such that it doesn't violate the limit for the number of prefetch stages for the pipeline. + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + const auto k_batch_max = static_cast(std::floor( + (k_dim_size_ - 1.0) / ((GridwiseGemm::BlockwiseGemmPipe::PrefetchStages-1.0) * K0PerBlock))); + if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " + << k_batch_max << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Optimal k_batch value: " + << k_batch_ << std::endl; + k_batch_ = std::min(k_batch_, k_batch_max); + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " + << k_batch_ << std::endl; + } + } } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index af9d0da3cb..6979f7b709 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -24,25 +24,35 @@ struct DeviceProperties hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); num_cu_ = dev_prop.multiProcessorCount; + max_num_active_wavefronts_per_cu_ = dev_prop.maxThreadsPerMultiProcessor / dev_prop.warpSize; + wavefront_size_ = dev_prop.warpSize; }; int num_cu_; + int max_num_active_wavefronts_per_cu_; + int wavefront_size_; }; -inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size) +inline ck::index_t get_k_batch_value(int max_occupancy, ck::index_t grid_size, ck::index_t blockSize, bool enable_oversubscription = true) { static DeviceProperties device_properties; const int num_cu = device_properties.num_cu_; auto k_batch = 1; - const auto optimal_split = static_cast(std::floor((max_occupancy * num_cu) / (grid_size))); + const ck::index_t oversubscription = enable_oversubscription + ? static_cast(std::round((1.0 *device_properties.max_num_active_wavefronts_per_cu_ * device_properties.wavefront_size_) / blockSize)) + : 1; + + const auto optimal_split = static_cast(std::floor((1.0 *max_occupancy * num_cu) / (grid_size))); if (optimal_split > 1) { - k_batch = optimal_split; + k_batch = oversubscription * optimal_split; } if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Block size: " << blockSize << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Oversubscription factor: " << oversubscription << " (oversubscription enabled = " << std::to_string(enable_oversubscription) << ")"<< std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl; std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index b7947309e4..bd7463ee9b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -556,7 +556,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight return false; if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + { return false; + } if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index b518f635b8..28c2a47901 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -378,7 +378,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_right_pads_, begin(input_right_pads)); - std::vector split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128, 256}; + std::vector split_k_list = {/*Split-k parameter autodeduction*/-1, 1, 2, 4, 8, 16, 32, 64, 128}; bool profile_all = true; if(split_k != "all") {