From 7460d19460af8fc6da4ffa407518ad71d3f03dad Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Fri, 27 Sep 2024 03:18:39 +0800 Subject: [PATCH] [CK_TILE] Fix compiler related FA bwd issues (#1530) * add barriers * tail bias barriers * adjust bf16/hd256 tol * continue adjust bf16/hd256 tol [ROCm/composable_kernel commit: 9d69a099a462f01794cc3ea945403b3f00827806] --- example/ck_tile/01_fmha/fmha_bwd.cpp | 17 +++++++++++++++-- ...mha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 6 ++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index efae4e284a..c2f554f6cc 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[]) // different threshold for different dtype template -auto get_elimit(int /*init_method*/) +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) +{ + double rtol = 1e-2; + double atol = 1e-2; + if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN + { + rtol = 3.2e-2; + atol = 3.2e-2; + } + return ck_tile::make_tuple(rtol, atol); +} + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } // clang-format on - auto [rtol, atol] = get_elimit(init_method); + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); bool dq_cur_pass = ck_tile::check_err(dq_host_result, dq_host_ref, std::string("Error: QGrad Incorrect results!"), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 9e6a2725c9..3156e4a356 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }, s_acc, bias_s_tile); + __builtin_amdgcn_sched_barrier(0); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 4, OGrad@V Gemm2 auto dp_acc = SPGradBlockTileType{}; @@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 5, P^T(PGrad^T - D) auto ds = SPGradBlockTileType{}; @@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP Policy::template MakeBiasTileDistribution()); shuffle_tile(dbias_tile, shuffled_dbias_tile); store_tile(dbias_dram_window, dbias_tile); + __builtin_amdgcn_sched_barrier(0); } // STAGE 6, SGrad^T@Q^T Gemm3 @@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP move_tile_window(ds_lds_read_window, {0, kK4}); HotLoopScheduler::template GemmStagedScheduler<3>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 7, SGrad@K^T Gemm4 auto dq_acc = QGradBlockTileType{}; clear_tile(dq_acc); @@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); HotLoopScheduler::template GemmStagedScheduler<4>(); + __builtin_amdgcn_sched_barrier(0); // Results Scale if constexpr(FmhaDropout::IsDropout)