diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 79df4e624d..475c13166d 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,2 +1 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) -add_executable(tile_example_grouped_gemm_tileloop EXCLUDE_FROM_ALL grouped_gemm_tileloop.cpp) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index bb0a0d5840..897952f03c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -16,19 +16,11 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -template -float grouped_gemm(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* kargs_ptr) +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -83,8 +75,6 @@ float grouped_gemm(const std::vector& gemm_descs, constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr bool TransposeC = false; - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -97,54 +87,41 @@ float grouped_gemm(const std::vector& gemm_descs, GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + scheduler>; using GemmPipeline = GEMM_PIPELINE; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, - DsLayout, + ck_tile::tuple<>, CLayout, - CDEElementWise, + ck_tile::element_wise::PassThrough, GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -156,20 +133,8 @@ float grouped_gemm(const std::vector& gemm_descs, UniversalGemmProblem::TransposeC, memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { @@ -186,45 +151,26 @@ float grouped_gemm(const std::vector& gemm_descs, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); + num_groups)); return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } return ave_time; } #include "run_grouped_gemm_example.inc" -constexpr bool Persistent = false; -int main(int argc, char* argv[]) -{ - try - { - return !run_grouped_gemm_example(argc, argv); - } - catch(const std::runtime_error& e) - { - std::cerr << "Runtime error: " << e.what() << '\n'; - return EXIT_FAILURE; - } -} +constexpr bool Persistent = true; +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 74efb1bdeb..89d91fbef6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -15,7 +15,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4 #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp deleted file mode 100644 index 897952f03c..0000000000 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ /dev/null @@ -1,176 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host.hpp" -#include "grouped_gemm.hpp" - -template -float grouped_gemm_tileloop(const ck_tile::stream_config& s, - const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) -{ -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = true; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - float ave_time{0}; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - constexpr auto memory_operation = memory_operation_.value; - - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - GemmPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - constexpr dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - - if(!splitk) - { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); - } - - return ave_time; -} - -#include "run_grouped_gemm_example.inc" - -constexpr bool Persistent = true; -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 921ea11720..477a87d42f 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -252,13 +252,6 @@ struct GroupedGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs, - const tuple& block_idx_2d, - const index_t block_idx_z) const - { - Run(kargs.group_karg, block_idx_2d, block_idx_z); - } - CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const @@ -277,24 +270,56 @@ struct GroupedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + __shared__ char smem_ptr_0[GetSmemSize()]; - if constexpr(UsePersistentKernel) + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - RunGemmWithPipelineSelection( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + Base::RunGemm2LDS({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } else { - Base::RunGemm({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + Base::RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -358,6 +383,69 @@ struct GroupedGemmKernel c_block_window, c_block_tile, d_block_window, smem_ptr_0); } + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const UniversalGemmKernelArgs<>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, index_t block_id, index_t group_count) const @@ -401,7 +489,7 @@ struct GroupedGemmKernel kargs.group_karg.M, kargs.group_karg.N, (block_id - kargs.block_start) % grid_size_2d); - Run(kargs, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); } // For persistent kernels diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index ac91c2f58f..22c8cf383b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -18,12 +18,14 @@ struct BaseGemmPipelineAgBgCrCompV4 static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; - CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) {