From 2d6dab29ebae0ab5564efe734ac3279a7f908514 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:18:23 +0000 Subject: [PATCH 1/2] refactor the q tensor view transformation --- .../kernel/unified_attention_kernel.hpp | 58 +++++++------------ 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 67d6372c31..31bf24fa31 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/core/numeric/math.hpp" #include #include @@ -314,7 +315,7 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; // should be cu_seqlens_q rather const index_t context_len = seq_len - cur_batch_query_len; @@ -330,62 +331,47 @@ struct FmhaFwdV3Kernel index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; // Q/K/V DRAM and DRAM window - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr); + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + + index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; + bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(seq_len, num_head_q, HEAD_SIZE), + make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), number{}, number<1>{}); - const auto q_dram_unmerged = transform_tensor_view( - q_dram_base, - make_tuple( - make_pass_through_transform(seq_len), - make_unmerge_transform(make_tuple(num_head_q / num_queries_per_kv, num_queries_per_kv)), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}) - ); + const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + q_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{} + ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged, - make_tuple( - make_pass_through_transform(num_head_q / num_queries_per_kv), - make_pass_through_transform(seq_len), - make_pass_through_transform(num_queries_per_kv), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}, sequence<3>{}) - ); const auto q_dram_merged = transform_tensor_view( - q_dram_permuted, + q_dram_pad, make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(num_head_q / num_queries_per_kv, seq_len, num_queries_per_kv) + make_merge_transform( + make_tuple(seq_len, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) ); - const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - q_dram_merged, - // block sizes - make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - sequence{} - ); - - return q_dram_pad; + return q_dram_merged; }(); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) From af94aaf1cbb1b4bbd2145a02112ee4c695ff5e45 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:22:52 +0000 Subject: [PATCH 2/2] refactor the q tensor view transformation --- .../kernel/unified_attention_kernel.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ffceec8aa2..ea1c4f3bf0 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -322,13 +322,13 @@ struct FmhaFwdV3Kernel ); // Q/K/V DRAM and DRAM window - index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start - index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset; - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; @@ -339,7 +339,7 @@ struct FmhaFwdV3Kernel const auto q_dram_base = make_naive_tensor_view( q_ptr, make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{});