From d143f1466f8fe5d8d2fadd78813508abfcf22df0 Mon Sep 17 00:00:00 2001 From: shay-li77 Date: Mon, 28 Jul 2025 14:39:31 +0800 Subject: [PATCH] fix mha bwd dbias random mismatch (#2570) * fix mha bwd dbias random mismatch * formatting code [ROCm/composable_kernel commit: 8ae528a1b42913a71c9ca49253b0cfd515e1c6da] --- ...lock_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index d1b6e6f85b..420ae03b7e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -738,6 +738,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); @@ -976,6 +981,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP decltype(ds_gemm)>(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds();