unified attention rename

This commit is contained in:
Tianxing Wu
2025-10-09 08:47:19 +00:00
parent e54cb5a713
commit 191f179038
5 changed files with 19 additions and 4 deletions

View File

@@ -84,6 +84,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_seqs; // number of batches for q
ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent????
ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent????
ck_tile::index_t BLOCK_M; // Block size for kv cache. to 2's exponent????
};
struct Kargs {
@@ -125,7 +126,8 @@ struct FmhaFwdV3Kernel
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs,
ck_tile::index_t BLOCK_SIZE,
ck_tile::index_t BLOCK_Q
ck_tile::index_t BLOCK_Q,
ck_tile::index_t BLOCK_M
)
{
Kargs kargs{{q_ptr,
@@ -161,6 +163,7 @@ struct FmhaFwdV3Kernel
num_seqs,
BLOCK_SIZE,
BLOCK_Q,
BLOCK_M
}};
return kargs;
@@ -301,7 +304,17 @@ struct FmhaFwdV3Kernel
}
const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q;
const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx];
const index_t context_len = seq_len - cur_batch_query_len;
const index_t max_seq_prefix_len = (
context_len
+ q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q
+ (kargs.unifiedAttentionVarlenKargs.BLOCK_M - 1) // num_queries_per_kv
+ 1
);
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) +
@@ -323,14 +336,16 @@ struct FmhaFwdV3Kernel
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.),
make_tuple(kargs.stride_q, 1),
make_tuple(seq_len, kargs.unifiedAttentionCommonKargs.HEAD_SIZE_PADDED),
make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
// block sizes
make_tuple(number<kargs.unifiedAttentionVarlenKargs.BLOCK_Q>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
// bool defining should we pad
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto k_dram = [&]() {