Add KV-segment parallelism to CK unified attention pipeline

End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.

Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
  o_acc. The masked-empty early-exit returns lse = -inf so downstream
  combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
  rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
  m / log2(e) + log(l), which dropped the per-head scale and produced
  LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
  alu0-done/alu1-pending state at loop exit depends on the parity of
  the *iteration count* (= num_total_loop - num_blocks_start), not on
  num_total_loop alone. For non-split (num_blocks_start == 0) the two
  parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
  kPageBlockSize-sized tiles, but block_tables is indexed in
  page_size-sized pages — shifting block_table_offset by num_blocks_start
  reads the wrong pages whenever kPageBlockSize != page_size. Replaced
  with split_token_offset = num_blocks_start * kPageBlockSize added to
  logical_token before /page_size, so the page lookup uses the absolute
  token position.

Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
  (defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
  epilogue and instead writes the FP32 (o_acc, lse) tile pair into
  workspace tensors at [head, split, batch_start_token, ...] using
  Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
  lse. Host strides come from kargs.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 08:42:09 +00:00
parent 473869aba5
commit 25364aa634
4 changed files with 266 additions and 70 deletions

View File

@@ -67,6 +67,25 @@ struct unified_attention_args
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
// KV-segment parallelism (split-KV). When num_splits == 1, the kernel
// writes to o_ptr as usual. When num_splits > 1, the kernel is launched
// with a 3D grid whose z-dim is num_splits — each CTA computes its own
// partial (o_acc, lse) and writes them into the FP32 workspaces; a
// separate combine kernel (or a Python-side reduce) merges across
// splits into the final output.
//
// Workspace layout (host-allocated):
// o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v] (FP32)
// lse_acc_ptr : [num_q_heads, num_splits, total_q] (FP32)
// The corresponding host-set strides are below.
index_t num_splits = 1;
void* o_acc_ptr = nullptr;
void* lse_acc_ptr = nullptr;
index_t split_stride_o_acc = 0;
index_t split_stride_lse_acc = 0;
index_t nhead_stride_o_acc = 0;
index_t nhead_stride_lse_acc = 0;
};
std::ostream& operator<<(std::ostream& stream,

View File

@@ -414,16 +414,27 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
args.block_table_stride,
args.seq_lens_ptr,
args.query_start_len_ptr,
args.num_seqs);
args.num_seqs,
args.num_splits,
args.lse_acc_ptr,
args.o_acc_ptr,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc);
dim3 grids;
if constexpr(UseDecodeGrid)
{
grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv, args.num_seqs);
grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv,
args.num_seqs,
args.num_splits);
}
else
{
grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv,
total_num_q_blocks,
args.num_splits);
}
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
#include "ck_tile/core/numeric/math.hpp"
@@ -97,9 +98,11 @@ struct UnifiedAttentionKernel
ck_tile::index_t num_seqs; // number of batches for q
// KV-segment parallelism (split-KV within unified attention)
// KV-segment parallelism (split-KV within unified attention).
// Each CTA derives its own `i_split` from `blockIdx.z` — the host
// launches a single 3D grid with z = num_splits and the kernel
// dispatches all splits in parallel.
ck_tile::index_t num_splits = 1;
ck_tile::index_t i_split = 0;
void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float
void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float
ck_tile::index_t split_stride_lse_acc = 0;
@@ -140,7 +143,14 @@ struct UnifiedAttentionKernel
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs)
ck_tile::index_t num_seqs,
ck_tile::index_t num_splits = 1,
void* lse_acc_ptr = nullptr,
void* o_acc_ptr = nullptr,
ck_tile::index_t split_stride_lse_acc = 0,
ck_tile::index_t split_stride_o_acc = 0,
ck_tile::index_t nhead_stride_lse_acc = 0,
ck_tile::index_t nhead_stride_o_acc = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -172,15 +182,25 @@ struct UnifiedAttentionKernel
block_table_stride,
seq_lens_ptr,
query_start_len_ptr,
num_seqs};
num_seqs,
num_splits,
lse_acc_ptr,
o_acc_ptr,
split_stride_lse_acc,
split_stride_o_acc,
nhead_stride_lse_acc,
nhead_stride_o_acc};
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
ck_tile::index_t total_num_q_blocks)
ck_tile::index_t total_num_q_blocks,
ck_tile::index_t num_splits = 1)
{
return dim3(num_kv_heads * total_num_q_blocks);
// z-dim carries the split index; num_splits == 1 is the existing
// (non-split) launch with dim3(N, 1, 1).
return dim3(num_kv_heads * total_num_q_blocks, 1, num_splits);
}
// Binary search to find the sequence index for a given target index
@@ -231,9 +251,10 @@ struct UnifiedAttentionKernel
}
CK_TILE_HOST static constexpr auto GridSizeDecode(ck_tile::index_t num_kv_heads,
ck_tile::index_t num_seqs)
ck_tile::index_t num_seqs,
ck_tile::index_t num_splits = 1)
{
return dim3(num_kv_heads, num_seqs);
return dim3(num_kv_heads, num_seqs, num_splits);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -244,6 +265,11 @@ struct UnifiedAttentionKernel
assert(kBlockM / num_queries_per_kv == kBlockQ);
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
// split index lives in z — when num_splits == 1 (the only z value)
// this is just `0` and costs nothing.
const index_t i_split = blockIdx.z;
index_t kv_head_idx;
index_t seq_idx;
index_t q_block_local_idx;
@@ -252,7 +278,7 @@ struct UnifiedAttentionKernel
if(gridDim.y > 1)
{
// Decode grid: dim3(num_kv_heads, num_seqs)
// Decode grid: dim3(num_kv_heads, num_seqs, num_splits)
// Direct mapping, no binary search, no padding CTAs.
kv_head_idx = blockIdx.x;
seq_idx = blockIdx.y;
@@ -263,7 +289,8 @@ struct UnifiedAttentionKernel
}
else
{
// Standard 1D grid with binary search
// Standard 1D grid (x-folded) with binary search; z-dim carries
// the split index just like in the decode branch.
ck_tile::index_t pid = blockIdx.x;
const auto [kv_head_idx_, q_block_global_idx] = GetTileIndex(pid, kargs);
@@ -318,15 +345,17 @@ struct UnifiedAttentionKernel
index_t total_num_kv_blocks =
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
// KV-segment parallelism: split KV range across workgroups
// KV-segment parallelism: split KV range across workgroups.
// `i_split` came from blockIdx.z above; with num_splits == 1 it's 0
// and these min/max bounds reduce to [0, total_num_kv_blocks).
index_t num_blocks_start = 0;
index_t num_blocks = total_num_kv_blocks;
if(kargs.num_splits > 1)
{
const index_t blocks_per_split =
ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits);
num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks);
num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks);
num_blocks_start = ck_tile::min(blocks_per_split * i_split, total_num_kv_blocks);
num_blocks = ck_tile::min(blocks_per_split * (i_split + 1), total_num_kv_blocks);
if(num_blocks_start >= num_blocks)
{
return; // this split has no work
@@ -460,56 +489,143 @@ struct UnifiedAttentionKernel
// kernel tile to fit in a cache page) is gone — tiles may span multiple
// pages as long as the inner-N step (Y0_step_N from the K/V tile dist)
// divides page_size cleanly.
//
// Pipeline returns make_tuple(o_acc, lse) where o_acc is the normalized
// attention output (post divide-by-l) and lse is the per-row log-sum-exp
// in natural-log domain. For num_splits == 1 we ignore lse and forward
// o_acc through the user's epilogue (bf16/fp16 cast + store to o_ptr).
// For num_splits > 1 we instead write o_acc and lse to FP32 workspaces
// — a separate combine kernel will merge across splits.
auto o_acc_tile = [&]() {
return UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
num_blocks,
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
kargs.page_size,
mask,
kargs.scale_s,
smem_ptr,
static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_1));
}();
auto pipeline_result = UnifiedAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
num_blocks,
num_blocks_start,
kargs.block_tables_ptr,
block_table_offset,
kargs.page_size,
mask,
kargs.scale_s,
smem_ptr,
static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_1));
auto& o_acc_tile = pipeline_result[number<0>{}];
auto& lse_tile = pipeline_result[number<1>{}];
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentO>{},
number<1>{});
if(kargs.num_splits > 1)
{
// ----- Split-KV write path -----
// Workspaces (FP32) are assumed in layout:
// o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v]
// lse_acc_ptr : [num_q_heads, num_splits, total_q]
// The host passes nhead/split strides; the q_token axis is contiguous
// (= hdim_v for o_acc, = 1 for lse_acc) so we hardcode that here.
const auto o_dram_pad =
pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded
o_dram_base,
// block sizes
const index_t head_q_base = kv_head_idx * num_queries_per_kv;
float* o_acc_base = reinterpret_cast<float*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(head_q_base) * kargs.nhead_stride_o_acc +
static_cast<long_index_t>(i_split) * kargs.split_stride_o_acc +
static_cast<long_index_t>(cur_batch_in_all_start_index) * kHeadDim;
auto o_acc_dram = [&]() {
const auto o_acc_base_view = make_naive_tensor_view<address_space_enum::global>(
o_acc_base,
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
make_tuple(static_cast<long_index_t>(kHeadDim),
static_cast<long_index_t>(kargs.nhead_stride_o_acc),
static_cast<long_index_t>(1)),
number<1>{},
number<1>{});
const auto o_acc_pad = pad_tensor_view(
o_acc_base_view,
make_tuple(kBlockQ, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
// kHeadDimPadded)
sequence<true, false, kPadHeadDimQ>{});
const auto o_dram_merged = transform_tensor_view(
o_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(kHeadDimPadded)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_view(
o_acc_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(kHeadDimPadded)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
return o_dram_merged;
}();
auto o_acc_window =
make_tile_window(o_acc_dram,
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
{query_pos * num_queries_per_kv, 0});
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
{query_pos * num_queries_per_kv, 0});
// FP32-out epilogue: cast_tile<float>(o_acc) is a no-op, but the
// pad-aware store path (UseRawStore=true) is the same machinery the
// user's epilogue uses, so storage semantics are unchanged.
using SplitOEpilogue =
Default2DEpilogue<Default2DEpilogueProblem<float, float, true, true, true>>;
SplitOEpilogue{}(o_acc_window, o_acc_tile, nullptr);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
// ----- LSE write -----
float* lse_acc_base =
reinterpret_cast<float*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(head_q_base) * kargs.nhead_stride_lse_acc +
static_cast<long_index_t>(i_split) * kargs.split_stride_lse_acc +
static_cast<long_index_t>(cur_batch_in_all_start_index);
auto lse_acc_dram = [&]() {
const auto lse_acc_base_view = make_naive_tensor_view<address_space_enum::global>(
lse_acc_base,
make_tuple(cur_batch_query_len, num_queries_per_kv),
make_tuple(static_cast<long_index_t>(1),
static_cast<long_index_t>(kargs.nhead_stride_lse_acc)),
number<1>{},
number<1>{});
const auto lse_acc_pad = pad_tensor_view(
lse_acc_base_view, make_tuple(kBlockQ, 1), sequence<true, false>{});
return transform_tensor_view(
lse_acc_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv))),
make_tuple(sequence<0, 1>{}),
make_tuple(sequence<0>{}));
}();
auto lse_acc_window =
make_tile_window(lse_acc_dram, make_tuple(number<kBlockM>{}), {query_pos * num_queries_per_kv});
store_tile(lse_acc_window, lse_tile);
}
else
{
// ----- Non-split (current) path -----
auto o_dram = [&]() {
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
number<UnifiedAttentionPipeline::kAlignmentO>{},
number<1>{});
const auto o_dram_pad =
pad_tensor_view(o_dram_base,
make_tuple(kBlockQ, 1, kHeadDimPadded),
sequence<true, false, kPadHeadDimQ>{});
return transform_tensor_view(
o_dram_pad,
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
make_pass_through_transform(kHeadDimPadded)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
{query_pos * num_queries_per_kv, 0});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
}
};
} // namespace ck_tile

View File

@@ -348,18 +348,28 @@ struct UnifiedAttentionPipeline
{
if(num_total_loop - num_blocks_start <= 0)
{
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
// Note: o_acc is already cleared above. q loaded but no fence
// (ignored). lse must be -infinity so the split-KV combine
// weighs this empty partial as zero (exp(-inf) == 0); for
// single-split callers the value is harmless (ignored).
auto lse_early =
make_static_distributed_tensor<SMPLComputeDataType>(m.get_tile_distribution());
set_tile(lse_early, -ck_tile::numeric<SMPLComputeDataType>::infinity());
return ck_tile::make_tuple(o_acc, lse_early);
}
}
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
assert(k_block_idx == v_block_idx); // because of the following line
block_table_offset += num_blocks_start;
assert(k_block_idx == v_block_idx);
// Split-KV start offset in *tokens* (not in tiles or pages). We add
// this to logical_token below so the page-table lookup uses the right
// page; we do NOT shift block_table_offset because num_blocks_start is
// counted in kPageBlockSize-sized tiles, while block_tables is indexed
// in page_size-sized pages — the two differ whenever kPageBlockSize !=
// page_size and shifting tiles-as-pages reads the wrong entries.
const index_t split_token_offset = num_blocks_start * kPageBlockSize;
// Pass-2: unified page-offset formula. The kPageBlockSize <= page_size
// constraint is gone. For every (thread, Y0-iter) pair we compute:
@@ -415,7 +425,8 @@ struct UnifiedAttentionPipeline
auto refresh_k_offsets = [&](index_t k_tile_idx) {
static_for<0, KNRepeat, 1>{}([&](auto i) {
const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos +
const index_t logical_token = split_token_offset +
k_tile_idx * kPageBlockSize + k_thread_n_pos +
static_cast<index_t>(i.value) * KY0_step_N;
const index_t logical_page = logical_token / page_size;
const index_t within_page = logical_token - logical_page * page_size;
@@ -427,7 +438,8 @@ struct UnifiedAttentionPipeline
};
auto refresh_v_offsets = [&](index_t v_tile_idx) {
static_for<0, VNRepeat, 1>{}([&](auto i) {
const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos +
const index_t logical_token = split_token_offset +
v_tile_idx * kPageBlockSize + v_thread_n_pos +
static_cast<index_t>(i.value) * VY0_step_N;
const index_t logical_page = logical_token / page_size;
const index_t within_page = logical_token - logical_page * page_size;
@@ -1155,16 +1167,24 @@ struct UnifiedAttentionPipeline
}
}
label_main_loops_exit:
if(num_total_loop % 2)
// The post-process call finalizes whichever SP register was left in
// an "alu0-done, alu1-pending" state at the end of the main loop.
// Which one that is depends on the parity of the *number of
// iterations performed* (= num_total_loop - num_blocks_start), not
// num_total_loop itself. For the non-split path num_blocks_start
// is always 0 so the two parities coincide; the split-KV path with
// num_blocks_start > 0 needs the corrected expression below.
const index_t num_iters = num_total_loop - num_blocks_start;
if(num_iters % 2)
{
fmha_post_process(number<1>{});
}
if(!(num_total_loop % 2))
if(!(num_iters % 2))
{
fmha_post_process(number<0>{});
}
// finally, O
// finally, O — normalize by l
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
@@ -1185,7 +1205,37 @@ struct UnifiedAttentionPipeline
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
// Build the log-sum-exp side-output (natural-log domain) for the
// split-KV combine kernel. For non-split callers this is ignored.
//
// Note `m` here is the *unscaled* rowmax of the raw qk dot products
// (the pipeline computes `m = block_tile_reduce(sp_compute, max)`
// before applying `scale_s`). Likewise `l = sum exp2(scale_s*(s-m))`
// is the natural-domain softmax denominator (since `scale_s` carries
// a baked-in log2(e), `exp2(scale_s*x) == exp(scale*x)`). Combined,
// LSE = log(sum exp(scale * s_k))
// = scale * m + log(l)
// = scale_s/log2(e) * m + log(l).
// The combine kernel re-weights partials with exp(lse - lse_max).
const auto scale_natlog =
scale_s / static_cast<SMPLComputeDataType>(C_LOG2E);
auto lse = make_static_distributed_tensor<SMPLComputeDataType>(m.get_tile_distribution());
sweep_tile_span(o_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if constexpr(FmhaMask::IsMasking)
{
lse(i_idx) =
(l_[i_idx] == 0.f)
? -ck_tile::numeric<SMPLComputeDataType>::infinity()
: scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]);
}
else
{
lse(i_idx) = scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]);
}
});
return ck_tile::make_tuple(o_acc, lse);
}
template <typename QDramBlockWindowTmp,