Adding SWA implementation + instances

This commit is contained in:
Damien Lejeune
2026-05-08 08:52:25 +00:00
parent 076d505826
commit 5afd97ff5b
12 changed files with 272 additions and 53 deletions

View File

@@ -86,6 +86,12 @@ struct UnifiedAttentionKernel
ck_tile::index_t stride_v_cache_3;
ck_tile::index_t output_stride_0;
ck_tile::index_t output_stride_1;
// Sliding-window-attention parameters. <0 means "unbounded on that side".
// (left=-1, right=0, is_top_left=false) reproduces classical bottom-right causal.
ck_tile::index_t window_size_left = -1;
ck_tile::index_t window_size_right = -1;
bool is_top_left = false;
};
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
@@ -140,7 +146,10 @@ struct UnifiedAttentionKernel
ck_tile::index_t block_table_stride,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs)
ck_tile::index_t num_seqs,
ck_tile::index_t window_size_left = -1,
ck_tile::index_t window_size_right = -1,
bool is_top_left = false)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -167,7 +176,10 @@ struct UnifiedAttentionKernel
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
output_stride_1,
window_size_left,
window_size_right,
is_top_left},
block_tables_ptr,
block_table_stride,
seq_lens_ptr,
@@ -443,17 +455,41 @@ struct UnifiedAttentionKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
-1,
0,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false);
kargs.window_size_left, // <0 means unbounded on the left
kargs.window_size_right, // <0 means unbounded on the right
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated
// num_queries_per_kv times along x dim of the tile
kargs.is_top_left);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();
// Sliding-window-attention: tighten the KV-block iteration to the row of tiles
// that actually overlap the window for this Q-tile. Without this, blocks wholly
// outside the window would still be loaded, scaled and masked tile-by-tile —
// which is both slow and (because the kernel's softmax accumulator interleaves
// m/l updates with prefetch and warp-group barriers) sensitive to having any
// all-(-inf) blocks in the loop. Skipping them entirely keeps each iterated
// tile either fully-inside-window or a true edge tile that per-pixel masking
// can clean up correctly.
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
{
const index_t i_y_for_mask = query_pos * num_queries_per_kv;
const auto window_range =
mask.GetTileRangeAlongX(i_y_for_mask,
ck_tile::number<kBlockQ>{},
ck_tile::number<kPageBlockSize>{});
const index_t window_blk_lo = ck_tile::max(
index_t(0), window_range.at(ck_tile::number<0>{}) / kPageBlockSize);
const index_t window_blk_hi = ck_tile::min(
total_num_kv_blocks,
(window_range.at(ck_tile::number<1>{}) + kPageBlockSize - 1) / kPageBlockSize);
num_blocks_start = ck_tile::max(num_blocks_start, window_blk_lo);
num_blocks = ck_tile::min(num_blocks, window_blk_hi);
}
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size