Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933)

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults

* feat(grouped_gemm_multi_d): add functionality to run persistant kernel

* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature

* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel

* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments

* fix: segfault fix by passing correct parameters for d tensors

* style: clang format

* fix: incorrect validation method and Dtensor layout in test suite

* docs: improved README text based on review comments

* fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint
This commit is contained in:
Aviral Goel
2025-09-29 18:03:56 -04:00
committed by GitHub
parent 243118c275
commit bebf0e9d15
5 changed files with 163 additions and 19 deletions

View File

@@ -324,10 +324,18 @@ struct GroupedGemmKernel
}
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // Non-persistent kernel
{
@@ -365,6 +373,7 @@ struct GroupedGemmKernel
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
@@ -375,7 +384,7 @@ struct GroupedGemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =