diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 386fe93715..de9608bcb4 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -18,9 +18,12 @@ template + typename CLayout, + bool Persistent> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { + if constexpr(Persistent) + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; constexpr bool kPadN = false; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 4c9fecaba6..aec5f6a116 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -213,7 +213,8 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -226,5 +227,6 @@ template + typename CLayout, + bool Persistent = false> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 3010130e6c..bf455a6415 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -162,7 +162,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t stride_C, ck_tile::index_t kbatch, int n_warmup, - int n_repeat) + int n_repeat, + bool persistent) { ck_tile::GemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); @@ -176,9 +177,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = - gemm_calc( + float ave_time; + if(persistent) + { + ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } + else + { + ave_time = gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -193,8 +216,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, << " B_Type=" << DataTypeTraits::name << " C_Type=" << DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } @@ -229,6 +252,7 @@ int run_gemm_example_with_layouts(int argc, int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -316,7 +340,8 @@ int run_gemm_example_with_layouts(int argc, stride_C, kbatch, n_warmup, - n_repeat); + n_repeat, + persistent); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 0a094c29fe..645263d26d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -32,7 +32,8 @@ template + typename CLayout, + bool Persistent> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -61,7 +62,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& BLayout, CLayout, GemmConfig::TransposeC, - GemmConfig::UseStructuredSparsity>; + GemmConfig::UseStructuredSparsity, + Persistent>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -111,7 +113,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 2e82e21ba1..95fb1bd834 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/config.hpp" +#include #include #include @@ -138,4 +139,33 @@ struct is_specialization_of, RefTemplate> : std::true_type { }; +// Helper to get a tuple element or default type +namespace detail { + +template +struct tuple_element_or_default_dispatch +{ + using type = DefaultType; +}; + +template +struct tuple_element_or_default_dispatch +{ + using type = std::tuple_element_t; +}; + +} // namespace detail + +template +struct tuple_element_or_default +{ + using Tuple = remove_cvref_t; + static constexpr bool is_within_bounds = Idx < std::tuple_size_v; + using type = typename detail:: + tuple_element_or_default_dispatch::type; +}; +template +using tuple_element_or_default_t = + typename tuple_element_or_default::type; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 9c25104cd7..fea6633f9f 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -9,7 +9,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" +#include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -142,6 +144,21 @@ struct GemmKernel using CLayout = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + // Get the persistent kernel if the pipeline has it available + struct has_persistent_kernel + { + template + using has_persistent_type = decltype(T::UsePersistentKernel); + + static constexpr bool value = []() { + if constexpr(is_detected{}) + return GemmPipeline::UsePersistentKernel; + else + return false; + }(); + }; + static constexpr bool PersistentKernel = has_persistent_kernel::value; + using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. @@ -163,6 +180,23 @@ struct GemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using Kernel = GemmKernel; + const auto kernel = kentry; + int occupancy; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) @@ -693,6 +727,8 @@ struct GemmKernel c_block_window, c_block_tile, smem_ptr_0); } + // Non-persistent kernel entry point + template > CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); @@ -739,6 +775,74 @@ struct GemmKernel } } } + + // Persistent kernel entry point + template , typename = void> + CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + { + const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto num_tiles = + __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); + auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + + while(block_id < num_work) + { + // Get the tile index for this block + const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + // Get the SplitK offset for this block + const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + // Run the GEMM + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + } + // Advance to the next work item + block_id += grid_size; + if(block_id >= num_work) + { + break; + } + } + } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index fc04af5cdb..598bd68666 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -23,3 +23,8 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") else() message("Skipping ck_tile_gemm tests for current target") endif() + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index bd1502516b..b9d3f57dbb 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include +#include #include "gtest/gtest.h" @@ -21,6 +22,9 @@ using Mem = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; +using Persistent = std::true_type; +using NonPersistent = std::false_type; + // clang-format off using KernelTypesMem = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, @@ -59,4 +63,9 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4> >; +using KernelTypesPersistent = ::testing::Types< + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent> +>; + // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp b/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp new file mode 100644 index 0000000000..1dea1ab48c --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp @@ -0,0 +1,16 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistent + +TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistent); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 85742cb3de..1892aa0e31 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -89,6 +89,8 @@ class TestCkTileGemmPipeline : public ::testing::Test using CDataType = std::tuple_element_t<6, Tuple>; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; + static constexpr bool Persistent = + ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? template @@ -130,14 +132,17 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + TransposeC, + StructuredSparsity, + Persistent>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -190,7 +195,15 @@ class TestCkTileGemmPipeline : public ::testing::Test using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -442,9 +455,6 @@ class TestCkTileGemmPipeline : public ::testing::Test "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; EXPECT_TRUE(pass); } };