refactor the q tensor view transformation

This commit is contained in:
Juuso Korhonen
2025-10-13 10:22:52 +00:00
parent 49ce980c67
commit af94aaf1cb

View File

@@ -322,13 +322,13 @@ struct FmhaFwdV3Kernel
);
// Q/K/V DRAM and DRAM window
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.unifiedAttentionCommonKargs.k_ptr);
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.unifiedAttentionCommonKargs.v_ptr);
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.unifiedAttentionCommonKargs.o_ptr);
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr);
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr);
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q;
@@ -339,7 +339,7 @@ struct FmhaFwdV3Kernel
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1),
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});