Add persistent async input scheduler for GEMM kernels (#3520)

Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
  chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks

Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
  chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations
This commit is contained in:
Max Podkorytov
2026-01-20 10:37:09 -08:00
committed by GitHub
parent 8f75869408
commit 91b4102a59
11 changed files with 844 additions and 61 deletions

View File

@@ -13,6 +13,8 @@
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
#include "ck_tile/core/arch/workgroup_barrier.hpp"
namespace ck_tile {
@@ -30,18 +32,20 @@ namespace ck_tile {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct UniversalGemmHostArgs
{
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
CK_TILE_HOST UniversalGemmHostArgs(
const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_,
PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{})
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
@@ -53,7 +57,8 @@ struct UniversalGemmHostArgs
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
k_batch(k_batch_),
async_input_scheduler(async_input_scheduler_)
{
}
@@ -78,6 +83,7 @@ struct UniversalGemmHostArgs
};
index_t k_batch;
PersistentAsyncInputScheduler async_input_scheduler;
};
/// @brief The GEMM kernel device arguments.
@@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs
/// (in memory) of E tensor.
index_t stride_E;
index_t k_batch;
/// @brief Persistent async input scheduler for chunk-based tile scheduling.
PersistentAsyncInputScheduler async_input_scheduler = {};
};
/// @brief The Universal GEMM kernel template.
@@ -201,7 +209,7 @@ struct UniversalGemmKernel
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
// Detect persistent kernel support to select appropriate entry point
struct has_persistent_kernel
{
template <typename T>
@@ -216,7 +224,7 @@ struct UniversalGemmKernel
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
// Detect custom output offset support for advanced partitioning schemes
struct has_tile_partitioner_output_offset_impl
{
template <typename T, typename KernelArgs>
@@ -272,10 +280,10 @@ struct UniversalGemmKernel
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
* @brief Calculate grid size that maximizes hardware utilization for persistent kernels.
* @return Grid size that fills all compute units at maximum occupancy.
* @note Persistent kernels loop over tiles, so grid size should match hardware capacity
* rather than problem size.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
@@ -315,7 +323,8 @@ struct UniversalGemmKernel
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
hostArgs.k_batch,
hostArgs.async_input_scheduler};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -325,11 +334,8 @@ struct UniversalGemmKernel
struct SplitKBatchOffset
{
// This structure distributes work evenly among splitkk workgroups
// It's based on a principle that if there is enough work to fill all workgroups,
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
// and leave the potential tail for last(splitk - 1) indexed workgroup.
// Balances K-dimension work across batches to maximize parallelism while minimizing
// load imbalance. Uses ceil division to distribute remainder work evenly.
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
@@ -658,6 +664,28 @@ struct UniversalGemmKernel
return false;
}
}
// Verify async scheduler parameters to prevent division-by-zero and invalid memory access
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
if(kargs.async_input_scheduler.tiles_per_chunk_m == 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("tiles_per_chunk_m must be positive when chunk_signals is set!");
}
return false;
}
if(kargs.async_input_scheduler.num_chunks == 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("num_chunks must be positive when chunk_signals is set!");
}
return false;
}
}
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
}
@@ -1177,12 +1205,30 @@ struct UniversalGemmKernel
while(block_id < num_work)
{
s_waitcnt_barrier();
// Get the tile index for this block
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Synchronize with producer to ensure input data is ready before processing tile
if(kargs.async_input_scheduler.chunk_signals != nullptr)
{
const auto tiles_per_chunk =
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
const auto tile_idx_pivot =
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
const auto num_chunks =
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
if(tiles_per_chunk > 0 && num_chunks > 0)
{
// Pivot allows rotating chunk assignments for load balancing
const auto chunk_idx = amd_wave_read_first_lane(
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
}
}
// Get the SplitK offset for this block
const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);