From a69e4ed8b78cfc8682abff484b054dea0c31d8aa Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 29 Sep 2025 18:03:56 -0400 Subject: [PATCH] Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933) * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * fix: incorrect validation method and Dtensor layout in test suite * docs: improved README text based on review comments * fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint [ROCm/composable_kernel commit: bebf0e9d158c13d34c9f263a9551f60fa463bc66] --- example/ck_tile/17_grouped_gemm/README.md | 10 +- .../17_grouped_gemm/grouped_gemm_multi_d.cpp | 106 ++++++++++++++++++ .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 5 +- .../run_grouped_gemm_multi_d_example.inc | 46 ++++++-- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 15 ++- 5 files changed, 163 insertions(+), 19 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 0821065098..09bf3e167a 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features: Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches. - **Implementation**: Available in `grouped_gemm_preshuffle.cpp` -- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration +- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration - **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts -- **Benefits**: Improved memory efficiency and reduced data movement + #### Persistence Mode Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy. - **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm` - **Usage**: `invoke_gemm` enables persistence -- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes #### Multi-D Operations Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output. @@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e - **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call - **Build Target**: `make tile_example_grouped_gemm_multi_d -j` -Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. +Multi-D operations supports both persistence and non-persistence modes. +Weight preshuffle supports only on non-persistence mode. ## Build ``` @@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j # The quant grouped gemm fp8 example make tile_example_quant_grouped_gemm -j ``` -This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. +Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. ## example diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 409eda8de4..98b0428d39 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -166,6 +166,112 @@ float grouped_gemm_multi_d(const std::vector& gemm_d return ave_time; } +template +float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + #include "run_grouped_gemm_multi_d_example.inc" int main(int argc, char* argv[]) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index f7727d854c..d5203a799c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -95,6 +95,7 @@ struct GemmConfigV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -170,7 +171,7 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; -using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>; +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; std::pair create_args(int argc, char* argv[]) { @@ -201,7 +202,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template > kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = + grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); } return ave_time; } @@ -322,12 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); - std::cout << "e_m_n_host_refs[i]: " << std::endl; - e_m_n_host_refs[i].print_first_n(std::cout, 10); - std::cout << std::endl; - std::cout << "e_m_n_tensors[i]: " << std::endl; - e_m_n_tensors[i].print_first_n(std::cout, 10); - std::cout << std::endl; const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 217637d605..551dc6f50d 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -324,10 +324,18 @@ struct GroupedGemmKernel } else // SingleSmemBuffer { + if constexpr(UsePersistentKernel) { - RunGemmWithPipelineSelection( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } else // Non-persistent kernel { @@ -365,6 +373,7 @@ struct GroupedGemmKernel CK_TILE_DEVICE static void RunGemmWithPipelineSelection(const ADataType* a_ptr, const BDataType* b_ptr, + const std::array& ds_ptr, CDataType* c_ptr, void* smem_ptr_0, const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, @@ -375,7 +384,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); + {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows =