mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Add persistent async input scheduler for GEMM kernels (#3520)
Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks
Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations
[ROCm/composable_kernel commit: 91b4102a59]
This commit is contained in:
@@ -26,6 +26,36 @@ struct workgroup_barrier
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Reduces power consumption during polling by leveraging wave-level sleep instructions
|
||||
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
// Limit active polling to first wave to reduce memory traffic and power
|
||||
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
|
||||
if(threadIdx.x < wave_size)
|
||||
{
|
||||
uint32_t loaded_value = 0;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
|
||||
while(loaded_value != value)
|
||||
{
|
||||
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
|
||||
// busy-wait
|
||||
__builtin_amdgcn_s_sleep(1);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Scheduler for persistent GEMM kernels with asynchronous input streaming.
|
||||
///
|
||||
/// This structure enables signal-based synchronization for persistent kernels where input data
|
||||
/// becomes available incrementally. It divides M-dimension tiles into chunks and uses signals
|
||||
/// to coordinate between the input producer and the kernel consumer.
|
||||
///
|
||||
/// Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation:
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
///
|
||||
/// @par Typical usage pattern:
|
||||
/// 1. Set tiles_per_chunk_m to group tiles into chunks (e.g., 2 or 4 tiles per chunk)
|
||||
/// 2. Set tile_idx_pivot_m as offset for chunk calculation
|
||||
/// 3. Set num_chunks = ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m)
|
||||
/// 4. Allocate chunk_signals array with size = num_chunks
|
||||
/// 5. Producer sets chunk_signals[i] = 1 when chunk i's data is ready
|
||||
/// 6. Kernel waits for chunk_signals[chunk_idx] before processing each tile
|
||||
struct PersistentAsyncInputScheduler
|
||||
{
|
||||
/// @brief Number of M-dimension tiles grouped into each chunk.
|
||||
/// Grouping tiles balances synchronization overhead against input streaming granularity.
|
||||
/// Set to 0 to disable async scheduling.
|
||||
uint32_t tiles_per_chunk_m = 0;
|
||||
|
||||
/// @brief Device pointer to array of signal values (uint32_t), one per chunk.
|
||||
/// Producer sets signals to coordinate when input data is ready for consumption.
|
||||
/// Set to nullptr to disable async scheduling.
|
||||
uint32_t* chunk_signals = nullptr;
|
||||
|
||||
/// @brief Pivot offset for rotating the chunk assignment.
|
||||
/// Allows shifting which tiles map to which chunks, useful for load balancing.
|
||||
/// chunk_idx = ((tile_idx + tile_idx_pivot_m) / tiles_per_chunk_m) % num_chunks
|
||||
int32_t tile_idx_pivot_m = 0;
|
||||
|
||||
/// @brief Number of signal chunks allocated.
|
||||
/// Must equal ceil((tiles_m + tile_idx_pivot_m) / tiles_per_chunk_m).
|
||||
/// Modulo wraparound prevents out-of-bounds access when pivot shifts chunk assignment.
|
||||
uint32_t num_chunks = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user