mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[rocm-libraries] ROCm/rocm-libraries#5776 (commit ee1bbcb)
[CK] Fix async pivot mismatch in persistent GEMM kernel scheduler (#5776) ## Motivation Fix pivot mismatch in the persistent GEMM kernel's async input scheduler that causes **GPU hangs** and incorrect results when used with AsyncTP (Asynchronous Tensor Parallelism) on ROCm. PyTorch's `_fused_all_gather_matmul_native` uses this persistent GEMM kernel with chunk signals to overlap communication and computation. The pivot mechanism ensures each rank starts computing from its own local shard first (which is already available), then moves to remote chunks as they arrive over the network. Because of the pivot mismatch, the kernel frequently waits on signals for chunks that have not yet arrived, while attempting to read data from completely different chunks. This synchronization desync reliably triggers infinite hangs during multi-GPU native AsyncTP execution. This fix is required to enable functional AsyncTP support on ROCm. ## Technical Details In the persistent kernel loop (`UniversalGemmKernel::operator()`), the M-tile coordinate used for data selection (`i_m`) and the M-tile coordinate used for the chunk-signal wait (`chunk_idx`) were derived from inconsistent bases: * `i_m` was computed from the **unpivoted** tile index `iM`. * `chunk_idx` was computed from the **pivoted** expression `(iM + tile_idx_pivot)`. This means the kernel could wait for chunk N's signal but then read from chunk M's memory, or vice versa. The mismatch scales with GPU count: with 2 GPUs ~50% of tiles are wrong, with 4 GPUs ~75%, etc. **The Fix:** Introduce a single pivoted M-tile index (`iM_eff`) and derive both `i_m` and `chunk_idx` from it. This guarantees the kernel always waits for the correct chunk before reading its data. *(Note: Minor cosmetic `clang-format` changes were also pulled in alongside the fix).* ## Test Plan 1. Build PyTorch with this CK change. 2. Run the specific multi-GPU AsyncTP native test: `timeout 180s env HIP_VISIBLE_DEVICES=0,1 pytest test/distributed/test_symmetric_memory.py -k test_fused_all_gather_matmul_native -q -s -x` ## Test Result Tests verify correct overlapping execution without hangs or accuracy mismatches when running the AsyncTP native path with non-zero pivots. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9426f49b52
commit
2bb69a24ea
@@ -1226,23 +1226,37 @@ struct UniversalGemmKernel
|
||||
s_waitcnt_barrier();
|
||||
const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
// Apply pivot to M tile index first, then use the same pivoted index
|
||||
// for both data-tile selection and chunk-signal wait.
|
||||
auto iM_eff = amd_wave_read_first_lane(iM);
|
||||
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto tiles_m = amd_wave_read_first_lane(
|
||||
integer_divide_ceil(kargs.M, TilePartitioner::MPerBlock));
|
||||
if(tiles_m > 0)
|
||||
{
|
||||
iM_eff = amd_wave_read_first_lane((iM_eff + tile_idx_pivot) % tiles_m);
|
||||
}
|
||||
}
|
||||
|
||||
const index_t i_m = amd_wave_read_first_lane(iM_eff * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Synchronize with producer to ensure input data is ready before processing tile
|
||||
if(kargs.async_input_scheduler.chunk_signals != nullptr)
|
||||
{
|
||||
const auto tiles_per_chunk =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tiles_per_chunk_m);
|
||||
const auto tile_idx_pivot =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.tile_idx_pivot_m);
|
||||
const auto num_chunks =
|
||||
amd_wave_read_first_lane(kargs.async_input_scheduler.num_chunks);
|
||||
if(tiles_per_chunk > 0 && num_chunks > 0)
|
||||
{
|
||||
// Pivot allows rotating chunk assignments for load balancing
|
||||
const auto chunk_idx = amd_wave_read_first_lane(
|
||||
((iM + tile_idx_pivot) / tiles_per_chunk) % num_chunks);
|
||||
const auto chunk_idx =
|
||||
amd_wave_read_first_lane((iM_eff / tiles_per_chunk) % num_chunks);
|
||||
workgroup_barrier chunk_barrier(kargs.async_input_scheduler.chunk_signals);
|
||||
chunk_barrier.wait_eq_wave(/*value=*/1, /*offset=*/chunk_idx);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user