From 87d16738bfd4ef57c26efb74cd3e810b44d8f21e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Apr 2026 19:09:59 +0000 Subject: [PATCH] WIP: CK-UA KV-segment parallelism - kernel args and split range Added split-KV fields to UnifiedAttentionVarlenKargs (num_splits, i_split, lse_acc_ptr, o_acc_ptr + strides). Modified operator() to compute per-split KV range using blocks_per_split. INCOMPLETE: The pipeline returns normalized o_acc but the split-KV combine kernel needs unnormalized o_acc + lse. Need to modify the pipeline to optionally return m and l values alongside o_acc. The kernel changes compile but the epilogue needs the split path (write to float accumulators instead of final output). Made-with: Cursor --- .../kernel/unified_attention_kernel.hpp | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e970230611..7c1facc545 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -96,6 +96,16 @@ struct UnifiedAttentionKernel const int32_t* query_start_len_ptr; // [num_seqs+1] ck_tile::index_t num_seqs; // number of batches for q + + // KV-segment parallelism (split-KV within unified attention) + ck_tile::index_t num_splits = 1; + ck_tile::index_t i_split = 0; + void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float + void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float + ck_tile::index_t split_stride_lse_acc = 0; + ck_tile::index_t split_stride_o_acc = 0; + ck_tile::index_t nhead_stride_lse_acc = 0; + ck_tile::index_t nhead_stride_o_acc = 0; }; using Kargs = UnifiedAttentionVarlenKargs; @@ -305,11 +315,23 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = + index_t total_num_kv_blocks = amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize); - // TODO sliding window - const index_t num_blocks_start = 0; + // KV-segment parallelism: split KV range across workgroups + index_t num_blocks_start = 0; + index_t num_blocks = total_num_kv_blocks; + if(kargs.num_splits > 1) + { + const index_t blocks_per_split = + ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits); + num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks); + num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks); + if(num_blocks_start >= num_blocks) + { + return; // this split has no work + } + } long_index_t kv_head_offset = static_cast(kv_head_idx) * kargs.stride_k_cache_2; // Q/K/V DRAM and DRAM window