From 4fba4073d36bdcd385483d8e5cce96e5fc7766c1 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 5 Jun 2025 09:24:00 -0700 Subject: [PATCH] Revert "[CK_TILE] Tile loop persistent gemm kernel (#2191)" (#2293) This reverts commit 6b2a12ae04a22188acd1444e69d89b270525b79e. [ROCm/composable_kernel commit: 233e274077cae99f2f1deacf5044593ace5be65e] --- example/ck_tile/03_gemm/gemm_basic.cpp | 5 +- example/ck_tile/03_gemm/gemm_utils.hpp | 6 +- example/ck_tile/03_gemm/run_gemm_example.inc | 37 +------ example/ck_tile/03_gemm/universal_gemm.cpp | 16 +-- include/ck_tile/core/utility/type_traits.hpp | 30 ----- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 104 ------------------ test/ck_tile/gemm/CMakeLists.txt | 5 - .../gemm/test_gemm_pipeline_kernel_types.hpp | 9 -- .../gemm/test_gemm_pipeline_persistent.cpp | 16 --- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 22 +--- 10 files changed, 18 insertions(+), 232 deletions(-) delete mode 100644 test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index de9608bcb4..386fe93715 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -18,12 +18,9 @@ template + typename CLayout> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { - if constexpr(Persistent) - std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; constexpr bool kPadN = false; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index aec5f6a116..4c9fecaba6 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -213,8 +213,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("persistent", "0", "0:non-persistent, 1:persistent"); + .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -227,6 +226,5 @@ template + typename CLayout> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index bf455a6415..3010130e6c 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -162,8 +162,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t stride_C, ck_tile::index_t kbatch, int n_warmup, - int n_repeat, - bool persistent) + int n_repeat) { ck_tile::GemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); @@ -177,31 +176,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time; - if(persistent) - { - ave_time = gemm_calc( + float ave_time = + gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); - } - else - { - ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); - } std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -216,8 +193,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, << " B_Type=" << DataTypeTraits::name << " C_Type=" << DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") - << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; return ave_time; } @@ -252,7 +229,6 @@ int run_gemm_example_with_layouts(int argc, int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); - bool persistent = arg_parser.get_int("persistent"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -340,8 +316,7 @@ int run_gemm_example_with_layouts(int argc, stride_C, kbatch, n_warmup, - n_repeat, - persistent); + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3a7cc93df8..bc9569d342 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -19,8 +19,7 @@ template + typename CLayout> float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -49,8 +48,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& BLayout, CLayout, GemmConfig::TransposeC, - GemmConfig::UseStructuredSparsity, - Persistent>; + GemmConfig::UseStructuredSparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -100,15 +98,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 95fb1bd834..2e82e21ba1 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core/config.hpp" -#include #include #include @@ -139,33 +138,4 @@ struct is_specialization_of, RefTemplate> : std::true_type { }; -// Helper to get a tuple element or default type -namespace detail { - -template -struct tuple_element_or_default_dispatch -{ - using type = DefaultType; -}; - -template -struct tuple_element_or_default_dispatch -{ - using type = std::tuple_element_t; -}; - -} // namespace detail - -template -struct tuple_element_or_default -{ - using Tuple = remove_cvref_t; - static constexpr bool is_within_bounds = Idx < std::tuple_size_v; - using type = typename detail:: - tuple_element_or_default_dispatch::type; -}; -template -using tuple_element_or_default_t = - typename tuple_element_or_default::type; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index fea6633f9f..9c25104cd7 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -9,9 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" -#include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" -#include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -144,21 +142,6 @@ struct GemmKernel using CLayout = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - // Get the persistent kernel if the pipeline has it available - struct has_persistent_kernel - { - template - using has_persistent_type = decltype(T::UsePersistentKernel); - - static constexpr bool value = []() { - if constexpr(is_detected{}) - return GemmPipeline::UsePersistentKernel; - else - return false; - }(); - }; - static constexpr bool PersistentKernel = has_persistent_kernel::value; - using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. @@ -180,23 +163,6 @@ struct GemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ - CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 - { - using Kernel = GemmKernel; - const auto kernel = kentry; - int occupancy; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); - } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) @@ -727,8 +693,6 @@ struct GemmKernel c_block_window, c_block_tile, smem_ptr_0); } - // Non-persistent kernel entry point - template > CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); @@ -775,74 +739,6 @@ struct GemmKernel } } } - - // Persistent kernel entry point - template , typename = void> - CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const - { - const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); - const auto num_tiles = - __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); - const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); - auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); - - while(block_id < num_work) - { - // Get the tile index for this block - const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - // Get the SplitK offset for this block - const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); - const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } - // Advance to the next work item - block_id += grid_size; - if(block_id >= num_work) - { - break; - } - } - } }; } // namespace ck_tile diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 598bd68666..fc04af5cdb 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -23,8 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") else() message("Skipping ck_tile_gemm tests for current target") endif() - -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") - add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index b9d3f57dbb..bd1502516b 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -2,7 +2,6 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include #include "gtest/gtest.h" @@ -22,9 +21,6 @@ using Mem = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; -using Persistent = std::true_type; -using NonPersistent = std::false_type; - // clang-format off using KernelTypesMem = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, @@ -63,9 +59,4 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4> >; -using KernelTypesPersistent = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent> ->; - // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp b/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp deleted file mode 100644 index 1dea1ab48c..0000000000 --- a/test/ck_tile/gemm/test_gemm_pipeline_persistent.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "test_gemm_pipeline_kernel_types.hpp" -#include "test_gemm_pipeline_util.hpp" -#include "gtest/gtest.h" - -template -class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline -{ -}; - -#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistent - -TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistent); - -#include "test_gemm_pipeline_ut_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index b3146b5f8e..c388df3a41 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -76,8 +76,6 @@ class TestCkTileGemmPipeline : public ::testing::Test using CDataType = std::tuple_element_t<6, Tuple>; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; - static constexpr bool Persistent = - ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? template @@ -119,17 +117,14 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - static constexpr bool StructuredSparsity = false; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + TransposeC>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -182,15 +177,7 @@ class TestCkTileGemmPipeline : public ::testing::Test using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -359,6 +346,9 @@ class TestCkTileGemmPipeline : public ::testing::Test "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; EXPECT_TRUE(pass); } };