From 668e10728293e0ccf4884ce98b5e21b4a4191e68 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Sun, 17 May 2026 02:30:48 -0400 Subject: [PATCH] fix(sparse_attn): backport PR #4742 LDS s_barrier Add s_barrier after sched_barrier when K-tail and V share LDS buffer, mirroring upstream PR #4742. Applies to both async_vsa and async_jenga pipelines. Co-Authored-By: Claude Opus 4 --- .../pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 6 ++++++ .../pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 9fe8b365b0..717d82aca7 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -430,6 +430,12 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga s.get_tile_distribution()); // Pcompute{j} __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } // store & prefetch next v, after the max reduction auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 578ad7e603..507c91a585 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -387,6 +387,12 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA s.get_tile_distribution()); // Pcompute{j} __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } // store & prefetch next v, after the max reduction if constexpr(std::is_same_v) {