From df4fc8f26c6f1e14de0b16a97d7ce4d539b1fa3a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Jun 2024 19:23:19 +0000 Subject: [PATCH] Re-order split-kv pipeline call operator arguments --- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 12 ++++++------ ...lock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 18 +++++++++--------- ...mha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp | 18 +++++++++--------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 2504a91a65..a2a4d150ae 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -865,13 +865,13 @@ struct FmhaFwdSplitKVKernel identity{}, // s_acc_element_func scales{kargs.scale_p}, // p_compute_element_func identity{}, // o_acc_element_func + kargs.num_splits, + i_split_, mask, position_encoding, kargs.scale_s, smem_ptr, - dropout, - i_split_, - kargs.num_splits); + dropout); } else { @@ -881,13 +881,13 @@ struct FmhaFwdSplitKVKernel bias_dram_window, randval_dram_window, lse_acc_dram_window, + kargs.num_splits, + i_split_, mask, position_encoding, kargs.scale_s, smem_ptr, - dropout, - i_split_, - kargs.num_splits); + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index f768551755..a4102ce314 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -133,13 +133,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, - index_t i_split, - index_t num_splits) const + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -629,13 +629,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, - index_t i_split, - index_t num_splits) const + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -651,13 +651,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS identity{}, identity{}, identity{}, + num_splits, + i_split, mask, position_encoding, scale_s, smem_ptr, - dropout, - i_split, - num_splits); + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp index c9c00123df..6a5269e862 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -145,13 +145,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, - index_t i_split, - index_t num_splits) const + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -730,13 +730,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, - index_t i_split, - index_t num_splits) const + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -752,13 +752,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync identity{}, identity{}, identity{}, + num_splits, + i_split, mask, position_encoding, scale_s, smem_ptr, - dropout, - i_split, - num_splits); + dropout); } };