mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add persistent async input scheduler for GEMM kernels (#3520)
Add signal-based synchronization for persistent GEMM kernels where input data becomes available incrementally. Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation: chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks Key components: - PersistentAsyncInputScheduler struct with tiles_per_chunk_m, chunk_signals, tile_idx_pivot_m, and num_chunks fields - wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency - IsSupportedArgument validation for scheduler parameters - Example demonstrating async input scheduling with simulated producer - GTest unit tests covering all layout combinations
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
||||
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -30,18 +32,20 @@ namespace ck_tile {
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
|
||||
struct UniversalGemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST UniversalGemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
CK_TILE_HOST UniversalGemmHostArgs(
|
||||
const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_,
|
||||
PersistentAsyncInputScheduler async_input_scheduler_ = PersistentAsyncInputScheduler{})
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
@@ -53,7 +57,8 @@ struct UniversalGemmHostArgs
|
||||
stride_Bs(stride_Bs_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
k_batch(k_batch_),
|
||||
async_input_scheduler(async_input_scheduler_)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -78,6 +83,7 @@ struct UniversalGemmHostArgs
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
PersistentAsyncInputScheduler async_input_scheduler;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
@@ -111,6 +117,8 @@ struct UniversalGemmKernelArgs
|
||||
/// (in memory) of E tensor.
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
/// @brief Persistent async input scheduler for chunk-based tile scheduling.
|
||||
PersistentAsyncInputScheduler async_input_scheduler = {};
|
||||
};
|
||||
|
||||
/// @brief The Universal GEMM kernel template.
|
||||
@@ -201,7 +209,7 @@ struct UniversalGemmKernel
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
// Detect persistent kernel support to select appropriate entry point
|
||||
struct has_persistent_kernel
|
||||
{
|
||||
template <typename T>
|
||||
@@ -216,7 +224,7 @@ struct UniversalGemmKernel
|
||||
};
|
||||
static constexpr bool PersistentKernel = has_persistent_kernel::value;
|
||||
|
||||
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
|
||||
// Detect custom output offset support for advanced partitioning schemes
|
||||
struct has_tile_partitioner_output_offset_impl
|
||||
{
|
||||
template <typename T, typename KernelArgs>
|
||||
@@ -272,10 +280,10 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
* @brief Calculate grid size that maximizes hardware utilization for persistent kernels.
|
||||
* @return Grid size that fills all compute units at maximum occupancy.
|
||||
* @note Persistent kernels loop over tiles, so grid size should match hardware capacity
|
||||
* rather than problem size.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
@@ -315,7 +323,8 @@ struct UniversalGemmKernel
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
hostArgs.k_batch,
|
||||
hostArgs.async_input_scheduler};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -325,11 +334,8 @@ struct UniversalGemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
// This structure distributes work evenly among splitkk workgroups
|
||||
// It's based on a principle that if there is enough work to fill all workgroups,
|
||||
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
|
||||
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
|
||||
// and leave the potential tail for last(splitk - 1) indexed workgroup.
|
||||
// Balances K-dimension work across batches to maximize parallelism while minimizing
|
||||
// load imbalance. Uses ceil division to distribute remainder work evenly.
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
@@ -658,6 +664,28 @@ struct UniversalGemmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify async scheduler parameters to prevent division-by-zero and invalid memory access
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
if(kargs.async_input_scheduler.tiles_per_chunk_m == 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("tiles_per_chunk_m must be positive when chunk_signals is set!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.async_input_scheduler.num_chunks == 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("num_chunks must be positive when chunk_signals is set!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
@@ -1177,12 +1205,30 @@ struct UniversalGemmKernel
|
||||
while(block_id < num_work)
|
||||
{
|
||||
s_waitcnt_barrier();
|
||||
// Get the tile index for this block
|
||||
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Synchronize with producer to ensure input data is ready before processing tile
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tiles_per_chunk =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto num_chunks =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
|
||||
if(tiles_per_chunk > 0 && num_chunks > 0)
|
||||
{
|
||||
// Pivot allows rotating chunk assignments for load balancing
|
||||
const auto chunk_idx = amd_wave_read_first_lane(
|
||||
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
|
||||
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
|
||||
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the SplitK offset for this block
|
||||
const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
|
||||
|
||||
Reference in New Issue
Block a user