[CK grouped gemm] Fix grouped gemm two stage HasMainK0BlockLoop (#3466)

* Re-enable two stage kernel

* Only disable on HasMainKBlockLoop mismatch

* Address PR comments
This commit is contained in:
Johannes Graner
2025-12-23 11:33:09 +01:00
committed by GitHub
parent 4ce7d4c511
commit e1381d6a71

View File

@@ -3,6 +3,7 @@
#pragma once
#include <ios>
#include <iostream>
#include <sstream>
#include <tuple>
@@ -677,8 +678,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
@@ -709,8 +709,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
bool not_all_have_kbatch_value_same =
all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
@@ -848,21 +847,47 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return false;
}
// TODO: Fix this.
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
// Check if all groups have compatible HasMainLoop values
if(!arg.gemm_kernel_args_.empty())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
const auto& first_arg = arg.gemm_kernel_args_[0].karg_;
const auto first_desc =
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(first_arg.M,
first_arg.MPadded,
first_arg.K,
first_arg.StrideA,
first_arg.k_batch,
first_arg.K0Padded,
first_arg.KPadded);
const bool first_has_main_loop =
GridwiseGemm64::CalculateHasMainK0BlockLoop(first_desc.GetLength(I1));
for(std::size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i)
{
std::cout
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
"supported for all possible shapes!"
<< std::endl;
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
const auto desc =
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
gemm_arg.MPadded,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.k_batch,
gemm_arg.K0Padded,
gemm_arg.KPadded);
const bool has_main_loop =
GridwiseGemm64::CalculateHasMainK0BlockLoop(desc.GetLength(I1));
if(first_has_main_loop != has_main_loop)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << std::boolalpha
<< "Not all groups have compatible HasMainLoop values! "
<< "Group 0: " << first_has_main_loop << ", Group " << i << ": "
<< has_main_loop << std::endl;
}
return false;
}
}
return false;
}
bool supported = true;