From 237c93c85bfaeeedd874138fa1062afd05ff8cbf Mon Sep 17 00:00:00 2001 From: danyao12 Date: Mon, 15 Jul 2024 12:23:27 +0800 Subject: [PATCH] bias support --- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 16 ++++++++------- ...block_fmha_bwd_pipeline_default_policy.hpp | 20 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 6ade3c17df..9011fc02c2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR Policy::template MakeBiasTileDistribution()); BiasDataType* biast_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeQT())); + static_cast(smem_ptr) + Policy::template GetSmemSizeQT() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGradT() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE() + + Policy::template GetSmemSizeD())); auto biast_lds = make_tensor_view( biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); @@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - - const auto bias_tile = load_tile(bias_dram_window); - block_sync_lds(); + const auto bias_tile = load_tile(bias_dram_window); auto bias_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(bias_shuffle_tmp, bias_tile); @@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR st_acc, biast_tile); move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); store_tile(dbias_dram_window, dbiast_shuffle_tmp); move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); } // STAGE 6, SGrad^T@Q^T Gemm3 @@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - - const auto bias_tile = load_tile(bias_dram_window); - block_sync_lds(); + const auto bias_tile = load_tile(bias_dram_window); auto bias_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(bias_shuffle_tmp, bias_tile); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index fdaa4dd768..4cae274db9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() { - using BiasDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize; - constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType); - constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType); - - constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : (kTotalPixels / kMinVecLoad); - return kVecLoad; + // TODO: not correct! + if constexpr(kTotalPixels > 32) + return 8; + else + return 4; } template @@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() { - return GetAlignmentBias(); + // TODO: this is for 3d layout + using BiasDataType = remove_cvref_t; + return 16 / sizeof(BiasDataType); } template @@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_1 = smem_size_v; constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + - smem_size_ds; + max(smem_size_bias, smem_size_ds); constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias; constexpr index_t smem_size_stage3 = smem_size_qt; constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;