Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* feat: generalized grouped_gemm_kernel.hpp

* feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* docs: add multi d info and trim down outdated content

* tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility

* style: clang format

* fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
Aviral Goel
2025-09-26 12:59:58 -04:00
committed by GitHub
parent e40c0acef2
commit a44bea45b2
16 changed files with 1527 additions and 216 deletions

View File

@@ -23,10 +23,13 @@ namespace ck_tile {
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
template <index_t NumDTensor = 0>
struct GroupedGemmHostArgs
{
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
@@ -34,15 +37,18 @@ struct GroupedGemmHostArgs
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
@@ -50,6 +56,7 @@ struct GroupedGemmHostArgs
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
@@ -61,7 +68,7 @@ struct GroupedGemmHostArgs
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
@@ -71,20 +78,23 @@ struct GroupedGemmHostArgs
index_t k_batch;
};
template <index_t NumDTensor = 0>
struct GemmTransKernelArg
{
UniversalGemmKernelArgs<> group_karg;
UniversalGemmKernelArgs<1, 1, NumDTensor> group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = delete;
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg,
index_t bl_start,
index_t bl_end)
: group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end}
{
}
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg)
: group_karg{std::move(karg)}, block_start{0}, block_end{0}
{
}
};
@@ -106,9 +116,12 @@ struct GroupedGemmKernel
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t NumDTensor_ = DsDataType::size();
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
@@ -140,19 +153,21 @@ struct GroupedGemmKernel
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
(UsePersistentKernel ? "Persistent" : "NonPersistent"),
(NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
(GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer"));
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs<>>& gemm_descs) -> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
return gemm_descs.size() * sizeof(GemmTransKernelArg<NumDTensor_>);
}
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
{
return group_count * sizeof(GemmTransKernelArg);
return group_count * sizeof(GemmTransKernelArg<NumDTensor_>);
}
CK_TILE_HOST static auto BlockSize() -> dim3
@@ -184,7 +199,8 @@ struct GroupedGemmKernel
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
CK_TILE_HOST static auto
GridSize(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -196,9 +212,10 @@ struct GroupedGemmKernel
}
CK_TILE_HOST static auto
MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
MakeKargs(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
-> std::vector<GemmTransKernelArg<NumDTensor_>>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;
std::vector<GemmTransKernelArg<NumDTensor_>> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
index_t grid_size = 0;
gemm_kernel_args_.reserve(group_count);
@@ -217,6 +234,7 @@ struct GroupedGemmKernel
const index_t stride_a = gemm_descs[i].stride_A;
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_e = gemm_descs[i].stride_E;
auto stride_ds = gemm_descs[i].stride_Ds;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
@@ -225,19 +243,19 @@ struct GroupedGemmKernel
grid_size += grid_size_grp;
auto karg =
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{/*ds_ptr*/},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{/*stride_ds*/},
stride_e,
gemm_descs[i].k_batch};
auto karg = UniversalGemmKernelArgs<1, 1, NumDTensor_>{
{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{gemm_descs[i].ds_ptr},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
stride_ds,
stride_e,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -245,7 +263,8 @@ struct GroupedGemmKernel
return gemm_kernel_args_;
}
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
CK_TILE_HOST static bool
IsSupportedArgument(const std::vector<GemmTransKernelArg<NumDTensor_>>& kargs)
{
for(const auto& karg : kargs)
{
@@ -262,7 +281,7 @@ struct GroupedGemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
@@ -292,8 +311,16 @@ struct GroupedGemmKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemmWithPipelineSelection2LDS(
a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
kargs.ds_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // SingleSmemBuffer
{
@@ -306,7 +333,7 @@ struct GroupedGemmKernel
{
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
@@ -340,7 +367,7 @@ struct GroupedGemmKernel
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const UniversalGemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -396,9 +423,10 @@ struct GroupedGemmKernel
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const UniversalGemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -406,7 +434,7 @@ struct GroupedGemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
@@ -453,7 +481,7 @@ struct GroupedGemmKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,
index_t block_id,
index_t group_count) const
{
@@ -485,7 +513,7 @@ struct GroupedGemmKernel
index_t group_count) const
{
const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
@@ -508,7 +536,7 @@ struct GroupedGemmKernel
const index_t group_count) const
{
const index_t grid_size = ck_tile::get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
index_t cum_grid_size = 0;