Add persistent async input scheduler for GEMM kernels (#3520)

Add signal-based synchronization for persistent GEMM kernels where
input data becomes available incrementally. Uses modulo wraparound
(like PyTorch's AsyncMM) for chunk index calculation:
  chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks

Key components:
- PersistentAsyncInputScheduler struct with tiles_per_chunk_m,
  chunk_signals, tile_idx_pivot_m, and num_chunks fields
- wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency
- IsSupportedArgument validation for scheduler parameters
- Example demonstrating async input scheduling with simulated producer
- GTest unit tests covering all layout combinations
This commit is contained in:
Max Podkorytov
2026-01-20 10:37:09 -08:00
committed by GitHub
parent 8f75869408
commit 91b4102a59
11 changed files with 844 additions and 61 deletions

View File

@@ -26,6 +26,36 @@ struct workgroup_barrier
__syncthreads();
}
// Reduces power consumption during polling by leveraging wave-level sleep instructions
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
{
// Limit active polling to first wave to reduce memory traffic and power
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
if(threadIdx.x < wave_size)
{
uint32_t loaded_value = 0;
if(threadIdx.x == 0)
{
loaded_value = ld(offset);
}
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
while(loaded_value != value)
{
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
// busy-wait
__builtin_amdgcn_s_sleep(1);
if(threadIdx.x == 0)
{
loaded_value = ld(offset);
}
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)