From e1120fffb0c4d635bcbd7c859ffa354944cb6526 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 09:58:27 +0000 Subject: [PATCH] pipeline api --- .../kernel/unified_attention_kernel.hpp | 13 ++++++++++++- .../pipeline/unified_attention_pipeline.hpp | 8 ++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e8cd551417..386319f28b 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -96,7 +96,6 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q }; - using Kargs = UnifiedAttentionVarlenKargs; CK_TILE_HOST static constexpr Kargs MakeKargs( @@ -332,6 +331,8 @@ struct FmhaFwdV3Kernel index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + index_t block_table_offset = seq_idx * kargs.block_table_stride; + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; @@ -445,10 +446,20 @@ struct FmhaFwdV3Kernel auto v_dram_window = make_tile_window( v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + // Create mask for causal attention + auto mask = [&]() { + return make_casual_mask(query_pos, BLOCK_Q, max_seq_prefix_len, BLOCK_SIZE); + }(); + + // Define LSE dram window (or use a dummy if not needed by pipeline) + auto lse_dram_window = make_dummy_tile_window(); + auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, + block_tables_ptr, + block_table_offset, lse_dram_window, mask, kargs.scale_s, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b151b61028..7bc3dc1d7d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -391,6 +391,8 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, + const void* block_tables_ptr, + index_t block_table_offset, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, @@ -402,6 +404,8 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; + index_t block_idx_prev = 0; + static_assert( std::is_same_v> && std::is_same_v> && @@ -1231,6 +1235,8 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const void* block_tables_ptr, + index_t block_table_offset, LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, @@ -1244,6 +1250,8 @@ struct UnifiedAttentionPipeline identity{}, v_dram_block_window_tmp, identity{}, + block_tables_ptr, + block_table_offset, lse_dram_block_window_tmp, identity{}, identity{},