diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..c402eaeac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds on the previous + // iteration to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to + // reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return cast_tile(ds); } }(); + // Finish loading bias_s to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, dbias); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); @@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - if constexpr(kHasBiasGrad) - { - // SGrad and BiasGrad use the same address in LDS. - block_sync_lds(); - } + // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when + // bias is not used, loading ds in the hot loop to reuse LDS. + block_sync_lds(); store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..41cb4fc306 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } s_waitcnt(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 16d9f695df..8c8d2af486 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } __builtin_amdgcn_s_waitcnt(3952); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..ad9e2959f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + max(smem_size_bias, smem_size_ds); diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 5e71e25478..17a99e62a3 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f: if tests: # Extract just the filename after the last '/' clean_tests = [os.path.basename(test) for test in tests] - print('ctest -R \"' + '|'.join(clean_tests) + '\"') + print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"') else: print('# No tests to run') ") @@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f: echo "$command" eval "$command" - -