diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 423aea9054..18bdfe184b 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -351,7 +351,7 @@ struct FmhaFwdV3Kernel number{}, number<1>{}); - const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), @@ -362,15 +362,13 @@ struct FmhaFwdV3Kernel q_dram_pad, make_tuple( make_merge_transform( - make_tuple(seq_len, num_queries_per_kv) + make_tuple(seq_len_padded, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) - ); - - // TODO are we padding the tensor view or the block here? + ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim return q_dram_merged; }();