[rocm-libraries] ROCm/rocm-libraries#4742 (commit d340a14)

[CK_TILE] Fix FMHA async pipeline LDS sync issue

## Motivation

Fix FMHA forward async pipeline
(`block_fmha_pipeline_qr_ks_vs_async.hpp`) sync issue.
Some attention test cases intermittently fail due to a race condition
where the V tile store to LDS overwrites K tile data that is still being
read by other threads during the tail `gemm_0` operation.

## Technical Details

In the `BlockFmhaPipelineQRKSVSAsync` pipeline, K and V tiles share the
same LDS memory through a rotation schedule (`LdsSeq`).
After the tail `gemm_0` (line 458), some fast threads may proceed to
store V to LDS (line 617) before slow threads finish reading K data from
the same LDS buffer.

The fix adds an `s_barrier` synchronization after the tail `gemm_0` when
K's last sub-tile and V's first sub-tile use the same LDS buffer (i.e.,
`LdsSeq[k0_loops - 1] == LdsSeq[k0_loops]`):

`if constexpr(LdsSeq.at(number<k0_loops - 1>{}) ==
LdsSeq.at(number<k0_loops>{}))
    __builtin_amdgcn_s_barrier();`

Why `s_barrier` alone is sufficient (no s_waitcnt lgkmcnt(0) needed):
The `gemm_0` MFMA instruction internally waits for its LDS operands
(ds_read) to complete before execution
Therefore, each thread's ds_read of K data is already complete by the
time gemm_0 finishes
Only cross-thread synchronization (`s_barrier`) is needed to ensure all
threads have finished reading before any thread starts writing V
This commit is contained in:
rocking
2026-03-09 18:06:54 +00:00
committed by assistant-librarian[bot]
parent 683865895e
commit fe8b7d0c27

View File

@@ -589,6 +589,12 @@ struct BlockFmhaPipelineQRKSVSAsync
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
// Only needed when K tail and V use the same LDS buffer
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
{
__builtin_amdgcn_s_barrier();
}
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{