mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
unified attention rename
This commit is contained in:
@@ -84,6 +84,7 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent????
|
||||
ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent????
|
||||
ck_tile::index_t BLOCK_M; // Block size for kv cache. to 2's exponent????
|
||||
};
|
||||
|
||||
struct Kargs {
|
||||
@@ -125,7 +126,8 @@ struct FmhaFwdV3Kernel
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t BLOCK_SIZE,
|
||||
ck_tile::index_t BLOCK_Q
|
||||
ck_tile::index_t BLOCK_Q,
|
||||
ck_tile::index_t BLOCK_M
|
||||
)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -161,6 +163,7 @@ struct FmhaFwdV3Kernel
|
||||
num_seqs,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_Q,
|
||||
BLOCK_M
|
||||
}};
|
||||
|
||||
return kargs;
|
||||
@@ -301,7 +304,17 @@ struct FmhaFwdV3Kernel
|
||||
}
|
||||
|
||||
const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q;
|
||||
const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx];
|
||||
|
||||
const index_t context_len = seq_len - cur_batch_query_len;
|
||||
|
||||
|
||||
const index_t max_seq_prefix_len = (
|
||||
context_len
|
||||
+ q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q
|
||||
+ (kargs.unifiedAttentionVarlenKargs.BLOCK_M - 1) // num_queries_per_kv
|
||||
+ 1
|
||||
);
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) +
|
||||
@@ -323,14 +336,16 @@ struct FmhaFwdV3Kernel
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
make_tuple(seq_len, kargs.unifiedAttentionCommonKargs.HEAD_SIZE_PADDED),
|
||||
make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
// block sizes
|
||||
make_tuple(number<kargs.unifiedAttentionVarlenKargs.BLOCK_Q>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
// bool defining should we pad
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
Reference in New Issue
Block a user