mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add persistent async input scheduler for GEMM kernels (#3520)
Add signal-based synchronization for persistent GEMM kernels where input data becomes available incrementally. Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation: chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks Key components: - PersistentAsyncInputScheduler struct with tiles_per_chunk_m, chunk_signals, tile_idx_pivot_m, and num_chunks fields - wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency - IsSupportedArgument validation for scheduler parameters - Example demonstrating async input scheduling with simulated producer - GTest unit tests covering all layout combinations
This commit is contained in:
@@ -26,6 +26,36 @@ struct workgroup_barrier
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Reduces power consumption during polling by leveraging wave-level sleep instructions
|
||||
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
// Limit active polling to first wave to reduce memory traffic and power
|
||||
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
|
||||
if(threadIdx.x < wave_size)
|
||||
{
|
||||
uint32_t loaded_value = 0;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
|
||||
while(loaded_value != value)
|
||||
{
|
||||
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
|
||||
// busy-wait
|
||||
__builtin_amdgcn_s_sleep(1);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
|
||||
Reference in New Issue
Block a user