From d732e1e379f3ddcacb00ef8a9b4e7ea678541644 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 22 Aug 2025 10:01:10 +0800 Subject: [PATCH] [CK_TILE] FMHA BWD Fix Compilation with Bias (#2682) * [CK_TILE] FMHA BWD Fix Compilation with Bias * Fix appendkv kApplyRoPE [ROCm/composable_kernel commit: 4cfa2c715876fb170bace7d564403b796d5045ba] --- example/ck_tile/01_fmha/fmha_bwd.cpp | 14 -------- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 8 +++-- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 18 +++++----- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 18 +++++----- ...mha_bwd_pipeline_trload_default_policy.hpp | 35 +++---------------- 5 files changed, 28 insertions(+), 65 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9c2907778f..9f1e0f6948 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - - printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " - "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", - fmha_traits.hdim_q, - fmha_traits.hdim_v, - fmha_traits.data_type.c_str(), - fmha_traits.is_group_mode, - static_cast(fmha_traits.mask_type), - static_cast(fmha_traits.bias_type), - fmha_traits.has_dbias, - fmha_traits.has_dropout, - fmha_traits.is_store_randval, - fmha_traits.is_deterministic); - fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 81075d0ec6..66f51459af 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel {0, i_n0}); // If kApplyRoPe is false, we set the rotary_dim to 0 - auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0; - + auto rotary_dim = [&]() { + if constexpr(kApplyRoPE) + return kargs.rotary_dim; + else + return 0; + }(); FmhaPipeline{}(q_dram_window, k_dram_window, i_page_block_k, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 1d95bc2801..9a31498dd1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {seqlen_q_start, bias_origin.at(number<1>{})}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 65f70c4f62..3112070271 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = - make_tile_window(Policy::template TransformXDramTensorView( - bias_dram_block_window_tmp.get_bottom_tensor_view()), + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {bias_origin.at(number<0>{}), seqlen_kv_start}, Policy::template MakeBiasTileDistribution()); auto bias_lds = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor()); auto bias_lds_write_window = make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); - auto bias_lds_read = make_tensor_view( - bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); auto bias_s_lds_read_window = - make_tile_window(bias_lds_read, - make_tuple(number{}, number{}), + make_tile_window(bias_lds_write_window.get_bottom_tensor_view(), + bias_lds_write_window.get_window_lengths(), bias_lds_write_window.get_window_origin(), Policy::template MakeBiasSTileDistribution()); @@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - async_load_tile(bias_lds_write_window, bias_dram_window); - __builtin_amdgcn_s_waitcnt(3952); + const auto bias_tile = load_tile(bias_dram_window); + auto shuffled_bias_tile = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(shuffled_bias_tile, bias_tile); + store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); tile_elementwise_inout( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 7849c931f7..6259e5b473 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor() + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor() { - return MakeXLdsWriteBlockDescriptor(); + return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor(); } template @@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kQKHeaddim>(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor() - { - return MakeXLdsReadBlockDescriptor(); - } template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() @@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - - constexpr index_t N1 = min(static_cast(GetAlignmentBias()), - kMPerBlock * kNPerBlock / kBlockSize); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = get_warp_size() / N0; - constexpr index_t M2 = kMPerBlock / M1 / M0; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution(); } template @@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return sizeof(typename Problem::BiasDataType) * - MakeBiasLdsWriteBlockDescriptor().get_element_space_size(); + MakeBiasLdsBlockDescriptor().get_element_space_size(); else return 0; }