diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index 01c3276bf7..b01052a966 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -306,6 +306,9 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl if (arg.split_k_ < 0) { // TODO: Add split-K autodeduction. + // This will probably require adding interface to the GEMM operation for + // querying the optimal split-K value, as we cannot easily access the actual GEMM kernel + // from here. return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 070965c299..3803c2ac85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -655,7 +655,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1); const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1); - const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_; + const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN) * Conv_G_ / NumGroupsToMerge; k_batch_ = get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size); // Ensure that k_batch_ does not exceed the maximum value