[CK_TILE] FMHA Fix synchronization issues in BWD pipelines (#2876)

* Run ctest with --output-on-failure

* Fix synchronization issues in bwd pipelines

The bwd kernel reuses the same area of LDS for ds (SGrad), bias and
dbias (BiasGrad). This means that there must be block_sync_lds between
loading one tensor and storing another to the same area.

Heavy instructions like MFMA/WMMA and global loads are executed between
reuses of the same memory so in MOST cases loading is finished by all
warps before storing is started. However, sometimes warps progress at
different speeds.
Running the tests multiple times and, preferably, with multiple
processes on the same GPU helps to trigger this issue:

bin/test_ck_tile_fmha_bwd_bf16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure
This commit is contained in:
Anton Gorenko
2025-09-19 12:34:45 +06:00
committed by GitHub
parent dd249f1cd6
commit 2aec38f9ec
5 changed files with 25 additions and 9 deletions

View File

@@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
// SGrad and Bias use the same address in LDS, finish loading ds on the previous
// iteration to reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
@@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
// SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
// reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
@@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
return cast_tile<BiasGradDataType>(ds);
}
}();
// Finish loading bias_s to reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, dbias);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
@@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS.
block_sync_lds();
}
// SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
// bias is not used, loading ds in the hot loop to reuse LDS.
block_sync_lds();
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();

View File

@@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
// LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
}
s_waitcnt</*vmcnt=*/0>();

View File

@@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
// LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);

View File

@@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);