From d07b37c0973139ef972c88601e05e5aeefdefef2 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 12 Jun 2026 08:40:18 +0000 Subject: [PATCH] Remove un-used element-wise functions passed through pipelines' operator() interfaces --- ...hstu_attention_no_softmax_fwd_pipeline.hpp | 61 +------------- ...o_softmax_fwd_splitkv_combine_pipeline.hpp | 14 +--- ...tention_no_softmax_fwd_trload_pipeline.hpp | 61 +------------- ...tu_attention_with_softmax_fwd_pipeline.hpp | 79 +++---------------- ...h_softmax_fwd_splitkv_combine_pipeline.hpp | 36 +-------- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 79 +++---------------- 6 files changed, 29 insertions(+), 301 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 488e3bda71..529b4079a4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -115,22 +115,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename QElementFunction, - typename BiasElementFunction, - typename SAccElementFunction, - typename PComputeElementFunction, - typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const QElementFunction& q_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -181,7 +171,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS if(seqlen_k_end <= seqlen_k_start) { clear_tile(o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; }; @@ -302,8 +291,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS clear_tile(o_acc); - q_tile = tile_elementwise_in(q_element_func, q_tile); - auto seqlen_k_curr = seqlen_k_start; using v_tile_type = decltype(load_tile(v_dram_window)); @@ -329,8 +316,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, @@ -347,8 +332,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS const auto bias_tile = load_tile(bias_dram_window); tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); + [&scale_s](auto& x, const auto& y) { + x = x * scale_s + type_convert(y); }, pcomp_tile, bias_tile); @@ -420,7 +405,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); } - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p = cast_tile(pcomp_tile); // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { @@ -458,46 +443,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - index_t seqlen_k_start, - index_t seqlen_k_end, - HstuMask mask, - float scale_s, // scaling value exerted on the immediate Q@K result - float scale_p, // scaling value exerted on the SiLU result - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - v_dram_block_window_tmp, - bias_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - seqlen_k_start, - seqlen_k_end, - mask, - scale_s, - scale_p, - smem_ptr, - dropout); - } }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp index 90394fb13c..4ef3a976cf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_splitkv_combine_pipeline.hpp @@ -55,10 +55,9 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline return Policy::template GetSmemSize(); } - template + template CK_TILE_DEVICE auto operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // M0*kOHeaddim tile - const OAccElementFunction& o_acc_element_func, ck_tile::index_t o_acc_split_stride, ck_tile::index_t num_splits) const { @@ -88,19 +87,8 @@ struct HstuAttentionNoSoftmaxFwdSplitKVCombinePipeline tile_elementwise_inout([](auto& x, const auto& y) { x = x + y; }, o_acc, o_acc_tile); }; - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile - ck_tile::index_t o_acc_split_stride, - ck_tile::index_t num_splits) const - { - return operator()(o_acc_dram_block_window_tmp, identity{}, o_acc_split_stride, num_splits); - }; }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 60c367f7ef..289b410abf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -116,22 +116,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename QElementFunction, - typename BiasElementFunction, - typename SAccElementFunction, - typename PComputeElementFunction, - typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const QElementFunction& q_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -182,7 +172,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad if(seqlen_k_end <= seqlen_k_start) { clear_tile(o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; }; @@ -309,8 +298,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad clear_tile(o_acc); - q_tile = tile_elementwise_in(q_element_func, q_tile); - auto seqlen_k_curr = seqlen_k_start; using v_tile_type = decltype(load_tile(v_dram_window)); @@ -338,8 +325,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, @@ -356,8 +341,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad const auto bias_tile = load_tile(bias_dram_window); tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); + [&scale_s](auto& x, const auto& y) { + x = x * scale_s + type_convert(y); }, pcomp_tile, bias_tile); @@ -408,7 +393,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); } - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p = cast_tile(pcomp_tile); // check whether first V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 @@ -442,46 +427,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad }); } while(seqlen_k_curr < seqlen_k_end); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - index_t seqlen_k_start, - index_t seqlen_k_end, - HstuMask mask, - float scale_s, // scaling value exerted on the immediate Q@K result - float scale_p, // scaling value exerted on the SiLU result - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - v_dram_block_window_tmp, - bias_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - seqlen_k_start, - seqlen_k_end, - mask, - scale_s, - scale_p, - smem_ptr, - dropout); - } }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 54a456a03d..bdb0f682be 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEorLSEaccDramBlockWindow, - typename QElementFunction, - typename BiasElementFunction, - typename LSEaccElementFunction, - typename SAccElementFunction, - typename PComputeElementFunction, - typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile - const LSEaccElementFunction& lse_or_lse_acc_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -202,7 +190,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS if(seqlen_k_end <= seqlen_k_start) { clear_tile(o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); if constexpr(!is_null_tile_window_v) { @@ -211,8 +198,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS set_tile(lse_or_lse_acc, -numeric::infinity()); - store_tile(lse_or_lse_acc_dram_block_window, - tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc); } return o_acc; @@ -338,8 +324,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - q_tile = tile_elementwise_in(q_element_func, q_tile); - auto seqlen_k_curr = seqlen_k_start; constexpr index_t NumPrefetchV = 2; @@ -381,8 +365,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, @@ -397,8 +379,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const auto bias_tile = load_tile(bias_dram_window); tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); + [&scale_s](auto& x, const auto& y) { + x = x * scale_s + type_convert(y); }, pcomp_tile, bias_tile); @@ -549,7 +531,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); } - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p = cast_tile(pcomp_tile); __builtin_amdgcn_sched_barrier(0x00000001); @@ -612,8 +594,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); }); - store_tile(lse_or_lse_acc_dram_block_window, - tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc); } constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); @@ -632,50 +613,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS }); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile - index_t seqlen_k_start, - index_t seqlen_k_end, - HstuMask mask, - float scale_s, // scaling value exerted on the immediate Q@K result - float scale_p, // scaling value exerted on the SiLU result - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - v_dram_block_window_tmp, - bias_dram_block_window_tmp, - identity{}, - lse_or_lse_acc_dram_block_window, - identity{}, - identity{}, - identity{}, - identity{}, - seqlen_k_start, - seqlen_k_end, - mask, - scale_s, - scale_p, - smem_ptr, - dropout); - } }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp index fb5d14ef2f..1bb0b29ea7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp @@ -63,17 +63,11 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline template + typename LSEDramBlockWindow> CK_TILE_DEVICE auto operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile LSEDramBlockWindow& lse_dram_block_window, // kM tile - const OAccElementFunction& o_acc_element_func, - const LSEaccElementFunction& lse_acc_element_func, - const LSEElementFunction& lse_element_func, index_t o_acc_split_stride, index_t num_splits, void* smem_ptr) const @@ -110,8 +104,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline auto lse_acc = load_tile(lse_acc_dram_window); - lse_acc = tile_elementwise_in(lse_acc_element_func, lse_acc); - using lse_acc_type = decltype(lse_acc); constexpr auto lse_spans = lse_acc_type::get_distributed_spans(); @@ -173,7 +165,7 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline // in case kStoreLSE is false, LSEDramBlockWindow is null if constexpr(!is_null_tile_window_v) { - store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum)); + store_tile(lse_dram_block_window, lse_logsum); } // calculate scale value (used for adjusting the o_acc) for all splits for all rows in @@ -240,32 +232,8 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline }); }; - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp, - const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile - LSEDramBlockWindow& lse_dram_block_window, // kM tile - ck_tile::index_t o_acc_split_stride, - index_t num_splits, - void* smem_ptr) const - { - return operator()(lse_acc_dram_block_window_tmp, - o_acc_dram_block_window_tmp, - lse_dram_block_window, - identity{}, - identity{}, - identity{}, - o_acc_split_stride, - num_splits, - smem_ptr); - }; }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 7f3902bfad..8e2783c4ea 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -120,25 +120,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename LSEorLSEaccDramBlockWindow, - typename QElementFunction, - typename BiasElementFunction, - typename LSEaccElementFunction, - typename SAccElementFunction, - typename PComputeElementFunction, - typename OAccElementFunction, typename HstuMask> CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile - const LSEaccElementFunction& lse_or_lse_acc_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, index_t seqlen_k_start, index_t seqlen_k_end, HstuMask& mask, @@ -203,7 +191,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad if(seqlen_k_end <= seqlen_k_start) { clear_tile(o_acc); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); if constexpr(!is_null_tile_window_v) { @@ -212,8 +199,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad set_tile(lse_or_lse_acc, -numeric::infinity()); - store_tile(lse_or_lse_acc_dram_block_window, - tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc); } return o_acc; @@ -345,8 +331,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad set_tile(m, -numeric::infinity()); clear_tile(l); - q_tile = tile_elementwise_in(q_element_func, q_tile); - auto seqlen_k_curr = seqlen_k_start; constexpr index_t NumPrefetchV = 2; @@ -388,8 +372,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, @@ -406,8 +388,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad const auto bias_tile = load_tile(bias_dram_window); tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s + type_convert(bias_element_func(y)); + [&scale_s](auto& x, const auto& y) { + x = x * scale_s + type_convert(y); }, pcomp_tile, bias_tile); @@ -556,7 +538,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); } - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p = cast_tile(pcomp_tile); __builtin_amdgcn_sched_barrier(0x00000001); @@ -618,8 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); }); - store_tile(lse_or_lse_acc_dram_block_window, - tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, lse_or_lse_acc); } constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); @@ -638,50 +619,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad }); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile - index_t seqlen_k_start, - index_t seqlen_k_end, - HstuMask mask, - float scale_s, // scaling value exerted on the immediate Q@K result - float scale_p, // scaling value exerted on the SiLU result - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - v_dram_block_window_tmp, - bias_dram_block_window_tmp, - identity{}, - lse_or_lse_acc_dram_block_window, - identity{}, - identity{}, - identity{}, - identity{}, - seqlen_k_start, - seqlen_k_end, - mask, - scale_s, - scale_p, - smem_ptr, - dropout); - } }; } // namespace ck_tile