WIP: CK-UA KV-segment parallelism - kernel args and split range

Added split-KV fields to UnifiedAttentionVarlenKargs (num_splits,
i_split, lse_acc_ptr, o_acc_ptr + strides). Modified operator() to
compute per-split KV range using blocks_per_split.

INCOMPLETE: The pipeline returns normalized o_acc but the split-KV
combine kernel needs unnormalized o_acc + lse. Need to modify the
pipeline to optionally return m and l values alongside o_acc.

The kernel changes compile but the epilogue needs the split path
(write to float accumulators instead of final output).

Made-with: Cursor
This commit is contained in:
root
2026-04-01 19:09:59 +00:00
parent 63821af1ff
commit 87d16738bf

View File

@@ -96,6 +96,16 @@ struct UnifiedAttentionKernel
const int32_t* query_start_len_ptr; // [num_seqs+1]
ck_tile::index_t num_seqs; // number of batches for q
// KV-segment parallelism (split-KV within unified attention)
ck_tile::index_t num_splits = 1;
ck_tile::index_t i_split = 0;
void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float
void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float
ck_tile::index_t split_stride_lse_acc = 0;
ck_tile::index_t split_stride_o_acc = 0;
ck_tile::index_t nhead_stride_lse_acc = 0;
ck_tile::index_t nhead_stride_o_acc = 0;
};
using Kargs = UnifiedAttentionVarlenKargs;
@@ -305,11 +315,23 @@ struct UnifiedAttentionKernel
}
const auto max_seq_prefix_len = _max_seq_prefix_len;
const index_t num_blocks =
index_t total_num_kv_blocks =
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
// TODO sliding window
const index_t num_blocks_start = 0;
// KV-segment parallelism: split KV range across workgroups
index_t num_blocks_start = 0;
index_t num_blocks = total_num_kv_blocks;
if(kargs.num_splits > 1)
{
const index_t blocks_per_split =
ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits);
num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks);
num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks);
if(num_blocks_start >= num_blocks)
{
return; // this split has no work
}
}
long_index_t kv_head_offset = static_cast<long_index_t>(kv_head_idx) * kargs.stride_k_cache_2;
// Q/K/V DRAM and DRAM window