From 80117e7ecc0c378c3afd6c5c3f8f42e7e9690751 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Fri, 19 Dec 2025 10:04:48 +0100 Subject: [PATCH] [CK Grouped Gemm] Fix workspace stride in two stage kernel (#3412) * Use correct workspace stride * Use correct stride in elementwise kernel * Fix test by adding padder * No UTF-8 in comments * Remove unnecessary changes * Remove more unnecessary changes * Use non-padded strides for workspace * Disable two stage kernel for RRR+MNKPadding+kbatch>2 Partially fixes AICK-441 [ROCm/composable_kernel commit: 323e01479940237ea24a078b8616fcf93a6b112e] --- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 71 ++++++++++++------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 92b53bc31b..9b5aab5c85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -156,10 +156,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ComputeDataType>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; + + // Use gemm_padder for consistent descriptor creation + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, KPerBlock}; + template static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) { - const auto c_grid_desc_m_n = [&]() { + const auto e_grid_desc_m_n = [&]() { if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); @@ -170,26 +176,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } + // Use gemm_padder for consistent padding (same as C descriptor) + return gemm_padder.PadCDescriptor_M_N(e_grid_desc_m_n); } static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, @@ -226,7 +214,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage } using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; - using EGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(0, 0, 0)); using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); using DsGridPointer = decltype(MakeDsGridPointer()); using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); @@ -339,6 +327,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage gemm_kernel_args_.reserve(group_count_); elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_e_grid_descs_m_n_.reserve(group_count_); elementwise_d_grid_descs_m_n_.reserve(group_count_); ds_grid_pointer_.reserve(group_count_); group_grid_size_.reserve(group_count_); @@ -358,15 +347,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t stride_a = gemm_descs[i].stride_A_; const index_t stride_b = gemm_descs[i].stride_B_; - const index_t stride_e = gemm_descs[i].stride_C_; const index_t m_padded = GridwiseGemm64::CalculateMPadded(M); const index_t n_padded = GridwiseGemm64::CalculateNPadded(N); const index_t k_padded = GridwiseGemm64::CalculateKPadded(K, K_BATCH); const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(K, K_BATCH); + // Two different strides are needed for TwoStage split-K: + // 1. workspace_stride (below): Stride for intermediate workspace (C grid) + // - Used by GEMM kernel to write workspace tiles + // 2. gemm_descs[i].stride_C_: User-provided stride for final output (E tensor) + // - Used by elementwise kernel to write final results + const index_t workspace_stride = + is_same::value ? N : M; + const auto c_grid_desc_m_n = - GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, stride_e); + GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, workspace_stride); DsGridDesc_M_N ds_grid_desc_m_n; DsGridPointer p_ds_grid; @@ -415,7 +411,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage K, stride_a, stride_b, - stride_e, + workspace_stride, m_padded, n_padded, k_padded, @@ -425,7 +421,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage gemm_kernel_args_.emplace_back( std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + // Create E grid descriptor with user-provided stride (not workspace stride) + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(M, N, gemm_descs[i].stride_C_); + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_e_grid_descs_m_n_.push_back(e_grid_desc_m_n); elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); ds_grid_pointer_.push_back(p_ds_grid); // Store a copy of E pointers for elementwise kernel destination @@ -548,6 +549,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage std::vector group_grid_size_; std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_e_grid_descs_m_n_; std::vector elementwise_d_grid_descs_m_n_; std::vector ds_grid_pointer_; std::vector e_ptrs_; @@ -810,7 +812,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage 0, concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), arg.elementwise_d_grid_descs_m_n_[i]), - make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_e_grid_descs_m_n_[i]), concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), arg.ds_grid_pointer_[i]), type_convert(arg.e_ptrs_[i]), @@ -846,6 +848,23 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return false; } + // TODO: Fix this. + // Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0` + if(std::is_same::value && + std::is_same::value && + std::is_same::value && + getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not " + "supported for all possible shapes!" + << std::endl; + } + return false; + } + bool supported = true; bool isWave64 = get_warp_size() == 64; for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)