From 70f4b54dfde2a8d5aa336d79c3d2bde4679d95fe Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Sun, 7 Sep 2025 17:18:35 -0400 Subject: [PATCH] feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721) * docs(README): update readme with new build instructions * feat(grouped_gemm): add support back for non persistent kernel * refactor(grouped_gemm): simplify tensor creation * refactor(grouped_gemm): Persistance is now GemmConfig value for easier management * chore(grouped_gemm): add print statements to ease debugging * WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration * fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity * WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines * chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle * style: clang format work so far * BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline * fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors * refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm * chore(grouped_gemm): add/remove debug comments and debug print statements * feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes * chore(gemm_profile): add new argument combinations * fix: branch cleanup, formatting, refactoring * fix: branch cleanup, formatting, refactoring * chore(changelog): update changelog to reflect new featuer * address review comments & nit [ROCm/composable_kernel commit: e279e9420ec8cb65b97013ea596c27c32cf42076] --- CHANGELOG.md | 2 +- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 1 + example/ck_tile/17_grouped_gemm/README.md | 78 ++++-- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 210 +++++++++++++++- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 101 +++++++- .../grouped_gemm_preshuffle.cpp | 234 ++++++++++++++++++ .../run_grouped_gemm_example.inc | 162 ++++-------- .../block/block_wp_asmem_bsmem_creg_v1.hpp | 10 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 56 +++-- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 2 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 67 +++++ .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 53 +++- script/gemm_profile.sh | 10 +- 13 files changed, 808 insertions(+), 178 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ae97b3d61..2d88da364a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added - +* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 475c13166d..cf47dc60f1 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1 +1,2 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) +add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 85a02c2231..9b8950ea2c 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function. -### Parsing Arguments -The example takes three arguments: `group_count`, `repeat`, and `warmup`: -- `group_count`: the number of GEMM operations in the group, +### Key Arguments +The example takes several arguments including `group_count`, `repeat`, and `warmup`: +- `group_count`: the number of GEMM operations in the group - `repeat`: the number of times to repeat the kernel for benchmarking -- `warmup`: the number of iterations before the actual kernel run time measure. +- `warmup`: the number of iterations before the actual kernel run time measure ```cpp // Example @@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup, ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(GetWorkspaceSize(args)); ``` + +### Advanced Features: Preshuffle and Persistence + +The grouped GEMM examples include two advanced optimization features: + +#### Weight Preshuffle +Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches. + +- **Implementation**: Available in `grouped_gemm_preshuffle.cpp` +- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration +- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts +- **Benefits**: Improved memory efficiency and reduced data movement + +#### Persistence Mode +Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy. + +- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm` +- **Usage**: `invoke_gemm` enables persistence +- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes + +Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. + Finally the arguments are passed to group_gemm and the kernel is launched. ```cpp // API @@ -151,26 +173,42 @@ mkdir build && cd build ../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_grouped_gemm -j +# The preshuffle example +make tile_example_grouped_gemm_preshuffle -j ``` This will result in an executable `build/bin/tile_example_grouped_gemm` ## example ``` args: - -Ms M dimensions - empty by default. (default:) - -Ns N dimensions - empty by default. (default:) - -Ks K dimensions - empty by default. (default:) - -stride_As Tensor A strides - it is empty by default. (default:) - -stride_Bs Tensor B strides - it is empty by default. (default:) - -stride_Cs Tensor C strides - it is empty by default. (default:) - -a_layout A tensor data layout - Row by default. (default:R) - -b_layout B tensor data layout - Row by default. (default:C) - -c_layout C tensor data layout - Row by default. (default:R) - -validate 0. No validation, 1. Validation on CPU. (default:1) - -warmup number of iterations before benchmark the kernel. (default:10) - -repeat number of iterations to benchmark the kernel. (default:100) - -group_count group count. (default:8) - -kbatch kbatch for SplitK (default:1) - -json 0: No Json, 1: Dump Results in Json format (default:0) - -jsonfile json file name to dump results (default:grouped_gemm.json) + -Ms M dimensions - (Default: empty). + -Ns N dimensions - (Default: empty). + -Ks K dimensions - (Default: empty). + -stride_As Tensor A strides - (Default: empty). + -stride_Bs Tensor B strides - (Default: empty). + -stride_Cs Tensor C strides - (Default: empty). + -a_layout A tensor data layout - (Default: Row). + -b_layout B tensor data layout - (Default: Col). + -c_layout C tensor data layout - (Default: Row). + -prec data type. fp16/fp8 - (Default: fp16). + -validate 0. No validation, 1. Validation on CPU. (Default: 1). + -warmup Number of iterations before benchmark the kernel. (Default: 10). + -repeat Number of iterations to benchmark the kernel. (Default: 100). + -group_count Group count. (Default: 16). + -kbatch kbatch for SplitK (Default: 1). + -json 0: No Json, 1: Dump Results in Json format (Default: 0). + -jsonfile json file name to dump results (Default: grouped_gemm.json). +``` + +If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes +don't match `group_count`, the example generates defaults per group index `i` (0-based): + +```text +M[i] = 256 + 256 * i +N[i] = 256 + 512 * i +K[i] = 512 + 384 * i + +stride_A[i] = K[i] +stride_B[i] = K[i] +stride_C[i] = N[i] ``` diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 527ef1e466..221543c0af 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,6 +16,155 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + template , ck_tile::sequence, ck_tile:: sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A and B tensors!"); + } +} int main(int argc, char* argv[]) { - return !run_grouped_gemm_example(argc, argv); + return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 39af33ebab..f8e21d5ee4 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" @@ -14,6 +15,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 #ifndef CK_TILE_PIPELINE_DEFAULT #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 @@ -37,6 +39,22 @@ constexpr ck_tile::index_t get_k_warp_tile() #endif } +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + template struct GemmTypeConfig; @@ -77,6 +95,8 @@ struct GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; + static constexpr bool Persistent = false; + static constexpr bool DoubleSmemBuffer = false; }; template @@ -123,6 +143,53 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +template +struct GemmConfigPreshuffleDecode : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr bool kPadK = true; + + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + +template +struct GemmConfigPreshufflePrefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool kPadK = true; +}; + template struct PipelineTypeTraits; @@ -153,9 +220,19 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; +}; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; -auto create_args(int argc, char* argv[]) +std::pair create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("Ms", "", "M dimensions - empty by default.") @@ -177,7 +254,7 @@ auto create_args(int argc, char* argv[]) .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); + return std::make_pair(result, arg_parser); } inline std::size_t get_workspace_size(const std::vector& gemm_descs) @@ -185,7 +262,24 @@ inline std::size_t get_workspace_size(const std::vector& gem return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } -template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp new file mode 100644 index 0000000000..00cbe5be83 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm.hpp" + +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = + // if preshuffle == true then num_loop is recalculated for each group in the kernel code + TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +#include "run_grouped_gemm_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + + // Preshuffle is supported only for A(Row major), B(column major) input matrices! + if(a_layout == "R" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } +} +template