From d9323ea261d05fd60f29e535909810a0cdf7dfbc Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 4 Apr 2024 03:17:36 +0000 Subject: [PATCH] Fix bug of elementwise op, our elementwise op is not inout --- .../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 8 ++++---- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 6 +++--- .../ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index b56210000c..1c1cb59912 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -325,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -341,7 +341,7 @@ struct BlockFmhaPipelineQRKSVS } else { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif @@ -468,7 +468,7 @@ struct BlockFmhaPipelineQRKSVS } move_tile_window(v_dram_window, {0, kK1}); - tile_elementwise_inout(p_compute_element_func, p_compute); + tile_elementwise_in(p_compute_element_func, p_compute); const auto p = cast_tile(p_compute); // STAGE 3, KV gemm @@ -555,7 +555,7 @@ struct BlockFmhaPipelineQRKSVS }); }); - tile_elementwise_inout(o_acc_element_func, o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b27bd7250a..f6e651132c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -368,7 +368,7 @@ struct BlockFmhaPipelineQRKSVSAsync // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -384,7 +384,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif @@ -650,7 +650,7 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); - tile_elementwise_inout(o_acc_element_func, o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 03b94d7e90..4251dcf3e2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -316,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -332,7 +332,7 @@ struct BlockFmhaPipelineQSKSVS } else { - tile_elementwise_inout(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif @@ -546,7 +546,7 @@ struct BlockFmhaPipelineQSKSVS }); }); - tile_elementwise_inout(o_acc_element_func, o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; }