diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 920bd172fb..f92fdc6825 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -349,25 +349,25 @@ struct HstuAttentionFwdPipelineQRKSVS move_tile_window(v_dram_window, {0, kK1}); }); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + // STAGE 2, scale_s, add bias, mask, siLU if constexpr(kHasBias) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( - [&](auto& x, const auto& y) { - x += type_convert(bias_element_func(y)); + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s + type_convert(bias_element_func(y)); }, s_acc, bias_tile); + + move_tile_window(bias_dram_window, {0, kN0}); } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); } - move_tile_window(bias_dram_window, {0, kN0}); if constexpr(HstuMask::IsMasking) { const auto k_origin = k_dram_block_window.get_window_origin();