Revert "[CK_TILE] Tile loop persistent gemm kernel (#2191)" (#2293)

This reverts commit ffb52783d0.
This commit is contained in:
Illia Silin
2025-06-05 09:24:00 -07:00
committed by GitHub
parent 7ea1508b59
commit 233e274077
10 changed files with 18 additions and 232 deletions

View File

@@ -18,12 +18,9 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;

View File

@@ -213,8 +213,7 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent");
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -227,6 +226,5 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent = false>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -162,8 +162,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat,
bool persistent)
int n_repeat)
{
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
@@ -177,31 +176,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time;
if(persistent)
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
true>(
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
else
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
false>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -216,8 +193,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
@@ -252,7 +229,6 @@ int run_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -340,8 +316,7 @@ int run_gemm_example_with_layouts(int argc,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent);
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -19,8 +19,7 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -49,8 +48,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent>;
GemmConfig::UseStructuredSparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -100,15 +98,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))

View File

@@ -4,7 +4,6 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <tuple>
#include <type_traits>
#include <stdint.h>
@@ -139,33 +138,4 @@ struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
{
};
// Helper to get a tuple element or default type
namespace detail {
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch
{
using type = DefaultType;
};
template <std::size_t Idx, typename Tuple, typename DefaultType>
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
{
using type = std::tuple_element_t<Idx, Tuple>;
};
} // namespace detail
template <typename Tuple_, std::size_t Idx, typename DefaultType>
struct tuple_element_or_default
{
using Tuple = remove_cvref_t<Tuple_>;
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
using type = typename detail::
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
};
template <typename Tuple_, std::size_t Idx, typename DefaultType>
using tuple_element_or_default_t =
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
} // namespace ck_tile

View File

@@ -9,9 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -144,21 +142,6 @@ struct GemmKernel
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
struct has_persistent_kernel
{
template <typename T>
using has_persistent_type = decltype(T::UsePersistentKernel);
static constexpr bool value = []() {
if constexpr(is_detected<has_persistent_type, GemmPipeline>{})
return GemmPipeline::UsePersistentKernel;
else
return false;
}();
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
@@ -180,23 +163,6 @@ struct GemmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using Kernel = GemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, GemmKernelArgs>;
int occupancy;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
@@ -727,8 +693,6 @@ struct GemmKernel
c_block_window, c_block_tile, smem_ptr_0);
}
// Non-persistent kernel entry point
template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
@@ -775,74 +739,6 @@ struct GemmKernel
}
}
}
// Persistent kernel entry point
template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
const auto num_tiles =
__builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
while(block_id < num_work)
{
// Get the tile index for this block
const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
// Get the SplitK offset for this block
const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
// Advance to the next work item
block_id += grid_size;
if(block_id >= num_work)
{
break;
}
}
}
};
} // namespace ck_tile

View File

@@ -23,8 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
else()
message("Skipping ck_tile_gemm tests for current target")
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -2,7 +2,6 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <type_traits>
#include "gtest/gtest.h"
@@ -22,9 +21,6 @@ using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// clang-format off
using KernelTypesMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
@@ -63,9 +59,4 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
>;
using KernelTypesPersistent = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
>;
// clang-format on

View File

@@ -1,16 +0,0 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistent
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistent);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -76,8 +76,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK>
@@ -119,17 +117,14 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
static constexpr bool StructuredSparsity = false;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
StructuredSparsity,
Persistent>;
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -182,15 +177,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -359,6 +346,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
EXPECT_TRUE(pass);
}
};