mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK Grouped Gemm] Fix workspace stride in two stage kernel (#3412)
* Use correct workspace stride * Use correct stride in elementwise kernel * Fix test by adding padder * No UTF-8 in comments * Remove unnecessary changes * Remove more unnecessary changes * Use non-padded strides for workspace * Disable two stage kernel for RRR+MNKPadding+kbatch>2 Partially fixes AICK-441
This commit is contained in:
@@ -156,10 +156,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
ComputeDataType>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
// Use gemm_padder for consistent descriptor creation
|
||||
static constexpr auto gemm_padder =
|
||||
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
const auto e_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
|
||||
@@ -170,26 +176,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
// Use gemm_padder for consistent padding (same as C descriptor)
|
||||
return gemm_padder.PadCDescriptor_M_N(e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
@@ -226,7 +214,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
}
|
||||
|
||||
using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N;
|
||||
using EGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(0, 0, 0));
|
||||
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
|
||||
@@ -339,6 +327,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
elementwise_c_grid_descs_m_n_.reserve(group_count_);
|
||||
elementwise_e_grid_descs_m_n_.reserve(group_count_);
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
@@ -358,15 +347,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_e = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm64::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm64::CalculateNPadded(N);
|
||||
const index_t k_padded = GridwiseGemm64::CalculateKPadded(K, K_BATCH);
|
||||
const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(K, K_BATCH);
|
||||
|
||||
// Two different strides are needed for TwoStage split-K:
|
||||
// 1. workspace_stride (below): Stride for intermediate workspace (C grid)
|
||||
// - Used by GEMM kernel to write workspace tiles
|
||||
// 2. gemm_descs[i].stride_C_: User-provided stride for final output (E tensor)
|
||||
// - Used by elementwise kernel to write final results
|
||||
const index_t workspace_stride =
|
||||
is_same<tensor_layout::gemm::RowMajor, ELayout>::value ? N : M;
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, stride_e);
|
||||
GridwiseGemm64::MakeCGridDescriptor_M_N(M, N, workspace_stride);
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
DsGridPointer p_ds_grid;
|
||||
@@ -415,7 +411,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_e,
|
||||
workspace_stride,
|
||||
m_padded,
|
||||
n_padded,
|
||||
k_padded,
|
||||
@@ -425,7 +421,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
gemm_kernel_args_.emplace_back(
|
||||
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
|
||||
|
||||
// Create E grid descriptor with user-provided stride (not workspace stride)
|
||||
const auto e_grid_desc_m_n =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, gemm_descs[i].stride_C_);
|
||||
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_e_grid_descs_m_n_.push_back(e_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
@@ -548,6 +549,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
std::vector<index_t> group_grid_size_;
|
||||
|
||||
std::vector<CGridDesc_M_N> elementwise_c_grid_descs_m_n_;
|
||||
std::vector<EGridDesc_M_N> elementwise_e_grid_descs_m_n_;
|
||||
std::vector<DsGridDesc_M_N> elementwise_d_grid_descs_m_n_;
|
||||
std::vector<DsGridPointer> ds_grid_pointer_;
|
||||
std::vector<void*> e_ptrs_;
|
||||
@@ -810,7 +812,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
0,
|
||||
concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
|
||||
arg.elementwise_d_grid_descs_m_n_[i]),
|
||||
make_tuple(arg.elementwise_c_grid_descs_m_n_[i]),
|
||||
make_tuple(arg.elementwise_e_grid_descs_m_n_[i]),
|
||||
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
|
||||
arg.ds_grid_pointer_[i]),
|
||||
type_convert<EDataType*>(arg.e_ptrs_[i]),
|
||||
@@ -846,6 +848,23 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Fix this.
|
||||
// Error appears in `script/profiler_grouped_gemm.sh grouped_gemm 1 0 1 1 0 0`
|
||||
if(std::is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
std::is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
std::is_same<ELayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
getGemmSpecializationString(GemmSpec) == "MNKPadding" && arg.K_BATCH > 2)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "All RowMajor layout with MNKPadding specialization and KBatch > 2 is not "
|
||||
"supported for all possible shapes!"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
bool isWave64 = get_warp_size() == 64;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
|
||||
Reference in New Issue
Block a user