Add persistent async input scheduler for GEMM kernels (#3520)

Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
  chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks

Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
  chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations

[ROCm/composable_kernel commit: 91b4102a59]
This commit is contained in:
Max Podkorytov
2026-01-20 10:37:09 -08:00
committed by GitHub
parent 30ac278911
commit b8595c5684
11 changed files with 844 additions and 61 deletions

View File

@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
### Changed

View File

@@ -456,7 +456,8 @@ inline auto create_args()
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "gemm.json", "json file name to dump results")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
return arg_parser;
}

View File

@@ -12,6 +12,169 @@
#include "run_gemm_example_common.hpp"
#include "universal_gemm_invoker.hpp"
// Universal GEMM-specific wrapper that handles test_async flag
template <typename GemmConfig,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const CLayout c_layout = CLayout{})
{
using Invoker = UniversalInvoker;
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
// Check for async input scheduler test mode
bool test_async = arg_parser.get_int("test_async");
if(test_async)
{
// Extract parameters for async test (same as shared implementation)
const ck_tile::index_t M = arg_parser.get_int("m");
const ck_tile::index_t N = arg_parser.get_int("n");
const ck_tile::index_t K = arg_parser.get_int("k");
const ck_tile::index_t kbatch = arg_parser.get_int("split_k");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
const ck_tile::index_t stride_A = is_a_row_major ? K : M;
const ck_tile::index_t stride_B = is_b_row_major ? N : K;
const ck_tile::index_t stride_C = is_c_row_major ? N : M;
// Allocate and initialize tensors
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};
Invoker::template test_async_input_scheduler<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough>(
args, ck_tile::stream_config{nullptr, false, 1});
// Copy result from device for verification
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Compute CPU reference
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
c_m_n_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
// Verify results
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
return pass;
}
// Normal path - delegate to shared implementation
return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
arg_parser, a_layout, b_layout, c_layout);
}
// Universal GEMM-specific prec_type dispatcher that uses the wrapper
template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type_universal(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
}
if(preshuffle && a_layout != "R" && b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
using LayoutVariant = std::variant<Row, Col>;
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
if(layout == "R")
return Row{};
if(layout == "C")
return Col{};
throw std::runtime_error("Unsupported layout: " + layout);
};
auto a_layout_variant = string_to_layout(a_layout);
auto b_layout_variant = string_to_layout(b_layout);
return std::visit(
[&](auto a_layout_type, auto b_layout_type) -> int {
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
std::is_same_v<decltype(b_layout_type), Row>)
{
throw std::runtime_error("Unsupported memory layout for the input matrices when "
"BPrecType is ck_tile::pk_int4_t!");
}
else
{
return run_gemm_example_with_layouts_universal<GemmConfig,
APrecType,
BPrecType,
CPrecType>(
arg_parser, a_layout_type, b_layout_type, Row{});
}
},
a_layout_variant,
b_layout_variant);
}
template <template <typename PrecType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
@@ -19,52 +182,50 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
using Invoker = UniversalInvoker;
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
Invoker,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp16i4")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
Invoker,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{
@@ -75,11 +236,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{
@@ -90,11 +251,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else
{

View File

@@ -2,7 +2,11 @@
// SPDX-License-Identifier: MIT
#pragma once
#include <functional>
#include <chrono>
#include <thread>
#include "gemm_utils.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/device_memory.hpp"
struct UniversalInvoker
{
@@ -150,4 +154,170 @@ struct UniversalInvoker
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
true, // Persistent = true for async test
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
constexpr auto scheduler = GemmConfig::Scheduler;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
GemmConfig::NumWaveGroups,
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
GemmConfig::DoubleSmemBuffer>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const ck_tile::index_t tiles_m =
ck_tile::integer_divide_ceil(args.M, TilePartitioner::MPerBlock);
// Balance signal granularity (smaller chunks = finer control) vs overhead (more signals)
const ck_tile::index_t tiles_per_chunk = 2;
// Shift chunk assignments to test wraparound behavior
const ck_tile::index_t tile_idx_pivot = tiles_per_chunk;
// Account for pivot when allocating signal buffer
const ck_tile::index_t num_chunks =
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
std::cout << "Async Input Scheduler Test:" << std::endl;
std::cout << " M tiles: " << tiles_m << std::endl;
std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl;
std::cout << " Tile index pivot: " << tile_idx_pivot << std::endl;
std::cout << " Number of signal chunks: " << num_chunks << std::endl;
// Signals must start as zero so kernel blocks until producer sets them
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
signal_buf.SetZero();
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
// Setup async input scheduler
ck_tile::PersistentAsyncInputScheduler async_scheduler;
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
async_scheduler.chunk_signals = d_chunk_signals;
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
async_scheduler.num_chunks = num_chunks;
// Create modified host args with async scheduler
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr},
{args.b_ptr},
{},
args.e_ptr,
args.k_batch,
args.M,
args.N,
args.K,
{args.stride_A},
{args.stride_B},
{},
args.stride_E,
async_scheduler);
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
const dim3 blocks = Kernel::BlockSize();
std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< std::endl;
std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
// Separate stream prevents deadlock: kernel and signal producer must run concurrently
hipStream_t signal_stream;
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
const auto start = std::chrono::high_resolution_clock::now();
ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
// Simulate incremental input arrival by delaying signal activation
const int sleep_us = 100;
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
{
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
const uint32_t signal_val = 1;
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
&signal_val,
sizeof(uint32_t),
hipMemcpyHostToDevice,
signal_stream));
}
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
// Wait for kernel completion
HIP_CHECK_ERROR(hipDeviceSynchronize());
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::high_resolution_clock::now() - start);
std::cout << " Total time: " << duration.count() << " us" << std::endl;
std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl;
}
};

View File

@@ -91,6 +91,7 @@
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"

View File

@@ -26,6 +26,36 @@ struct workgroup_barrier
__syncthreads();
}
// Reduces power consumption during polling by leveraging wave-level sleep instructions
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
{
// Limit active polling to first wave to reduce memory traffic and power
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
if(threadIdx.x < wave_size)
{
uint32_t loaded_value = 0;
if(threadIdx.x == 0)
{
loaded_value = ld(offset);
}
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
while(loaded_value != value)
{
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
// busy-wait
__builtin_amdgcn_s_sleep(1);
if(threadIdx.x == 0)
{
loaded_value = ld(offset);
}
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
namespace ck_tile {
/// @brief Scheduler for persistent GEMM kernels with asynchronous input streaming.
///
/// This structure enables signal-based synchronization for persistent kernels where input data
/// becomes available incrementally. It divides M-dimension tiles into chunks and uses signals
/// to coordinate between the input producer and the kernel consumer.
///
/// Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation:
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
///
/// @par Typical usage pattern:
/// 1. Set tiles_per_chunk_m to group tiles into chunks (e.g., 2 or 4 tiles per chunk)
/// 2. Set tile_idx_pivot_m as offset for chunk calculation
/// 3. Set num_chunks = ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)
/// 4. Allocate chunk_signals array with size = num_chunks
/// 5. Producer sets chunk_signals[i] = 1 when chunk i's data is ready
/// 6. Kernel waits for chunk_signals[chunk_idx] before processing each tile
struct PersistentAsyncInputScheduler
{
/// @brief Number of M-dimension tiles grouped into each chunk.
/// Grouping tiles balances synchronization overhead against input streaming granularity.
/// Set to 0 to disable async scheduling.
uint32_t tiles_per_chunk_m = 0;
/// @brief Device pointer to array of signal values (uint32_t), one per chunk.
/// Producer sets signals to coordinate when input data is ready for consumption.
/// Set to nullptr to disable async scheduling.
uint32_t* chunk_signals = nullptr;
/// @brief Pivot offset for rotating the chunk assignment.
/// Allows shifting which tiles map to which chunks, useful for load balancing.
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
int32_t tile_idx_pivot_m = 0;
/// @brief Number of signal chunks allocated.
/// Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m).
/// Modulo wraparound prevents out-of-bounds access when pivot shifts chunk assignment.
uint32_t num_chunks = 0;
};
} // namespace ck_tile

View File

@@ -13,6 +13,8 @@
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/arch/workgroup_barrier.hpp"
namespace ck_tile {
@@ -30,18 +32,20 @@ namespace ck_tile {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct UniversalGemmHostArgs
{
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
CK_TILE_HOST UniversalGemmHostArgs(
const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_,
PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{})
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
@@ -53,7 +57,8 @@ struct UniversalGemmHostArgs
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
k_batch(k_batch_),
async_input_scheduler(async_input_scheduler_)
{
}
@@ -78,6 +83,7 @@ struct UniversalGemmHostArgs
};
index_t k_batch;
PersistentAsyncInputScheduler async_input_scheduler;
};
/// @brief The GEMM kernel device arguments.
@@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs
/// (in memory) of E tensor.
index_t stride_E;
index_t k_batch;
/// @brief Persistent async input scheduler for chunk-based tile scheduling.
PersistentAsyncInputScheduler async_input_scheduler = {};
};
/// @brief The Universal GEMM kernel template.
@@ -201,7 +209,7 @@ struct UniversalGemmKernel
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
// Detect persistent kernel support to select appropriate entry point
struct has_persistent_kernel
{
template <typename T>
@@ -216,7 +224,7 @@ struct UniversalGemmKernel
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
// Detect custom output offset support for advanced partitioning schemes
struct has_tile_partitioner_output_offset_impl
{
template <typename T, typename KernelArgs>
@@ -272,10 +280,10 @@ struct UniversalGemmKernel
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
* @brief Calculate grid size that maximizes hardware utilization for persistent kernels.
* @return Grid size that fills all compute units at maximum occupancy.
* @note Persistent kernels loop over tiles, so grid size should match hardware capacity
* rather than problem size.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
@@ -315,7 +323,8 @@ struct UniversalGemmKernel
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
hostArgs.k_batch,
hostArgs.async_input_scheduler};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -325,11 +334,8 @@ struct UniversalGemmKernel
struct SplitKBatchOffset
{
// This structure distributes work evenly among splitkk workgroups
// It's based on a principle that if there is enough work to fill all workgroups,
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
// and leave the potential tail for last(splitk - 1) indexed workgroup.
// Balances K-dimension work across batches to maximize parallelism while minimizing
// load imbalance. Uses ceil division to distribute remainder work evenly.
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
@@ -658,6 +664,28 @@ struct UniversalGemmKernel
return false;
}
}
// Verify async scheduler parameters to prevent division-by-zero and invalid memory access
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
if(kargs.async_input_scheduler.tiles_per_chunk_m == 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("tiles_per_chunk_m must be positive when chunk_signals is set!");
}
return false;
}
if(kargs.async_input_scheduler.num_chunks == 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("num_chunks must be positive when chunk_signals is set!");
}
return false;
}
}
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
}
@@ -1177,12 +1205,30 @@ struct UniversalGemmKernel
while(block_id < num_work)
{
s_waitcnt_barrier();
// Get the tile index for this block
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Synchronize with producer to ensure input data is ready before processing tile
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
const auto tiles_per_chunk =
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
const auto tile_idx_pivot =
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
const auto num_chunks =
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
if(tiles_per_chunk > 0 && num_chunks > 0)
{
// Pivot allows rotating chunk assignments for load balancing
const auto chunk_idx = amd_wave_read_first_lane(
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
}
}
// Get the SplitK offset for this block
const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);

View File

@@ -3,6 +3,7 @@
add_subdirectory(image_to_column)
add_subdirectory(gemm)
add_subdirectory(gemm_persistent_async_input)
add_subdirectory(gemm_weight_preshuffle)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)

View File

@@ -0,0 +1,19 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Test for persistent async input GEMM - currently targeting gfx95
set(PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS
-mllvm
-enable-noalias-to-md-conversion=0
)
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_gemm_persistent_async_input test_gemm_persistent_async_input.cpp)
target_compile_options(test_ck_tile_gemm_persistent_async_input PRIVATE ${PERSISTENT_ASYNC_INPUT_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping test_ck_tile_gemm_persistent_async_input for current target - requires gfx95")
endif()

View File

@@ -0,0 +1,304 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include <chrono>
#include <thread>
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F16 = ck_tile::fp16_t;
using F32 = ck_tile::fp32_t;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
class TestGemmPersistentAsyncInput : public ::testing::Test
{
protected:
// Use larger M to ensure tiles_m > tile_idx_pivot, exercising the async scheduler
static constexpr ck_tile::index_t M = 1536; // 6 tiles with M_Tile=256
static constexpr ck_tile::index_t N = 1024;
static constexpr ck_tile::index_t K = 512;
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
template <bool IsRowMajor>
static constexpr ck_tile::index_t get_default_stride(ck_tile::index_t row, ck_tile::index_t col)
{
if constexpr(IsRowMajor)
return col;
else
return row;
}
void Run()
{
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
ck_tile::index_t stride_A = get_default_stride<is_a_row_major>(M, K);
ck_tile::index_t stride_B = get_default_stride<is_b_row_major>(K, N);
ck_tile::index_t stride_C = get_default_stride<is_c_row_major>(M, N);
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
// Fill input tensors with random values
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
// Copy input data to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
c_m_n_host_ref.SetZero();
// Compute reference result on host
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
// Setup kernel configuration for persistent async input GEMM
constexpr int kBlockPerCu = 1;
constexpr bool kPadM = true;
constexpr bool kPadN = true;
constexpr bool kPadK = true;
constexpr bool DoubleSmemBuffer = true;
constexpr bool TransposeC = false;
constexpr bool StructuredSparsity = false;
constexpr bool Persistent = true;
constexpr int NumWaveGroup = 1;
constexpr bool Preshuffle = false;
constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
constexpr ck_tile::index_t TilePartitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TilePartitionerGroupNum,
TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
StructuredSparsity,
Persistent,
NumWaveGroup,
Preshuffle>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Intrawave::value>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<UniversalGemmProblem>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
1, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
false, // TiledMMAPermuteN_
1, // BlockedXDLN_PerWarp_
DoubleSmemBuffer>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Calculate tiles and chunks for async scheduler.
// Uses modulo wraparound like PyTorch - chunk_idx = (iM + pivot) / tiles_per_chunk %
// num_chunks
constexpr ck_tile::index_t tiles_per_chunk = 2;
constexpr ck_tile::index_t tile_idx_pivot = 2;
const ck_tile::index_t tiles_m = ck_tile::integer_divide_ceil(M, M_Tile);
// With add logic, max chunk_idx = (tiles_m - 1 + pivot) / tiles_per_chunk
// So num_chunks = ceil((tiles_m + pivot) / tiles_per_chunk)
const ck_tile::index_t num_chunks =
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
// Validate async scheduler configuration
// With M=1536, M_Tile=256: tiles_m=6, num_chunks=ceil((6+2)/2)=4
ASSERT_GT(num_chunks, 0) << "Test requires num_chunks > 0 to exercise async scheduler";
ASSERT_GT(tiles_per_chunk, 0) << "tiles_per_chunk must be positive";
// Allocate chunk signals (initialized to zero)
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
signal_buf.SetZero();
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
ASSERT_NE(d_chunk_signals, nullptr) << "Failed to allocate signal buffer";
// Setup async input scheduler
ck_tile::PersistentAsyncInputScheduler async_scheduler;
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
async_scheduler.chunk_signals = d_chunk_signals;
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
async_scheduler.num_chunks = num_chunks;
// Create UniversalGemmHostArgs with async scheduler
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({a_m_k_dev_buf.GetDeviceBuffer()},
{b_k_n_dev_buf.GetDeviceBuffer()},
{},
c_m_n_dev_buf.GetDeviceBuffer(),
1, // k_batch
M,
N,
K,
{stride_A},
{stride_B},
{},
stride_C,
async_scheduler);
// Create kernel args using UniversalGemmKernel
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
// Validate kernel args match host configuration
ASSERT_EQ(kargs.async_input_scheduler.chunk_signals, d_chunk_signals)
<< "Kernel args chunk_signals doesn't match host configuration";
ASSERT_EQ(kargs.async_input_scheduler.tiles_per_chunk_m,
static_cast<uint32_t>(tiles_per_chunk))
<< "Kernel args tiles_per_chunk_m doesn't match host configuration";
ASSERT_EQ(kargs.async_input_scheduler.tile_idx_pivot_m,
static_cast<int32_t>(tile_idx_pivot))
<< "Kernel args tile_idx_pivot_m doesn't match host configuration";
// Setup grid and blocks for persistent kernel
ck_tile::stream_config stream_cfg{nullptr, false};
const dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
GTEST_SKIP() << "Kernel arguments not supported, skipping test";
return;
}
// Create a separate stream for setting signals
// Using the same stream would deadlock - memcpy waits for kernel, kernel waits for signal
hipStream_t signal_stream;
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
// Launch kernel
ck_tile::ignore = ck_tile::launch_kernel(
stream_cfg, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
// Simulate producer setting chunk signals with interleaved sleep
// This simulates async input becoming available over time
const int sleep_us = 100; // microseconds between chunks
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
{
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
const uint32_t signal_val = 1;
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
&signal_val,
sizeof(uint32_t),
hipMemcpyHostToDevice,
signal_stream));
}
// Wait for all signals to be set
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
// Wait for kernel completion
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Copy result back to host
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Validate results
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol = ck_tile::get_relative_threshold<ADataType, CDataType, AccDataType>(K);
const auto atol = ck_tile::get_absolute_threshold<ADataType, CDataType, AccDataType>(
max_accumulated_value, K);
bool pass = ck_tile::check_err(
c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
EXPECT_TRUE(pass);
}
};
// Define test types for different layout combinations
using RowRowRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Row, Row, Row, F16, F16, F32, F16>;
using RowColRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Row, Col, Row, F16, F16, F32, F16>;
using ColRowRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Col, Row, Row, F16, F16, F32, F16>;
using ColColRow_F16F16F32F16 = TestGemmPersistentAsyncInput<Col, Col, Row, F16, F16, F32, F16>;
// Test case for Row-Row-Row layout
TEST_F(RowRowRow_F16F16F32F16, BasicTest) { this->Run(); }
// Test case for Row-Col-Row layout
TEST_F(RowColRow_F16F16F32F16, BasicTest) { this->Run(); }
// Test case for Col-Row-Row layout
TEST_F(ColRowRow_F16F16F32F16, BasicTest) { this->Run(); }
// Test case for Col-Col-Row layout
TEST_F(ColColRow_F16F16F32F16, BasicTest) { this->Run(); }