fix mha bwd dbias random mismatch (#2570)

* fix mha bwd dbias random mismatch

* formatting code
This commit is contained in:
shay-li77
2025-07-28 14:39:31 +08:00
committed by GitHub
parent 685771b875
commit 8ae528a1b4

View File

@@ -738,6 +738,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();
@@ -976,6 +981,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();