diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index d1b6e6f85b..420ae03b7e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -738,6 +738,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); @@ -976,6 +981,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP decltype(ds_gemm)>(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds();