mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add KV-segment parallelism to CK unified attention pipeline
End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.
Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
o_acc. The masked-empty early-exit returns lse = -inf so downstream
combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
m / log2(e) + log(l), which dropped the per-head scale and produced
LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
alu0-done/alu1-pending state at loop exit depends on the parity of
the *iteration count* (= num_total_loop - num_blocks_start), not on
num_total_loop alone. For non-split (num_blocks_start == 0) the two
parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
kPageBlockSize-sized tiles, but block_tables is indexed in
page_size-sized pages — shifting block_table_offset by num_blocks_start
reads the wrong pages whenever kPageBlockSize != page_size. Replaced
with split_token_offset = num_blocks_start * kPageBlockSize added to
logical_token before /page_size, so the page lookup uses the absolute
token position.
Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
(defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
epilogue and instead writes the FP32 (o_acc, lse) tile pair into
workspace tensors at [head, split, batch_start_token, ...] using
Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
lse. Host strides come from kargs.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -67,6 +67,25 @@ struct unified_attention_args
|
||||
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
|
||||
// KV-segment parallelism (split-KV). When num_splits == 1, the kernel
|
||||
// writes to o_ptr as usual. When num_splits > 1, the kernel is launched
|
||||
// with a 3D grid whose z-dim is num_splits — each CTA computes its own
|
||||
// partial (o_acc, lse) and writes them into the FP32 workspaces; a
|
||||
// separate combine kernel (or a Python-side reduce) merges across
|
||||
// splits into the final output.
|
||||
//
|
||||
// Workspace layout (host-allocated):
|
||||
// o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v] (FP32)
|
||||
// lse_acc_ptr : [num_q_heads, num_splits, total_q] (FP32)
|
||||
// The corresponding host-set strides are below.
|
||||
index_t num_splits = 1;
|
||||
void* o_acc_ptr = nullptr;
|
||||
void* lse_acc_ptr = nullptr;
|
||||
index_t split_stride_o_acc = 0;
|
||||
index_t split_stride_lse_acc = 0;
|
||||
index_t nhead_stride_o_acc = 0;
|
||||
index_t nhead_stride_lse_acc = 0;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
|
||||
@@ -414,16 +414,27 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
args.block_table_stride,
|
||||
args.seq_lens_ptr,
|
||||
args.query_start_len_ptr,
|
||||
args.num_seqs);
|
||||
args.num_seqs,
|
||||
args.num_splits,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(UseDecodeGrid)
|
||||
{
|
||||
grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv, args.num_seqs);
|
||||
grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv,
|
||||
args.num_seqs,
|
||||
args.num_splits);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv,
|
||||
total_num_q_blocks,
|
||||
args.num_splits);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
|
||||
@@ -97,9 +98,11 @@ struct UnifiedAttentionKernel
|
||||
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
|
||||
// KV-segment parallelism (split-KV within unified attention)
|
||||
// KV-segment parallelism (split-KV within unified attention).
|
||||
// Each CTA derives its own `i_split` from `blockIdx.z` — the host
|
||||
// launches a single 3D grid with z = num_splits and the kernel
|
||||
// dispatches all splits in parallel.
|
||||
ck_tile::index_t num_splits = 1;
|
||||
ck_tile::index_t i_split = 0;
|
||||
void* lse_acc_ptr = nullptr; // [nhead, num_splits, total_q] float
|
||||
void* o_acc_ptr = nullptr; // [nhead, num_splits, total_q, hdim_v] float
|
||||
ck_tile::index_t split_stride_lse_acc = 0;
|
||||
@@ -140,7 +143,14 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t block_table_stride,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs)
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t num_splits = 1,
|
||||
void* lse_acc_ptr = nullptr,
|
||||
void* o_acc_ptr = nullptr,
|
||||
ck_tile::index_t split_stride_lse_acc = 0,
|
||||
ck_tile::index_t split_stride_o_acc = 0,
|
||||
ck_tile::index_t nhead_stride_lse_acc = 0,
|
||||
ck_tile::index_t nhead_stride_o_acc = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -172,15 +182,25 @@ struct UnifiedAttentionKernel
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
query_start_len_ptr,
|
||||
num_seqs};
|
||||
num_seqs,
|
||||
num_splits,
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t total_num_q_blocks)
|
||||
ck_tile::index_t total_num_q_blocks,
|
||||
ck_tile::index_t num_splits = 1)
|
||||
{
|
||||
return dim3(num_kv_heads * total_num_q_blocks);
|
||||
// z-dim carries the split index; num_splits == 1 is the existing
|
||||
// (non-split) launch with dim3(N, 1, 1).
|
||||
return dim3(num_kv_heads * total_num_q_blocks, 1, num_splits);
|
||||
}
|
||||
|
||||
// Binary search to find the sequence index for a given target index
|
||||
@@ -231,9 +251,10 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSizeDecode(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t num_seqs)
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t num_splits = 1)
|
||||
{
|
||||
return dim3(num_kv_heads, num_seqs);
|
||||
return dim3(num_kv_heads, num_seqs, num_splits);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -244,6 +265,11 @@ struct UnifiedAttentionKernel
|
||||
|
||||
assert(kBlockM / num_queries_per_kv == kBlockQ);
|
||||
|
||||
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
|
||||
// split index lives in z — when num_splits == 1 (the only z value)
|
||||
// this is just `0` and costs nothing.
|
||||
const index_t i_split = blockIdx.z;
|
||||
|
||||
index_t kv_head_idx;
|
||||
index_t seq_idx;
|
||||
index_t q_block_local_idx;
|
||||
@@ -252,7 +278,7 @@ struct UnifiedAttentionKernel
|
||||
|
||||
if(gridDim.y > 1)
|
||||
{
|
||||
// Decode grid: dim3(num_kv_heads, num_seqs)
|
||||
// Decode grid: dim3(num_kv_heads, num_seqs, num_splits)
|
||||
// Direct mapping, no binary search, no padding CTAs.
|
||||
kv_head_idx = blockIdx.x;
|
||||
seq_idx = blockIdx.y;
|
||||
@@ -263,7 +289,8 @@ struct UnifiedAttentionKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
// Standard 1D grid with binary search
|
||||
// Standard 1D grid (x-folded) with binary search; z-dim carries
|
||||
// the split index just like in the decode branch.
|
||||
ck_tile::index_t pid = blockIdx.x;
|
||||
|
||||
const auto [kv_head_idx_, q_block_global_idx] = GetTileIndex(pid, kargs);
|
||||
@@ -318,15 +345,17 @@ struct UnifiedAttentionKernel
|
||||
index_t total_num_kv_blocks =
|
||||
amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize);
|
||||
|
||||
// KV-segment parallelism: split KV range across workgroups
|
||||
// KV-segment parallelism: split KV range across workgroups.
|
||||
// `i_split` came from blockIdx.z above; with num_splits == 1 it's 0
|
||||
// and these min/max bounds reduce to [0, total_num_kv_blocks).
|
||||
index_t num_blocks_start = 0;
|
||||
index_t num_blocks = total_num_kv_blocks;
|
||||
if(kargs.num_splits > 1)
|
||||
{
|
||||
const index_t blocks_per_split =
|
||||
ck_tile::max(index_t(1), (total_num_kv_blocks + kargs.num_splits - 1) / kargs.num_splits);
|
||||
num_blocks_start = ck_tile::min(blocks_per_split * kargs.i_split, total_num_kv_blocks);
|
||||
num_blocks = ck_tile::min(blocks_per_split * (kargs.i_split + 1), total_num_kv_blocks);
|
||||
num_blocks_start = ck_tile::min(blocks_per_split * i_split, total_num_kv_blocks);
|
||||
num_blocks = ck_tile::min(blocks_per_split * (i_split + 1), total_num_kv_blocks);
|
||||
if(num_blocks_start >= num_blocks)
|
||||
{
|
||||
return; // this split has no work
|
||||
@@ -460,56 +489,143 @@ struct UnifiedAttentionKernel
|
||||
// kernel tile to fit in a cache page) is gone — tiles may span multiple
|
||||
// pages as long as the inner-N step (Y0_step_N from the K/V tile dist)
|
||||
// divides page_size cleanly.
|
||||
//
|
||||
// Pipeline returns make_tuple(o_acc, lse) where o_acc is the normalized
|
||||
// attention output (post divide-by-l) and lse is the per-row log-sum-exp
|
||||
// in natural-log domain. For num_splits == 1 we ignore lse and forward
|
||||
// o_acc through the user's epilogue (bf16/fp16 cast + store to o_ptr).
|
||||
// For num_splits > 1 we instead write o_acc and lse to FP32 workspaces
|
||||
// — a separate combine kernel will merge across splits.
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return UnifiedAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
num_blocks,
|
||||
num_blocks_start,
|
||||
kargs.block_tables_ptr,
|
||||
block_table_offset,
|
||||
kargs.page_size,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
static_cast<long_index_t>(kargs.stride_k_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_1));
|
||||
}();
|
||||
auto pipeline_result = UnifiedAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
num_blocks,
|
||||
num_blocks_start,
|
||||
kargs.block_tables_ptr,
|
||||
block_table_offset,
|
||||
kargs.page_size,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
static_cast<long_index_t>(kargs.stride_k_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_1));
|
||||
auto& o_acc_tile = pipeline_result[number<0>{}];
|
||||
auto& lse_tile = pipeline_result[number<1>{}];
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
|
||||
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
if(kargs.num_splits > 1)
|
||||
{
|
||||
// ----- Split-KV write path -----
|
||||
// Workspaces (FP32) are assumed in layout:
|
||||
// o_acc_ptr : [num_q_heads, num_splits, total_q, hdim_v]
|
||||
// lse_acc_ptr : [num_q_heads, num_splits, total_q]
|
||||
// The host passes nhead/split strides; the q_token axis is contiguous
|
||||
// (= hdim_v for o_acc, = 1 for lse_acc) so we hardcode that here.
|
||||
|
||||
const auto o_dram_pad =
|
||||
pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded
|
||||
o_dram_base,
|
||||
// block sizes
|
||||
const index_t head_q_base = kv_head_idx * num_queries_per_kv;
|
||||
|
||||
float* o_acc_base = reinterpret_cast<float*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(head_q_base) * kargs.nhead_stride_o_acc +
|
||||
static_cast<long_index_t>(i_split) * kargs.split_stride_o_acc +
|
||||
static_cast<long_index_t>(cur_batch_in_all_start_index) * kHeadDim;
|
||||
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_base_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_base,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
|
||||
make_tuple(static_cast<long_index_t>(kHeadDim),
|
||||
static_cast<long_index_t>(kargs.nhead_stride_o_acc),
|
||||
static_cast<long_index_t>(1)),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
const auto o_acc_pad = pad_tensor_view(
|
||||
o_acc_base_view,
|
||||
make_tuple(kBlockQ, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{}); // pads to (seq_len_padded, num_head_q,
|
||||
// kHeadDimPadded)
|
||||
sequence<true, false, kPadHeadDimQ>{});
|
||||
|
||||
const auto o_dram_merged = transform_tensor_view(
|
||||
o_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(kHeadDimPadded)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return transform_tensor_view(
|
||||
o_acc_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(kHeadDimPadded)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}();
|
||||
|
||||
return o_dram_merged;
|
||||
}();
|
||||
auto o_acc_window =
|
||||
make_tile_window(o_acc_dram,
|
||||
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
|
||||
{query_pos * num_queries_per_kv, 0});
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
|
||||
{query_pos * num_queries_per_kv, 0});
|
||||
// FP32-out epilogue: cast_tile<float>(o_acc) is a no-op, but the
|
||||
// pad-aware store path (UseRawStore=true) is the same machinery the
|
||||
// user's epilogue uses, so storage semantics are unchanged.
|
||||
using SplitOEpilogue =
|
||||
Default2DEpilogue<Default2DEpilogueProblem<float, float, true, true, true>>;
|
||||
SplitOEpilogue{}(o_acc_window, o_acc_tile, nullptr);
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
// ----- LSE write -----
|
||||
float* lse_acc_base =
|
||||
reinterpret_cast<float*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(head_q_base) * kargs.nhead_stride_lse_acc +
|
||||
static_cast<long_index_t>(i_split) * kargs.split_stride_lse_acc +
|
||||
static_cast<long_index_t>(cur_batch_in_all_start_index);
|
||||
|
||||
auto lse_acc_dram = [&]() {
|
||||
const auto lse_acc_base_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_base,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv),
|
||||
make_tuple(static_cast<long_index_t>(1),
|
||||
static_cast<long_index_t>(kargs.nhead_stride_lse_acc)),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
const auto lse_acc_pad = pad_tensor_view(
|
||||
lse_acc_base_view, make_tuple(kBlockQ, 1), sequence<true, false>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
lse_acc_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv))),
|
||||
make_tuple(sequence<0, 1>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
}();
|
||||
|
||||
auto lse_acc_window =
|
||||
make_tile_window(lse_acc_dram, make_tuple(number<kBlockM>{}), {query_pos * num_queries_per_kv});
|
||||
|
||||
store_tile(lse_acc_window, lse_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
// ----- Non-split (current) path -----
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_base = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim),
|
||||
make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1),
|
||||
number<UnifiedAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
const auto o_dram_pad =
|
||||
pad_tensor_view(o_dram_base,
|
||||
make_tuple(kBlockQ, 1, kHeadDimPadded),
|
||||
sequence<true, false, kPadHeadDimQ>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
o_dram_pad,
|
||||
make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)),
|
||||
make_pass_through_transform(kHeadDimPadded)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<kBlockM>{}, number<kHeadDimPadded>{}),
|
||||
{query_pos * num_queries_per_kv, 0});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -348,18 +348,28 @@ struct UnifiedAttentionPipeline
|
||||
{
|
||||
if(num_total_loop - num_blocks_start <= 0)
|
||||
{
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
// Note: o_acc is already cleared above. q loaded but no fence
|
||||
// (ignored). lse must be -infinity so the split-KV combine
|
||||
// weighs this empty partial as zero (exp(-inf) == 0); for
|
||||
// single-split callers the value is harmless (ignored).
|
||||
auto lse_early =
|
||||
make_static_distributed_tensor<SMPLComputeDataType>(m.get_tile_distribution());
|
||||
set_tile(lse_early, -ck_tile::numeric<SMPLComputeDataType>::infinity());
|
||||
return ck_tile::make_tuple(o_acc, lse_early);
|
||||
}
|
||||
}
|
||||
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
assert(k_block_idx == v_block_idx); // because of the following line
|
||||
block_table_offset += num_blocks_start;
|
||||
assert(k_block_idx == v_block_idx);
|
||||
// Split-KV start offset in *tokens* (not in tiles or pages). We add
|
||||
// this to logical_token below so the page-table lookup uses the right
|
||||
// page; we do NOT shift block_table_offset because num_blocks_start is
|
||||
// counted in kPageBlockSize-sized tiles, while block_tables is indexed
|
||||
// in page_size-sized pages — the two differ whenever kPageBlockSize !=
|
||||
// page_size and shifting tiles-as-pages reads the wrong entries.
|
||||
const index_t split_token_offset = num_blocks_start * kPageBlockSize;
|
||||
|
||||
// Pass-2: unified page-offset formula. The kPageBlockSize <= page_size
|
||||
// constraint is gone. For every (thread, Y0-iter) pair we compute:
|
||||
@@ -415,7 +425,8 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto refresh_k_offsets = [&](index_t k_tile_idx) {
|
||||
static_for<0, KNRepeat, 1>{}([&](auto i) {
|
||||
const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos +
|
||||
const index_t logical_token = split_token_offset +
|
||||
k_tile_idx * kPageBlockSize + k_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * KY0_step_N;
|
||||
const index_t logical_page = logical_token / page_size;
|
||||
const index_t within_page = logical_token - logical_page * page_size;
|
||||
@@ -427,7 +438,8 @@ struct UnifiedAttentionPipeline
|
||||
};
|
||||
auto refresh_v_offsets = [&](index_t v_tile_idx) {
|
||||
static_for<0, VNRepeat, 1>{}([&](auto i) {
|
||||
const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos +
|
||||
const index_t logical_token = split_token_offset +
|
||||
v_tile_idx * kPageBlockSize + v_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * VY0_step_N;
|
||||
const index_t logical_page = logical_token / page_size;
|
||||
const index_t within_page = logical_token - logical_page * page_size;
|
||||
@@ -1155,16 +1167,24 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}
|
||||
label_main_loops_exit:
|
||||
if(num_total_loop % 2)
|
||||
// The post-process call finalizes whichever SP register was left in
|
||||
// an "alu0-done, alu1-pending" state at the end of the main loop.
|
||||
// Which one that is depends on the parity of the *number of
|
||||
// iterations performed* (= num_total_loop - num_blocks_start), not
|
||||
// num_total_loop itself. For the non-split path num_blocks_start
|
||||
// is always 0 so the two parities coincide; the split-KV path with
|
||||
// num_blocks_start > 0 needs the corrected expression below.
|
||||
const index_t num_iters = num_total_loop - num_blocks_start;
|
||||
if(num_iters % 2)
|
||||
{
|
||||
fmha_post_process(number<1>{});
|
||||
}
|
||||
if(!(num_total_loop % 2))
|
||||
if(!(num_iters % 2))
|
||||
{
|
||||
fmha_post_process(number<0>{});
|
||||
}
|
||||
|
||||
// finally, O
|
||||
// finally, O — normalize by l
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -1185,7 +1205,37 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
// Build the log-sum-exp side-output (natural-log domain) for the
|
||||
// split-KV combine kernel. For non-split callers this is ignored.
|
||||
//
|
||||
// Note `m` here is the *unscaled* rowmax of the raw qk dot products
|
||||
// (the pipeline computes `m = block_tile_reduce(sp_compute, max)`
|
||||
// before applying `scale_s`). Likewise `l = sum exp2(scale_s*(s-m))`
|
||||
// is the natural-domain softmax denominator (since `scale_s` carries
|
||||
// a baked-in log2(e), `exp2(scale_s*x) == exp(scale*x)`). Combined,
|
||||
// LSE = log(sum exp(scale * s_k))
|
||||
// = scale * m + log(l)
|
||||
// = scale_s/log2(e) * m + log(l).
|
||||
// The combine kernel re-weights partials with exp(lse - lse_max).
|
||||
const auto scale_natlog =
|
||||
scale_s / static_cast<SMPLComputeDataType>(C_LOG2E);
|
||||
auto lse = make_static_distributed_tensor<SMPLComputeDataType>(m.get_tile_distribution());
|
||||
sweep_tile_span(o_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
lse(i_idx) =
|
||||
(l_[i_idx] == 0.f)
|
||||
? -ck_tile::numeric<SMPLComputeDataType>::infinity()
|
||||
: scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = scale_natlog * m_[i_idx] + ck_tile::log(l_[i_idx]);
|
||||
}
|
||||
});
|
||||
|
||||
return ck_tile::make_tuple(o_acc, lse);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
|
||||
Reference in New Issue
Block a user