diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 5d194e2167..423aea9054 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -328,14 +328,19 @@ struct FmhaFwdV3Kernel index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; + + index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start + index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start + index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; - bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); + const bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { @@ -370,11 +375,11 @@ struct FmhaFwdV3Kernel }(); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) - // stride for dim 0 (num_queries_per_kv * seq_len, num_queries_per_kv, 1) + // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window( q_dram, make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - {kv_head_idx * seq_len * num_queries_per_kv + q_block_global_idx * num_queries_per_kv, 0} + {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0} ); const auto k_dram = [&]() { @@ -424,23 +429,39 @@ struct FmhaFwdV3Kernel // O DRAM and O DRAM window auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( + const auto o_dram_base = make_naive_tensor_view( o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), + make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), number{}, number<1>{}); - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + o_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{} + ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + + const auto o_dram_merged = transform_tensor_view( + o_dram_pad, + make_tuple( + make_merge_transform( + make_tuple(seq_len, num_queries_per_kv) + ), + make_pass_through_transform(HEAD_SIZE_PADDED) + ), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + + return o_dram_merged; }(); auto o_dram_window = make_tile_window(o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); + make_tuple(BLOCK_M, HEAD_SIZE_PADDED), + {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0}); EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); }