From fb64a4453c62a602145fd13564c6f8327c187c98 Mon Sep 17 00:00:00 2001 From: Yi DING <28386673+DDEle@users.noreply.github.com> Date: Mon, 30 Mar 2026 01:45:16 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#5915 (commit a72cf7d) [CK_TILE] Fix FMHA BWD register pressure by wrapping num_total_loop with amd_wave_read_first_lane (#5915) ## Motivation In three FMHA backward pipelines, `num_total_loop` is computed without `amd_wave_read_first_lane()`, so the compiler treats it as a VGPR even though it is logically uniform across all lanes. This raises register pressure, and under high pressure the compiler may reuse VGPRs across overlapping live ranges. This was confirmed via assembly inspection: the compiler reused `v52:v53` as both the B-matrix input for dK MFMAs and an intermediate value for dV, producing incorrect dK/dV gradients. ## Technical Details Wrap `num_total_loop` with `amd_wave_read_first_lane()` in three pipelines: - `block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr` - `block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp` - `block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr` This promotes `num_total_loop` to an SGPR, eliminating the excess register pressure and the incorrect VGPR reuse. ## Test Plan ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 3 ++- .../block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 3 ++- .../block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index d12310add3..f4cc4bf3e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -159,7 +159,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 79bf963cf7..97db0f95c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -159,7 +159,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 966e2ddff4..b65d8ec8f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -245,7 +245,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0)