diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ffceec8aa2..ea1c4f3bf0 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -322,13 +322,13 @@ struct FmhaFwdV3Kernel ); // Q/K/V DRAM and DRAM window - index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start - index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset; - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; @@ -339,7 +339,7 @@ struct FmhaFwdV3Kernel const auto q_dram_base = make_naive_tensor_view( q_ptr, make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{});