From a64deec3ba9e700992ae1eefccbcd1b065fae12e Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Fri, 19 Sep 2025 12:34:45 +0600 Subject: [PATCH] [CK_TILE] FMHA Fix synchronization issues in BWD pipelines (#2876) * Run ctest with --output-on-failure * Fix synchronization issues in bwd pipelines The bwd kernel reuses the same area of LDS for ds (SGrad), bias and dbias (BiasGrad). This means that there must be block_sync_lds between loading one tensor and storing another to the same area. Heavy instructions like MFMA/WMMA and global loads are executed between reuses of the same memory so in MOST cases loading is finished by all warps before storing is started. However, sometimes warps progress at different speeds. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_bwd_bf16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure [ROCm/composable_kernel commit: 2aec38f9ec67bfbdccbdb3a5c25913e5a9ba6136] --- ...fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 16 +++++++++++----- ...ha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 6 ++++++ ...a_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 6 ++++++ .../block_fmha_bwd_pipeline_default_policy.hpp | 2 +- script/launch_tests.sh | 4 +--- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..c402eaeac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds on the previous + // iteration to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to + // reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return cast_tile(ds); } }(); + // Finish loading bias_s to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, dbias); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); @@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - if constexpr(kHasBiasGrad) - { - // SGrad and BiasGrad use the same address in LDS. - block_sync_lds(); - } + // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when + // bias is not used, loading ds in the hot loop to reuse LDS. + block_sync_lds(); store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..41cb4fc306 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } s_waitcnt(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 16d9f695df..8c8d2af486 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } __builtin_amdgcn_s_waitcnt(3952); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..ad9e2959f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + max(smem_size_bias, smem_size_ds); diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 5e71e25478..17a99e62a3 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f: if tests: # Extract just the filename after the last '/' clean_tests = [os.path.basename(test) for test in tests] - print('ctest -R \"' + '|'.join(clean_tests) + '\"') + print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"') else: print('# No tests to run') ") @@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f: echo "$command" eval "$command" - -