From 44a0e1afdb07b6977fed4890d85ad929ba4c768c Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 27 Oct 2025 19:11:23 +0000 Subject: [PATCH] Merge commit 'a46b725992bdefad16d1c30dcfe4bb8441462907' into develop --- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 1 + .../grouped_gemm_preshuffle.cpp | 107 +++++++++++++ .../run_grouped_gemm_example.inc | 24 +-- .../test_grouped_gemm_preshuffle.cpp | 29 ++-- .../test_grouped_gemm_preshuffle_util.hpp | 150 +++++++++++++++++- 5 files changed, 278 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 10d7befc06..57d3f224d8 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index b9d6a4a1bc..52b84737cc 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -167,6 +167,113 @@ float grouped_gemm(const std::vector& gemm_descs, return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + + if(splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + #include "run_grouped_gemm_example.inc" template diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index dbdbe80c5d..4eee165d66 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup, } else { - if(GemmConfig::Preshuffle) - { - // not supported yet - throw std::runtime_error( - "Persistent grouped gemm with preshuffle is not supported yet"); - } - - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse - // commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5, - // 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve - // commentCode has comments. Press enter to view. the gemm problems known on the host. - // Instead, we can just pass the pointer to the kernel and let the workgroups figure out - // which tiles to work on. This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm + // problems known on the host. Instead, we can just pass the pointer to the kernel and let + // the workgroups figure out which tiles to work on. This is useful when the gemm problems + // are generated dynamically. In this example however, we generate the `kargs` using the + // known gemm_descs, and copy the gemm descriptions to the device memory. The contents of + // the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from + // earlier stage. std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp index cf10853b3f..a9b61ac7de 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -8,11 +8,13 @@ #include "ck_tile/host.hpp" #include "test_grouped_gemm_preshuffle_util.hpp" -using F16 = ck_tile::half_t; -using F8 = ck_tile::fp8_t; -using F32 = float; -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using F16 = ck_tile::half_t; +using F8 = ck_tile::fp8_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using False = std::false_type; +using True = std::true_type; // Custom tuple-like structure for kernel configuration template , - KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>, - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>, - KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2> + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>, + + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index d2f64920fd..35b8f76642 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -39,9 +39,13 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using BDataType = typename Tuple::BDataType; using AccDataType = typename Tuple::AccDataType; using CDataType = typename Tuple::CDataType; - using PrecType = BDataType; - using DsLayout = ck_tile::tuple<>; // not used - using DsDataType = ck_tile::tuple<>; // not used + + using DsLayout = ck_tile::tuple<>; // not used + using DsDataType = ck_tile::tuple<>; // not used + + // Get the persistent value from ck_tile::bool_constant + using PersistentType = typename Tuple::Persistent; + static constexpr bool Persistent = PersistentType::value; static const bool kPadM = false; static const bool kPadN = false; @@ -231,6 +235,129 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + private: + template + void invoke_grouped_gemm_persistent(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + // Enable persistent mode for preshuffle + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = + ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -350,9 +477,20 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - invoke_grouped_gemm(gemm_descs, - ck_tile::stream_config{nullptr, false, 1}, - gemm_workspace.GetDeviceBuffer()); + if constexpr(Persistent) + { + invoke_grouped_gemm_persistent( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + invoke_grouped_gemm( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } // Copy results back to host for validation for(int i = 0; i < group_count; i++)