diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index f9ea3d0b50..44236a734c 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,8 +460,8 @@ struct UnifiedAttentionKernel cur_batch_query_len, // x (i.e. extend) seq_len, // y_total (x + y) cur_batch_query_len, // x_total - num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile + ); else return FmhaMask{cur_batch_query_len, seq_len}; }(); @@ -470,6 +470,7 @@ struct UnifiedAttentionKernel return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, + num_queries_per_kv, kargs.block_tables_ptr, block_table_offset, mask, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index af4d79759f..b2cb1a3da0 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -278,8 +278,8 @@ struct UnifiedAttentionPipeline static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDim; // static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -1208,6 +1208,7 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, FmhaMask mask, @@ -1216,12 +1217,29 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + index_t num_queries_per_kv, + const void* block_tables_ptr, + index_t block_table_offset, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, identity{}, v_dram_block_window_tmp, identity{}, + num_queries_per_kv, block_tables_ptr, block_table_offset, identity{},