Adding SWA implementation + instances

This commit is contained in:
Damien Lejeune
2026-05-08 08:52:25 +00:00
parent 076d505826
commit 5afd97ff5b
12 changed files with 272 additions and 53 deletions

View File

@@ -86,6 +86,12 @@ struct UnifiedAttentionKernel
ck_tile::index_t stride_v_cache_3;
ck_tile::index_t output_stride_0;
ck_tile::index_t output_stride_1;
// Sliding-window-attention parameters. <0 means "unbounded on that side".
// (left=-1, right=0, is_top_left=false) reproduces classical bottom-right causal.
ck_tile::index_t window_size_left = -1;
ck_tile::index_t window_size_right = -1;
bool is_top_left = false;
};
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
@@ -140,7 +146,10 @@ struct UnifiedAttentionKernel
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs)
ck_tile::index_t num_seqs,
ck_tile::index_t window_size_left = -1,
ck_tile::index_t window_size_right = -1,
bool is_top_left = false)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -167,7 +176,10 @@ struct UnifiedAttentionKernel
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
output_stride_1,
window_size_left,
window_size_right,
is_top_left},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
@@ -443,17 +455,41 @@ struct UnifiedAttentionKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
-1,
0,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false);
kargs.window_size_left, // <0 means unbounded on the left
kargs.window_size_right, // <0 means unbounded on the right
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated
// num_queries_per_kv times along x dim of the tile
kargs.is_top_left);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
// Sliding-window-attention: tighten the KV-block iteration to the row of tiles
// that actually overlap the window for this Q-tile. Without this, blocks wholly
// outside the window would still be loaded, scaled and masked tile-by-tile —
// which is both slow and (because the kernel's softmax accumulator interleaves
// m/l updates with prefetch and warp-group barriers) sensitive to having any
// all-(-inf) blocks in the loop. Skipping them entirely keeps each iterated
// tile either fully-inside-window or a true edge tile that per-pixel masking
// can clean up correctly.
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
{
const index_t i_y_for_mask = query_pos * num_queries_per_kv;
const auto window_range =
mask.GetTileRangeAlongX(i_y_for_mask,
ck_tile::number<kBlockQ>{},
ck_tile::number<kPageBlockSize>{});
const index_t window_blk_lo = ck_tile::max(
index_t(0), window_range.at(ck_tile::number<0>{}) / kPageBlockSize);
const index_t window_blk_hi = ck_tile::min(
total_num_kv_blocks,
(window_range.at(ck_tile::number<1>{}) + kPageBlockSize - 1) / kPageBlockSize);
num_blocks_start = ck_tile::max(num_blocks_start, window_blk_lo);
num_blocks = ck_tile::min(num_blocks, window_blk_hi);
}
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size

View File

@@ -360,8 +360,22 @@ struct UnifiedAttentionPipeline
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
assert(k_block_idx == v_block_idx); // because of the following line
block_table_offset += num_blocks_start;
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
// num_blocks_start is in kPageBlockSize sub-block units, but block_table_offset
// (and block_tables_ptr) are in PageSize-page units. Convert correctly so that
// a sub-block-aligned start position lands on the right page AND the right
// sub-block within that page. This matters for SWA (tile_lo > 0) and for
// split-KV when kv_page_size_in_blocks > 1.
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
const index_t init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
block_table_offset += page_advance;
// k_block_idx counts sub-blocks relative to the first iterated page. Starting it
// at init_sub_block_offset lets the existing K_mem_load math (k_block_idx /
// kv_page_size_in_blocks for the page index, k_block_idx % kv_page_size_in_blocks
// for the within-page sub-block) keep working unchanged.
k_block_idx = init_sub_block_offset;
v_block_idx = init_sub_block_offset;
index_t kv_blk_idx_initial =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
// When row strides are provided, use pointer rebasing to avoid int32 overflow
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
@@ -376,16 +390,25 @@ struct UnifiedAttentionPipeline
auto* k_base_ptr = k_view.buf_.p_data_;
auto* v_base_ptr = v_view.buf_.p_data_;
// Within-page byte offset (in K/V rows) for the very first iterated sub-block.
// Non-zero whenever num_blocks_start is not a multiple of kv_page_size_in_blocks.
const index_t initial_intra_page_row_offset =
init_sub_block_offset * kPageBlockSize;
if(use_ptr_rebase)
{
long_index_t k_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
static_cast<long_index_t>(initial_intra_page_row_offset)) *
k_row_stride;
k_view.buf_.p_data_ = k_base_ptr + k_off;
auto new_k = k_view.buf_.buffer_size_ - k_off;
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
long_index_t v_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
static_cast<long_index_t>(initial_intra_page_row_offset)) *
v_row_stride;
v_view.buf_.p_data_ = v_base_ptr + v_off;
auto new_v = v_view.buf_.buffer_size_ - v_off;
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
@@ -397,7 +420,10 @@ struct UnifiedAttentionPipeline
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
}
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
const index_t init_origin = use_ptr_rebase
? 0
: kv_blk_idx_initial * PageSize +
initial_intra_page_row_offset;
auto k_dram_window =
make_tile_window(k_view,
@@ -1147,11 +1173,20 @@ struct UnifiedAttentionPipeline
}
}
label_main_loops_exit:
if(num_total_loop % 2)
// Post-process must consume the *last iteration's* sp/V buffer slot. Pre-stage
// always writes to slot 0; the main loop alternates 1, 0, 1, ... So the last
// written slot is determined by the number of iterations actually executed
// (= num_total_loop - num_blocks_start), not by num_total_loop alone. Keying
// on num_total_loop matches when num_blocks_start is even but reads
// uninitialised sp(1)/V[1] when it is odd, which silently corrupts o_acc.
// This previously only mattered for split-KV with an odd split boundary; SWA
// exposes it whenever the per-Q-tile lower-bound clip is odd.
const index_t num_iters = num_total_loop - num_blocks_start;
if(num_iters % 2)
{
fmha_post_process(number<1>{});
}
if(!(num_total_loop % 2))
if(!(num_iters % 2))
{
fmha_post_process(number<0>{});
}