diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 9b5aab5c85..aeb6cd6d3f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -677,8 +678,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage all_have_kbatch_gt_one = arg.K_BATCH > 1; all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)); } for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) @@ -709,8 +709,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage bool not_all_have_main_k_block_loop_same = all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * - a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1)); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); @@ -848,21 +847,47 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return false; } - // TODO: Fix this. - // Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0` - if(std::is_same::value && - std::is_same::value && - std::is_same::value && - getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2) + // Check if all groups have compatible HasMainLoop values + if(!arg.gemm_kernel_args_.empty()) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + const auto& first_arg = arg.gemm_kernel_args_[0].karg_; + const auto first_desc = + GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(first_arg.M, + first_arg.MPadded, + first_arg.K, + first_arg.StrideA, + first_arg.k_batch, + first_arg.K0Padded, + first_arg.KPadded); + const bool first_has_main_loop = + GridwiseGemm64::CalculateHasMainK0BlockLoop(first_desc.GetLength(I1)); + + for(std::size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i) { - std::cout - << "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not " - "supported for all possible shapes!" - << std::endl; + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + const auto desc = + GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + const bool has_main_loop = + GridwiseGemm64::CalculateHasMainK0BlockLoop(desc.GetLength(I1)); + + if(first_has_main_loop != has_main_loop) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << std::boolalpha + << "Not all groups have compatible HasMainLoop values! " + << "Group 0: " << first_has_main_loop << ", Group " << i << ": " + << has_main_loop << std::endl; + } + return false; + } } - return false; } bool supported = true;