[CK Grouped Gemm] Fix workspace stride in two stage kernel (#3412)

* Use correct workspace stride

* Use correct stride in elementwise kernel

* Fix test by adding padder

* No UTF-8 in comments

* Remove unnecessary changes

* Remove more unnecessary changes

* Use non-padded strides for workspace

* Disable two stage kernel for RRR+MNKPadding+kbatch>2

Partially fixes AICK-441
This commit is contained in:
Johannes Graner
2025-12-19 10:04:48 +01:00
committed by GitHub
parent b188a2a896
commit 323e014799

View File

@@ -156,10 +156,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
ComputeDataType>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
// Use gemm_padder for consistent descriptor creation
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
{
const auto c_grid_desc_m_n = [&]() {
const auto e_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
@@ -170,26 +176,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
// Use gemm_padder for consistent padding (same as C descriptor)
return gemm_padder.PadCDescriptor_M_N(e_grid_desc_m_n);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
@@ -226,7 +214,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
}
using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N;
using EGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(0, 0, 0));
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
using DsGridPointer = decltype(MakeDsGridPointer());
using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
@@ -339,6 +327,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_.reserve(group_count_);
elementwise_c_grid_descs_m_n_.reserve(group_count_);
elementwise_e_grid_descs_m_n_.reserve(group_count_);
elementwise_d_grid_descs_m_n_.reserve(group_count_);
ds_grid_pointer_.reserve(group_count_);
group_grid_size_.reserve(group_count_);
@@ -358,15 +347,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_e = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm64::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm64::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm64::CalculateKPadded(K, K_BATCH);
const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(K, K_BATCH);
// Two different strides are needed for TwoStage split-K:
// 1. workspace_stride (below): Stride for intermediate workspace (C grid)
// - Used by GEMM kernel to write workspace tiles
// 2. gemm_descs[i].stride_C_: User-provided stride for final output (E tensor)
// - Used by elementwise kernel to write final results
const index_t workspace_stride =
is_same<tensor_layout::gemm::RowMajor, ELayout>::value ? N : M;
const auto c_grid_desc_m_n =
GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, stride_e);
GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, workspace_stride);
DsGridDesc_M_N ds_grid_desc_m_n;
DsGridPointer p_ds_grid;
@@ -415,7 +411,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
K,
stride_a,
stride_b,
stride_e,
workspace_stride,
m_padded,
n_padded,
k_padded,
@@ -425,7 +421,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
// Create E grid descriptor with user-provided stride (not workspace stride)
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, gemm_descs[i].stride_C_);
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
elementwise_e_grid_descs_m_n_.push_back(e_grid_desc_m_n);
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
ds_grid_pointer_.push_back(p_ds_grid);
// Store a copy of E pointers for elementwise kernel destination
@@ -548,6 +549,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
std::vector<index_t> group_grid_size_;
std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
std::vector<EGridDesc_M_N> elementwise_e_grid_descs_m_n_;
std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
std::vector<DsGridPointer> ds_grid_pointer_;
std::vector<void*> e_ptrs_;
@@ -810,7 +812,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
0,
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
arg.elementwise_d_grid_descs_m_n_[i]),
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
make_tuple(arg.elementwise_e_grid_descs_m_n_[i]),
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]),
@@ -846,6 +848,23 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return false;
}
// TODO: Fix this.
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
"supported for all possible shapes!"
<< std::endl;
}
return false;
}
bool supported = true;
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)