Revert "[CK_TILE] Tile loop persistent gemm kernel (#2191)" (#2293)

This reverts commit ffb52783d0.
This commit is contained in:
Illia Silin
2025-06-05 09:24:00 -07:00
committed by GitHub
parent 7ea1508b59
commit 233e274077
10 changed files with 18 additions and 232 deletions

View File

@@ -9,9 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -144,21 +142,6 @@ struct GemmKernel
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
struct has_persistent_kernel
{
template <typename T>
using has_persistent_type = decltype(T::UsePersistentKernel);
static constexpr bool value = []() {
if constexpr(is_detected<has_persistent_type, GemmPipeline>{})
return GemmPipeline::UsePersistentKernel;
else
return false;
}();
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
@@ -180,23 +163,6 @@ struct GemmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using Kernel = GemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, GemmKernelArgs>;
int occupancy;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
@@ -727,8 +693,6 @@ struct GemmKernel
c_block_window, c_block_tile, smem_ptr_0);
}
// Non-persistent kernel entry point
template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
@@ -775,74 +739,6 @@ struct GemmKernel
}
}
}
// Persistent kernel entry point
template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
const auto num_tiles =
__builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
while(block_id < num_work)
{
// Get the tile index for this block
const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
// Get the SplitK offset for this block
const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
// Advance to the next work item
block_id += grid_size;
if(block_id >= num_work)
{
break;
}
}
}
};
} // namespace ck_tile