diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index d12310add3..f4cc4bf3e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -159,7 +159,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 79bf963cf7..97db0f95c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -159,7 +159,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 966e2ddff4..b65d8ec8f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -245,7 +245,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto [seqlen_q_start, seqlen_q_end] = mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + const auto num_total_loop = + amd_wave_read_first_lane(integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0)); // check early exit if no work to do. if(num_total_loop <= 0)