mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
refactor the q tensor view transformation
This commit is contained in:
@@ -322,13 +322,13 @@ struct FmhaFwdV3Kernel
|
||||
);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start
|
||||
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start
|
||||
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start
|
||||
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start
|
||||
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.unifiedAttentionCommonKargs.k_ptr);
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.unifiedAttentionCommonKargs.v_ptr);
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.unifiedAttentionCommonKargs.o_ptr);
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr);
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr);
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
|
||||
|
||||
|
||||
index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q;
|
||||
@@ -339,7 +339,7 @@ struct FmhaFwdV3Kernel
|
||||
const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE),
|
||||
make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1),
|
||||
make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user