diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 918d1bc52e..86e8defb83 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -583,7 +583,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); - k_batch_ = split_k; +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else +#endif + { + k_batch_ = split_k; + } const auto descs = conv_to_gemm_transformer @@ -954,6 +983,13 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index d6a2b639c3..1ab6bc446f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); - +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -532,6 +532,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 } } else +#endif { k_batch_ = split_k; } @@ -1034,6 +1035,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *