mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
WIP: CK-UA KV-segment parallelism - kernel args and split range
Added split-KV fields to UnifiedAttentionVarlenKargs (num_splits, i_split, lse_acc_ptr, o_acc_ptr + strides). Modified operator() to compute per-split KV range using blocks_per_split. INCOMPLETE: The pipeline returns normalized o_acc but the split-KV combine kernel needs unnormalized o_acc + lse. Need to modify the pipeline to optionally return m and l values alongside o_acc. The kernel changes compile but the epilogue needs the split path (write to float accumulators instead of final output). Made-with: Cursor
This commit is contained in:
@@ -96,6 +96,16 @@ struct UnifiedAttentionKernel
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
|
||||
// KV-segment parallelism (split-KV within unified attention)
|
||||
ck_tile::index_t num_splits = 1;
|
||||
ck_tile::index_t i_split = 0;
|
||||
void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float
|
||||
void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float
|
||||
ck_tile::index_t split_stride_lse_acc = 0;
|
||||
ck_tile::index_t split_stride_o_acc = 0;
|
||||
ck_tile::index_t nhead_stride_lse_acc = 0;
|
||||
ck_tile::index_t nhead_stride_o_acc = 0;
|
||||
};
|
||||
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
@@ -305,11 +315,23 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
|
||||
const auto max_seq_prefix_len = _max_seq_prefix_len;
|
||||
const index_t num_blocks =
|
||||
index_t total_num_kv_blocks =
|
||||
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
|
||||
|
||||
// TODO sliding window
|
||||
const index_t num_blocks_start = 0;
|
||||
// KV-segment parallelism: split KV range across workgroups
|
||||
index_t num_blocks_start = 0;
|
||||
index_t num_blocks = total_num_kv_blocks;
|
||||
if(kargs.num_splits > 1)
|
||||
{
|
||||
const index_t blocks_per_split =
|
||||
ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits);
|
||||
num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks);
|
||||
num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks);
|
||||
if(num_blocks_start >= num_blocks)
|
||||
{
|
||||
return; // this split has no work
|
||||
}
|
||||
}
|
||||
long_index_t kv_head_offset = static_cast<long_index_t>(kv_head_idx) * kargs.stride_k_cache_2;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
|
||||
Reference in New Issue
Block a user