mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK] Fix async pivot mismatch in persistent GEMM kernel scheduler (#5776) ## Motivation Fix pivot mismatch in the persistent GEMM kernel's async input scheduler that causes **GPU hangs** and incorrect results when used with AsyncTP (Asynchronous Tensor Parallelism) on ROCm. PyTorch's `_fused_all_gather_matmul_native` uses this persistent GEMM kernel with chunk signals to overlap communication and computation. The pivot mechanism ensures each rank starts computing from its own local shard first (which is already available), then moves to remote chunks as they arrive over the network. Because of the pivot mismatch, the kernel frequently waits on signals for chunks that have not yet arrived, while attempting to read data from completely different chunks. This synchronization desync reliably triggers infinite hangs during multi-GPU native AsyncTP execution. This fix is required to enable functional AsyncTP support on ROCm. ## Technical Details In the persistent kernel loop (`UniversalGemmKernel::operator()`), the M-tile coordinate used for data selection (`i_m`) and the M-tile coordinate used for the chunk-signal wait (`chunk_idx`) were derived from inconsistent bases: * `i_m` was computed from the **unpivoted** tile index `iM`. * `chunk_idx` was computed from the **pivoted** expression `(iM + tile_idx_pivot)`. This means the kernel could wait for chunk N's signal but then read from chunk M's memory, or vice versa. The mismatch scales with GPU count: with 2 GPUs ~50% of tiles are wrong, with 4 GPUs ~75%, etc. **The Fix:** Introduce a single pivoted M-tile index (`iM_eff`) and derive both `i_m` and `chunk_idx` from it. This guarantees the kernel always waits for the correct chunk before reading its data. *(Note: Minor cosmetic `clang-format` changes were also pulled in alongside the fix).* ## Test Plan 1. Build PyTorch with this CK change. 2. Run the specific multi-GPU AsyncTP native test: `timeout 180s env HIP_VISIBLE_DEVICES=0,1 pytest test/distributed/test_symmetric_memory.py -k test_fused_all_gather_matmul_native -q -s -x` ## Test Result Tests verify correct overlapping execution without hangs or accuracy mismatches when running the AsyncTP native path with non-zero pivots. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.