diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e970230611..7c1facc545 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -96,6 +96,16 @@ struct UnifiedAttentionKernel const int32_t* query_start_len_ptr; // [num_seqs+1] ck_tile::index_t num_seqs; // number of batches for q + + // KV-segment parallelism (split-KV within unified attention) + ck_tile::index_t num_splits = 1; + ck_tile::index_t i_split = 0; + void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float + void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float + ck_tile::index_t split_stride_lse_acc = 0; + ck_tile::index_t split_stride_o_acc = 0; + ck_tile::index_t nhead_stride_lse_acc = 0; + ck_tile::index_t nhead_stride_o_acc = 0; }; using Kargs = UnifiedAttentionVarlenKargs; @@ -305,11 +315,23 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = + index_t total_num_kv_blocks = amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize); - // TODO sliding window - const index_t num_blocks_start = 0; + // KV-segment parallelism: split KV range across workgroups + index_t num_blocks_start = 0; + index_t num_blocks = total_num_kv_blocks; + if(kargs.num_splits > 1) + { + const index_t blocks_per_split = + ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits); + num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks); + num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks); + if(num_blocks_start >= num_blocks) + { + return; // this split has no work + } + } long_index_t kv_head_offset = static_cast(kv_head_idx) * kargs.stride_k_cache_2; // Q/K/V DRAM and DRAM window