From ec82f3bbd65d2a783a1d1c325cf838765500d93d Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Jun 2024 19:54:30 +0000 Subject: [PATCH] Re-order pipeline call operator arguments --- .../fmha_fwd_splitkv_combine_kernel.hpp | 8 ++++---- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 20 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index fe012ab1ab..e3f7fd0ab7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -421,9 +421,9 @@ struct FmhaFwdSplitKVCombineKernel identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func kargs.num_splits, - smem_ptr, kargs.seqlen_q, - kargs.max_seqlen_q); + kargs.max_seqlen_q, + smem_ptr); } else { @@ -431,9 +431,9 @@ struct FmhaFwdSplitKVCombineKernel o_acc_dram_window, lse_dram_window, kargs.num_splits, - smem_ptr, kargs.seqlen_q, - kargs.max_seqlen_q); + kargs.max_seqlen_q, + smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 340df1a094..37a1536413 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -83,9 +83,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline const LSEElementFunction& lse_element_func, const OaccElementFunction& o_acc_element_func, index_t num_splits, - void* smem_ptr, - index_t real_seqlen_q, - index_t max_seqlen_q) const + index_t seqlen_q, + index_t max_seqlen_q, + void* smem_ptr) const { LSEDataType* lse_acc_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); @@ -116,7 +116,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); - if(row < num_splits && col < real_seqlen_q) + if(row < num_splits && col < seqlen_q) { lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices); } @@ -312,9 +312,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline const OaccDramBlockWindow& o_acc_dram_block_window, LSEDramBlockWindow& lse_dram_block_window, index_t num_splits, - void* smem_ptr, - index_t real_seqlen_q, - index_t max_seqlen_q) const + index_t seqlen_q, + index_t max_seqlen_q, + void* smem_ptr) const { return operator()(lse_acc_dram_block_window, o_acc_dram_block_window, @@ -322,9 +322,9 @@ struct BlockFmhaFwdSplitKVCombinePipeline identity{}, identity{}, num_splits, - smem_ptr, - real_seqlen_q, - max_seqlen_q); + seqlen_q, + max_seqlen_q, + smem_ptr); } };