mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] FMHA Fix synchronization issues in BWD pipelines (#2876)
* Run ctest with --output-on-failure * Fix synchronization issues in bwd pipelines The bwd kernel reuses the same area of LDS for ds (SGrad), bias and dbias (BiasGrad). This means that there must be block_sync_lds between loading one tensor and storing another to the same area. Heavy instructions like MFMA/WMMA and global loads are executed between reuses of the same memory so in MOST cases loading is finished by all warps before storing is started. However, sometimes warps progress at different speeds. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_bwd_bf16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure
This commit is contained in:
@@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds on the previous
|
||||
// iteration to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
|
||||
// reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
// Finish loading bias_s to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
@@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
// SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
|
||||
// bias is not used, loading ds in the hot loop to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
|
||||
@@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
|
||||
@@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user