mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/composable_kernel into tianxing/unified-attention
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -324,63 +325,48 @@ struct FmhaFwdV3Kernel
|
||||
index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr);
|
||||
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start
|
||||
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start
|
||||
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
|
||||
|
||||
|
||||
index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q;
|
||||
bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(seq_len, num_head_q, HEAD_SIZE),
|
||||
make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE),
|
||||
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_dram_unmerged = transform_tensor_view(
|
||||
q_dram_base,
|
||||
make_tuple(
|
||||
make_pass_through_transform(seq_len),
|
||||
make_unmerge_transform(make_tuple(num_head_q / num_queries_per_kv, num_queries_per_kv)),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)
|
||||
),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})
|
||||
);
|
||||
const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
|
||||
sequence<is_seq_len_aligned, false, kPadHeadDimQ>{}
|
||||
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
|
||||
|
||||
const auto q_dram_permuted = transform_tensor_view(
|
||||
q_dram_unmerged,
|
||||
make_tuple(
|
||||
make_pass_through_transform(num_head_q / num_queries_per_kv),
|
||||
make_pass_through_transform(seq_len),
|
||||
make_pass_through_transform(num_queries_per_kv),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)
|
||||
),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}, sequence<3>{})
|
||||
);
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
q_dram_permuted,
|
||||
q_dram_pad,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(num_head_q / num_queries_per_kv, seq_len, num_queries_per_kv)
|
||||
make_merge_transform(
|
||||
make_tuple(seq_len, num_queries_per_kv)
|
||||
),
|
||||
make_pass_through_transform(HEAD_SIZE_PADDED)
|
||||
),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{})
|
||||
);
|
||||
|
||||
// TODO are we padding the tensor view or the block here?
|
||||
const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
|
||||
q_dram_merged,
|
||||
// block sizes
|
||||
make_tuple(BLOCK_Q, HEAD_SIZE_PADDED),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{}
|
||||
);
|
||||
|
||||
return q_dram_pad;
|
||||
return q_dram_merged;
|
||||
}();
|
||||
|
||||
// Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim)
|
||||
|
||||
Reference in New Issue
Block a user