diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 19033499cc..b56210000c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -325,8 +325,8 @@ struct BlockFmhaPipelineQRKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -341,9 +341,9 @@ struct BlockFmhaPipelineQRKSVS } else { + tile_elementwise_inout(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); - tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 1fa33de55c..b27bd7250a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -368,8 +368,8 @@ struct BlockFmhaPipelineQRKSVSAsync // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -384,9 +384,9 @@ struct BlockFmhaPipelineQRKSVSAsync } else { + tile_elementwise_inout(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); - tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 3c56c86e4d..03b94d7e90 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -316,8 +316,8 @@ struct BlockFmhaPipelineQSKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -332,9 +332,9 @@ struct BlockFmhaPipelineQSKSVS } else { + tile_elementwise_inout(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); - tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0});