diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index c6d231924d..ac1462ae47 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -35,8 +35,6 @@ struct HstuAttentionFwdKernel using QKVDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; - using GemmAccDataType = - ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; @@ -742,6 +740,7 @@ struct HstuAttentionFwdKernel bias_dram_window, mask, kargs.scale_s, + kargs.max_seqlen, smem_ptr, dropout); }(); @@ -766,13 +765,6 @@ struct HstuAttentionFwdKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - tile_elementwise_inout( - [&](auto& x) { - x = x * type_convert( - __builtin_amdgcn_rcpf(static_cast(kargs.max_seqlen))); - }, - o_acc_tile); - EpiloguePipeline{}(o_dram_window, o_acc_tile); } }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index d49f36cc75..6c3cb1e638 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS const OAccElementFunction& o_acc_element_func, HstuMask mask, float scale_s, + index_t max_seqlen, void* smem_ptr, DropoutType& dropout) const { @@ -463,6 +464,13 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); } while(++i_loop < num_loops); + tile_elementwise_inout( + [&](auto& x) { + x = x * type_convert( + __builtin_amdgcn_rcpf(static_cast(max_seqlen))); + }, + o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; @@ -479,6 +487,7 @@ struct HstuAttentionFwdPipelineQRKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask, float scale_s, + int max_seqlen, void* smem_ptr, DropoutType& dropout) const { @@ -495,6 +504,7 @@ struct HstuAttentionFwdPipelineQRKSVS identity{}, mask, scale_s, + max_seqlen, smem_ptr, dropout); }