o ptr and window

This commit is contained in:
Tianxing Wu
2025-10-13 11:32:28 +00:00
parent 16129a794a
commit be58d51d36

View File

@@ -328,14 +328,19 @@ struct FmhaFwdV3Kernel
index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start
index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start
index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1;
index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start
index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start
index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1;
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) + q_ptr_offset;
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) + kv_head_offset;
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) + kv_head_offset;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) + o_ptr_offset;
index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q;
bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0);
const bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0);
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
@@ -370,11 +375,11 @@ struct FmhaFwdV3Kernel
}();
// Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim)
// stride for dim 0 (num_queries_per_kv * seq_len, num_queries_per_kv, 1)
// stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1)
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(BLOCK_Q, HEAD_SIZE_PADDED),
{kv_head_idx * seq_len * num_queries_per_kv + q_block_global_idx * num_queries_per_kv, 0}
{q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0}
);
const auto k_dram = [&]() {
@@ -424,23 +429,39 @@ struct FmhaFwdV3Kernel
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE),
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
o_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<is_seq_len_aligned, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto o_dram_merged = transform_tensor_view(
o_dram_pad,
make_tuple(
make_merge_transform(
make_tuple(seq_len, num_queries_per_kv)
),
make_pass_through_transform(HEAD_SIZE_PADDED)
),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})
);
return o_dram_merged;
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
make_tuple(BLOCK_M, HEAD_SIZE_PADDED),
{q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}