diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 8b645387a4..023fc3be4e 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -67,6 +67,25 @@ struct unified_attention_args index_t num_seqs; // number of batches for q index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown) + + // KV-segment parallelism (split-KV). When num_splits == 1, the kernel + // writes to o_ptr as usual. When num_splits > 1, the kernel is launched + // with a 3D grid whose z-dim is num_splits — each CTA computes its own + // partial (o_acc, lse) and writes them into the FP32 workspaces; a + // separate combine kernel (or a Python-side reduce) merges across + // splits into the final output. + // + // Workspace layout (host-allocated): + // o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v] (FP32) + // lse_acc_ptr : [num_q_heads, num_splits, total_q] (FP32) + // The corresponding host-set strides are below. + index_t num_splits = 1; + void* o_acc_ptr = nullptr; + void* lse_acc_ptr = nullptr; + index_t split_stride_o_acc = 0; + index_t split_stride_lse_acc = 0; + index_t nhead_stride_o_acc = 0; + index_t nhead_stride_lse_acc = 0; }; std::ostream& operator<<(std::ostream& stream, diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 31e5c4c6ad..0793e0695a 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -414,16 +414,27 @@ float unified_attention_kernel_launch(const unified_attention_args& args, args.block_table_stride, args.seq_lens_ptr, args.query_start_len_ptr, - args.num_seqs); + args.num_seqs, + args.num_splits, + args.lse_acc_ptr, + args.o_acc_ptr, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc); dim3 grids; if constexpr(UseDecodeGrid) { - grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv, args.num_seqs); + grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv, + args.num_seqs, + args.num_splits); } else { - grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); + grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, + total_num_q_blocks, + args.num_splits); } constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 43a9142175..865e62a315 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/unified_attention/block/block_masking.hpp" #include "ck_tile/core/numeric/math.hpp" @@ -97,9 +98,11 @@ struct UnifiedAttentionKernel ck_tile::index_t num_seqs; // number of batches for q - // KV-segment parallelism (split-KV within unified attention) + // KV-segment parallelism (split-KV within unified attention). + // Each CTA derives its own `i_split` from `blockIdx.z` — the host + // launches a single 3D grid with z = num_splits and the kernel + // dispatches all splits in parallel. ck_tile::index_t num_splits = 1; - ck_tile::index_t i_split = 0; void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float ck_tile::index_t split_stride_lse_acc = 0; @@ -140,7 +143,14 @@ struct UnifiedAttentionKernel ck_tile::index_t block_table_stride, const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs) + ck_tile::index_t num_seqs, + ck_tile::index_t num_splits = 1, + void* lse_acc_ptr = nullptr, + void* o_acc_ptr = nullptr, + ck_tile::index_t split_stride_lse_acc = 0, + ck_tile::index_t split_stride_o_acc = 0, + ck_tile::index_t nhead_stride_lse_acc = 0, + ck_tile::index_t nhead_stride_o_acc = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -172,15 +182,25 @@ struct UnifiedAttentionKernel block_table_stride, seq_lens_ptr, query_start_len_ptr, - num_seqs}; + num_seqs, + num_splits, + lse_acc_ptr, + o_acc_ptr, + split_stride_lse_acc, + split_stride_o_acc, + nhead_stride_lse_acc, + nhead_stride_o_acc}; return kargs; } CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, - ck_tile::index_t total_num_q_blocks) + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t num_splits = 1) { - return dim3(num_kv_heads * total_num_q_blocks); + // z-dim carries the split index; num_splits == 1 is the existing + // (non-split) launch with dim3(N, 1, 1). + return dim3(num_kv_heads * total_num_q_blocks, 1, num_splits); } // Binary search to find the sequence index for a given target index @@ -231,9 +251,10 @@ struct UnifiedAttentionKernel } CK_TILE_HOST static constexpr auto GridSizeDecode(ck_tile::index_t num_kv_heads, - ck_tile::index_t num_seqs) + ck_tile::index_t num_seqs, + ck_tile::index_t num_splits = 1) { - return dim3(num_kv_heads, num_seqs); + return dim3(num_kv_heads, num_seqs, num_splits); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -244,6 +265,11 @@ struct UnifiedAttentionKernel assert(kBlockM / num_queries_per_kv == kBlockQ); + // Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The + // split index lives in z — when num_splits == 1 (the only z value) + // this is just `0` and costs nothing. + const index_t i_split = blockIdx.z; + index_t kv_head_idx; index_t seq_idx; index_t q_block_local_idx; @@ -252,7 +278,7 @@ struct UnifiedAttentionKernel if(gridDim.y > 1) { - // Decode grid: dim3(num_kv_heads, num_seqs) + // Decode grid: dim3(num_kv_heads, num_seqs, num_splits) // Direct mapping, no binary search, no padding CTAs. kv_head_idx = blockIdx.x; seq_idx = blockIdx.y; @@ -263,7 +289,8 @@ struct UnifiedAttentionKernel } else { - // Standard 1D grid with binary search + // Standard 1D grid (x-folded) with binary search; z-dim carries + // the split index just like in the decode branch. ck_tile::index_t pid = blockIdx.x; const auto [kv_head_idx_, q_block_global_idx] = GetTileIndex(pid, kargs); @@ -318,15 +345,17 @@ struct UnifiedAttentionKernel index_t total_num_kv_blocks = amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize); - // KV-segment parallelism: split KV range across workgroups + // KV-segment parallelism: split KV range across workgroups. + // `i_split` came from blockIdx.z above; with num_splits == 1 it's 0 + // and these min/max bounds reduce to [0, total_num_kv_blocks). index_t num_blocks_start = 0; index_t num_blocks = total_num_kv_blocks; if(kargs.num_splits > 1) { const index_t blocks_per_split = ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits); - num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks); - num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks); + num_blocks_start = ck_tile::min(blocks_per_split * i_split, total_num_kv_blocks); + num_blocks = ck_tile::min(blocks_per_split * (i_split + 1), total_num_kv_blocks); if(num_blocks_start >= num_blocks) { return; // this split has no work @@ -460,56 +489,143 @@ struct UnifiedAttentionKernel // kernel tile to fit in a cache page) is gone — tiles may span multiple // pages as long as the inner-N step (Y0_step_N from the K/V tile dist) // divides page_size cleanly. + // + // Pipeline returns make_tuple(o_acc, lse) where o_acc is the normalized + // attention output (post divide-by-l) and lse is the per-row log-sum-exp + // in natural-log domain. For num_splits == 1 we ignore lse and forward + // o_acc through the user's epilogue (bf16/fp16 cast + store to o_ptr). + // For num_splits > 1 we instead write o_acc and lse to FP32 workspaces + // — a separate combine kernel will merge across splits. - auto o_acc_tile = [&]() { - return UnifiedAttentionPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - num_blocks, - num_blocks_start, - kargs.block_tables_ptr, - block_table_offset, - kargs.page_size, - mask, - kargs.scale_s, - smem_ptr, - static_cast(kargs.stride_k_cache_1), - static_cast(kargs.stride_v_cache_1)); - }(); + auto pipeline_result = UnifiedAttentionPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + num_blocks, + num_blocks_start, + kargs.block_tables_ptr, + block_table_offset, + kargs.page_size, + mask, + kargs.scale_s, + smem_ptr, + static_cast(kargs.stride_k_cache_1), + static_cast(kargs.stride_v_cache_1)); + auto& o_acc_tile = pipeline_result[number<0>{}]; + auto& lse_tile = pipeline_result[number<1>{}]; - // O DRAM and O DRAM window - auto o_dram = [&]() { - const auto o_dram_base = make_naive_tensor_view( - o_ptr, - make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), - make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), - number{}, - number<1>{}); + if(kargs.num_splits > 1) + { + // ----- Split-KV write path ----- + // Workspaces (FP32) are assumed in layout: + // o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v] + // lse_acc_ptr : [num_q_heads, num_splits, total_q] + // The host passes nhead/split strides; the q_token axis is contiguous + // (= hdim_v for o_acc, = 1 for lse_acc) so we hardcode that here. - const auto o_dram_pad = - pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded - o_dram_base, - // block sizes + const index_t head_q_base = kv_head_idx * num_queries_per_kv; + + float* o_acc_base = reinterpret_cast(kargs.o_acc_ptr) + + static_cast(head_q_base) * kargs.nhead_stride_o_acc + + static_cast(i_split) * kargs.split_stride_o_acc + + static_cast(cur_batch_in_all_start_index) * kHeadDim; + + auto o_acc_dram = [&]() { + const auto o_acc_base_view = make_naive_tensor_view( + o_acc_base, + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), + make_tuple(static_cast(kHeadDim), + static_cast(kargs.nhead_stride_o_acc), + static_cast(1)), + number<1>{}, + number<1>{}); + + const auto o_acc_pad = pad_tensor_view( + o_acc_base_view, make_tuple(kBlockQ, 1, kHeadDimPadded), - sequence{}); // pads to (seq_len_padded, num_head_q, - // kHeadDimPadded) + sequence{}); - const auto o_dram_merged = transform_tensor_view( - o_dram_pad, - make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(kHeadDimPadded)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + return transform_tensor_view( + o_acc_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(kHeadDimPadded)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); - return o_dram_merged; - }(); + auto o_acc_window = + make_tile_window(o_acc_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); - auto o_dram_window = - make_tile_window(o_dram, - make_tuple(number{}, number{}), - {query_pos * num_queries_per_kv, 0}); + // FP32-out epilogue: cast_tile(o_acc) is a no-op, but the + // pad-aware store path (UseRawStore=true) is the same machinery the + // user's epilogue uses, so storage semantics are unchanged. + using SplitOEpilogue = + Default2DEpilogue>; + SplitOEpilogue{}(o_acc_window, o_acc_tile, nullptr); - EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + // ----- LSE write ----- + float* lse_acc_base = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(head_q_base) * kargs.nhead_stride_lse_acc + + static_cast(i_split) * kargs.split_stride_lse_acc + + static_cast(cur_batch_in_all_start_index); + + auto lse_acc_dram = [&]() { + const auto lse_acc_base_view = make_naive_tensor_view( + lse_acc_base, + make_tuple(cur_batch_query_len, num_queries_per_kv), + make_tuple(static_cast(1), + static_cast(kargs.nhead_stride_lse_acc)), + number<1>{}, + number<1>{}); + + const auto lse_acc_pad = pad_tensor_view( + lse_acc_base_view, make_tuple(kBlockQ, 1), sequence{}); + + return transform_tensor_view( + lse_acc_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv))), + make_tuple(sequence<0, 1>{}), + make_tuple(sequence<0>{})); + }(); + + auto lse_acc_window = + make_tile_window(lse_acc_dram, make_tuple(number{}), {query_pos * num_queries_per_kv}); + + store_tile(lse_acc_window, lse_tile); + } + else + { + // ----- Non-split (current) path ----- + auto o_dram = [&]() { + const auto o_dram_base = make_naive_tensor_view( + o_ptr, + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), + make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), + number{}, + number<1>{}); + + const auto o_dram_pad = + pad_tensor_view(o_dram_base, + make_tuple(kBlockQ, 1, kHeadDimPadded), + sequence{}); + + return transform_tensor_view( + o_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(kHeadDimPadded)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 98cf70914f..d338eb8627 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -348,18 +348,28 @@ struct UnifiedAttentionPipeline { if(num_total_loop - num_blocks_start <= 0) { - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; + // Note: o_acc is already cleared above. q loaded but no fence + // (ignored). lse must be -infinity so the split-KV combine + // weighs this empty partial as zero (exp(-inf) == 0); for + // single-split callers the value is harmless (ignored). + auto lse_early = + make_static_distributed_tensor(m.get_tile_distribution()); + set_tile(lse_early, -ck_tile::numeric::infinity()); + return ck_tile::make_tuple(o_acc, lse_early); } } index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - assert(k_block_idx == v_block_idx); // because of the following line - block_table_offset += num_blocks_start; + assert(k_block_idx == v_block_idx); + // Split-KV start offset in *tokens* (not in tiles or pages). We add + // this to logical_token below so the page-table lookup uses the right + // page; we do NOT shift block_table_offset because num_blocks_start is + // counted in kPageBlockSize-sized tiles, while block_tables is indexed + // in page_size-sized pages — the two differ whenever kPageBlockSize != + // page_size and shifting tiles-as-pages reads the wrong entries. + const index_t split_token_offset = num_blocks_start * kPageBlockSize; // Pass-2: unified page-offset formula. The kPageBlockSize <= page_size // constraint is gone. For every (thread, Y0-iter) pair we compute: @@ -415,7 +425,8 @@ struct UnifiedAttentionPipeline auto refresh_k_offsets = [&](index_t k_tile_idx) { static_for<0, KNRepeat, 1>{}([&](auto i) { - const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos + + const index_t logical_token = split_token_offset + + k_tile_idx * kPageBlockSize + k_thread_n_pos + static_cast(i.value) * KY0_step_N; const index_t logical_page = logical_token / page_size; const index_t within_page = logical_token - logical_page * page_size; @@ -427,7 +438,8 @@ struct UnifiedAttentionPipeline }; auto refresh_v_offsets = [&](index_t v_tile_idx) { static_for<0, VNRepeat, 1>{}([&](auto i) { - const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos + + const index_t logical_token = split_token_offset + + v_tile_idx * kPageBlockSize + v_thread_n_pos + static_cast(i.value) * VY0_step_N; const index_t logical_page = logical_token / page_size; const index_t within_page = logical_token - logical_page * page_size; @@ -1155,16 +1167,24 @@ struct UnifiedAttentionPipeline } } label_main_loops_exit: - if(num_total_loop % 2) + // The post-process call finalizes whichever SP register was left in + // an "alu0-done, alu1-pending" state at the end of the main loop. + // Which one that is depends on the parity of the *number of + // iterations performed* (= num_total_loop - num_blocks_start), not + // num_total_loop itself. For the non-split path num_blocks_start + // is always 0 so the two parities coincide; the split-KV path with + // num_blocks_start > 0 needs the corrected expression below. + const index_t num_iters = num_total_loop - num_blocks_start; + if(num_iters % 2) { fmha_post_process(number<1>{}); } - if(!(num_total_loop % 2)) + if(!(num_iters % 2)) { fmha_post_process(number<0>{}); } - // finally, O + // finally, O — normalize by l constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { @@ -1185,7 +1205,37 @@ struct UnifiedAttentionPipeline o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; + // Build the log-sum-exp side-output (natural-log domain) for the + // split-KV combine kernel. For non-split callers this is ignored. + // + // Note `m` here is the *unscaled* rowmax of the raw qk dot products + // (the pipeline computes `m = block_tile_reduce(sp_compute, max)` + // before applying `scale_s`). Likewise `l = sum exp2(scale_s*(s-m))` + // is the natural-domain softmax denominator (since `scale_s` carries + // a baked-in log2(e), `exp2(scale_s*x) == exp(scale*x)`). Combined, + // LSE = log(sum exp(scale * s_k)) + // = scale * m + log(l) + // = scale_s/log2(e) * m + log(l). + // The combine kernel re-weights partials with exp(lse - lse_max). + const auto scale_natlog = + scale_s / static_cast(C_LOG2E); + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + sweep_tile_span(o_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + if constexpr(FmhaMask::IsMasking) + { + lse(i_idx) = + (l_[i_idx] == 0.f) + ? -ck_tile::numeric::infinity() + : scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]); + } + else + { + lse(i_idx) = scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]); + } + }); + + return ck_tile::make_tuple(o_acc, lse); } template