From 191f17903866bdec19ddda69869576f965604f6e Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 9 Oct 2025 08:47:19 +0000 Subject: [PATCH] unified attention rename --- ...ernel.hpp => unified_attention_kernel.hpp} | 23 +++++++++++++++---- ...ine.hpp => unified_attention_pipeline.hpp} | 0 ...ied_attention_pipeline_default_policy.hpp} | 0 ...pp => unified_attention_pipeline_enum.hpp} | 0 ...=> unified_attention_pipeline_problem.hpp} | 0 5 files changed, 19 insertions(+), 4 deletions(-) rename include/ck_tile/ops/unified_attention/kernel/{fmha_fwd_v3_kernel.hpp => unified_attention_kernel.hpp} (94%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_fwd_v3_pipeline.hpp => unified_attention_pipeline.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_fwd_v3_pipeline_default_policy.hpp => unified_attention_pipeline_default_policy.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_pipeline_enum.hpp => unified_attention_pipeline_enum.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_pipeline_problem.hpp => unified_attention_pipeline_problem.hpp} (100%) diff --git a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp similarity index 94% rename from include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp rename to include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 9d164b639e..f49e560a96 100644 --- a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -84,6 +84,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent???? ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent???? + ck_tile::index_t BLOCK_M; // Block size for kv cache. to 2's exponent???? }; struct Kargs { @@ -125,7 +126,8 @@ struct FmhaFwdV3Kernel const int32_t* query_start_len_ptr, ck_tile::index_t num_seqs, ck_tile::index_t BLOCK_SIZE, - ck_tile::index_t BLOCK_Q + ck_tile::index_t BLOCK_Q, + ck_tile::index_t BLOCK_M ) { Kargs kargs{{q_ptr, @@ -161,6 +163,7 @@ struct FmhaFwdV3Kernel num_seqs, BLOCK_SIZE, BLOCK_Q, + BLOCK_M }}; return kargs; @@ -301,7 +304,17 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q; + const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t context_len = seq_len - cur_batch_query_len; + + + const index_t max_seq_prefix_len = ( + context_len + + q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q + + (kargs.unifiedAttentionVarlenKargs.BLOCK_M - 1) // num_queries_per_kv + + 1 + ); // for simplicity, batch stride we just modify the pointer const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + @@ -323,14 +336,16 @@ struct FmhaFwdV3Kernel const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( q_ptr, - make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.), - make_tuple(kargs.stride_q, 1), + make_tuple(seq_len, kargs.unifiedAttentionCommonKargs.HEAD_SIZE_PADDED), + make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, 1), number{}, number<1>{}); return pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + // block sizes + make_tuple(number{}, number{}), + // bool defining should we pad sequence{}); }(); const auto k_dram = [&]() { diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp