mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK grouped gemm] Fix grouped gemm two stage HasMainK0BlockLoop (#3466)
* Re-enable two stage kernel * Only disable on HasMainKBlockLoop mismatch * Address PR comments
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ios>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
@@ -677,8 +678,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
|
||||
all_have_kbatch_gt_one = arg.K_BATCH > 1;
|
||||
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
@@ -709,8 +709,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
|
||||
bool not_all_have_main_k_block_loop_same =
|
||||
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
|
||||
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1));
|
||||
bool not_all_have_kbatch_value_same =
|
||||
all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1);
|
||||
|
||||
@@ -848,21 +847,47 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Fix this.
|
||||
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
|
||||
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
|
||||
// Check if all groups have compatible HasMainLoop values
|
||||
if(!arg.gemm_kernel_args_.empty())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
const auto& first_arg = arg.gemm_kernel_args_[0].karg_;
|
||||
const auto first_desc =
|
||||
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(first_arg.M,
|
||||
first_arg.MPadded,
|
||||
first_arg.K,
|
||||
first_arg.StrideA,
|
||||
first_arg.k_batch,
|
||||
first_arg.K0Padded,
|
||||
first_arg.KPadded);
|
||||
const bool first_has_main_loop =
|
||||
GridwiseGemm64::CalculateHasMainK0BlockLoop(first_desc.GetLength(I1));
|
||||
|
||||
for(std::size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
std::cout
|
||||
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
|
||||
"supported for all possible shapes!"
|
||||
<< std::endl;
|
||||
const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto desc =
|
||||
GridwiseGemm64::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M,
|
||||
gemm_arg.MPadded,
|
||||
gemm_arg.K,
|
||||
gemm_arg.StrideA,
|
||||
gemm_arg.k_batch,
|
||||
gemm_arg.K0Padded,
|
||||
gemm_arg.KPadded);
|
||||
const bool has_main_loop =
|
||||
GridwiseGemm64::CalculateHasMainK0BlockLoop(desc.GetLength(I1));
|
||||
|
||||
if(first_has_main_loop != has_main_loop)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << std::boolalpha
|
||||
<< "Not all groups have compatible HasMainLoop values! "
|
||||
<< "Group 0: " << first_has_main_loop << ", Group " << i << ": "
|
||||
<< has_main_loop << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
|
||||
Reference in New Issue
Block a user