diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f1c8c349f0..02d7e8df52 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -528,9 +528,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); const auto k_grid_size = k_dim_size_ / K0PerBlock; - // For V3 pipeline, it is beneficial to oversubscribe and consider the total grid size to be only - // the grid of the GEMM output tiles. - const auto total_grid_size = grid_size_mn; + const auto total_grid_size = grid_size_mn * Conv_G_; k_batch_ = split_k_parameters.strategy_== SplitKStrategy::BestOccupancy ? get_best_occupancy_k_batch_value(max_occupancy.value_, total_grid_size) : get_optimized_k_batch_value(max_occupancy.value_, grid_size_mn, k_grid_size);