From 50c36f352ac065e4953230f6947acf6cba52bfaf Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 29 Mar 2024 02:55:16 -0400 Subject: [PATCH] Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 23 +++++++++++++++---- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 23 +++++++++++++++---- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 20 ++++++++++++---- 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 098d9d363d..07930258c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -111,7 +111,10 @@ struct BlockFmhaPipelineQRKSVS typename KElementFunction, typename VElementFunction, typename BiasElementFunction, - typename LSEElementFunction> + typename LSEElementFunction, + typename SAccElementFunction, + typename PComputeElementFunction, + typename OAccElementFunction> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -123,6 +126,9 @@ struct BlockFmhaPipelineQRKSVS const BiasElementFunction& bias_element_func, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale, void* smem_ptr) const @@ -319,13 +325,15 @@ struct BlockFmhaPipelineQRKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = scale * x + type_convert(bias_element_func(y)); + x += type_convert(bias_element_func(y)); #else - x = scale * x + log2e_v * - type_convert(bias_element_func(y)); + x += log2e_v * + type_convert(bias_element_func(y)); #endif }, s_acc, @@ -335,6 +343,7 @@ struct BlockFmhaPipelineQRKSVS { #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -459,6 +468,7 @@ struct BlockFmhaPipelineQRKSVS } move_tile_window(v_dram_window, {0, kK1}); + tile_elementwise_inout(p_compute_element_func, p_compute); const auto p = cast_tile(p_compute); // STAGE 3, KV gemm @@ -545,6 +555,8 @@ struct BlockFmhaPipelineQRKSVS }); }); + tile_elementwise_inout(o_acc_element_func, o_acc); + return o_acc; } @@ -573,6 +585,9 @@ struct BlockFmhaPipelineQRKSVS identity{}, lse_dram_block_window_tmp, identity{}, + identity{}, + identity{}, + identity{}, mask, scale, smem_ptr); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c7e2f3ae4b..e7d22984cd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -122,7 +122,10 @@ struct BlockFmhaPipelineQRKSVSAsync typename KElementFunction, typename VElementFunction, typename BiasElementFunction, - typename LSEElementFunction> + typename LSEElementFunction, + typename SAccElementFunction, + typename PComputeElementFunction, + typename OAccElementFunction> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -134,6 +137,9 @@ struct BlockFmhaPipelineQRKSVSAsync const BiasElementFunction& bias_element_func, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale, void* smem_ptr) const @@ -362,13 +368,15 @@ struct BlockFmhaPipelineQRKSVSAsync // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = scale * x + type_convert(bias_element_func(y)); + x += type_convert(bias_element_func(y)); #else - x = scale * x + log2e_v * - type_convert(bias_element_func(y)); + x += log2e_v * + type_convert(bias_element_func(y)); #endif }, s_acc, @@ -378,6 +386,7 @@ struct BlockFmhaPipelineQRKSVSAsync { #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -521,6 +530,7 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); + tile_elementwise_inout(p_compute_element_func, p_compute); const auto p = cast_tile(p_compute); // STAGE 3, KV gemm @@ -640,6 +650,8 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); + tile_elementwise_inout(o_acc_element_func, o_acc); + return o_acc; } @@ -668,6 +680,9 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, lse_dram_block_window_tmp, identity{}, + identity{}, + identity{}, + identity{}, mask, scale, smem_ptr); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index e04f6660d1..8150326adf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -97,7 +97,10 @@ struct BlockFmhaPipelineQSKSVS typename KElementFunction, typename VElementFunction, typename BiasElementFunction, - typename LSEElementFunction> + typename LSEElementFunction, + typename SAccElementFunction, + typename PComputeElementFunction, + typename OAccElementFunction> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -109,6 +112,9 @@ struct BlockFmhaPipelineQSKSVS const BiasElementFunction& bias_element_func, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale, void* smem_ptr) const @@ -310,13 +316,15 @@ struct BlockFmhaPipelineQSKSVS // STAGE 2, scale, add bias, mask, softmax if constexpr(kHasBias) { + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_TILE_FMHA_FWD_FAST_EXP2 - x = scale * x + type_convert(bias_element_func(y)); + x += type_convert(bias_element_func(y)); #else - x = scale * x + log2e_v * - type_convert(bias_element_func(y)); + x += log2e_v * + type_convert(bias_element_func(y)); #endif }, s_acc, @@ -326,6 +334,7 @@ struct BlockFmhaPipelineQSKSVS { #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout(s_acc_element_func, s_acc); #endif } move_tile_window(bias_dram_window, {0, kN0}); @@ -450,6 +459,7 @@ struct BlockFmhaPipelineQSKSVS } move_tile_window(v_dram_window, {0, kK1}); + tile_elementwise_inout(p_compute_element_func, p_compute); const auto p = cast_tile(p_compute); // STAGE 3, KV gemm @@ -536,6 +546,8 @@ struct BlockFmhaPipelineQSKSVS }); }); + tile_elementwise_inout(o_acc_element_func, o_acc); + return o_acc; }