mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Adding SWA implementation + instances
This commit is contained in:
@@ -86,6 +86,12 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t stride_v_cache_3;
|
||||
ck_tile::index_t output_stride_0;
|
||||
ck_tile::index_t output_stride_1;
|
||||
|
||||
// Sliding-window-attention parameters. <0 means "unbounded on that side".
|
||||
// (left=-1, right=0, is_top_left=false) reproduces classical bottom-right causal.
|
||||
ck_tile::index_t window_size_left = -1;
|
||||
ck_tile::index_t window_size_right = -1;
|
||||
bool is_top_left = false;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
|
||||
@@ -140,7 +146,10 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t block_table_stride,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs)
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t window_size_left = -1,
|
||||
ck_tile::index_t window_size_right = -1,
|
||||
bool is_top_left = false)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -167,7 +176,10 @@ struct UnifiedAttentionKernel
|
||||
stride_v_cache_2,
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
output_stride_1,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
is_top_left},
|
||||
block_tables_ptr,
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
@@ -443,17 +455,41 @@ struct UnifiedAttentionKernel
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
-1,
|
||||
0,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false);
|
||||
kargs.window_size_left, // <0 means unbounded on the left
|
||||
kargs.window_size_right, // <0 means unbounded on the right
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated
|
||||
// num_queries_per_kv times along x dim of the tile
|
||||
kargs.is_top_left);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
// Sliding-window-attention: tighten the KV-block iteration to the row of tiles
|
||||
// that actually overlap the window for this Q-tile. Without this, blocks wholly
|
||||
// outside the window would still be loaded, scaled and masked tile-by-tile —
|
||||
// which is both slow and (because the kernel's softmax accumulator interleaves
|
||||
// m/l updates with prefetch and warp-group barriers) sensitive to having any
|
||||
// all-(-inf) blocks in the loop. Skipping them entirely keeps each iterated
|
||||
// tile either fully-inside-window or a true edge tile that per-pixel masking
|
||||
// can clean up correctly.
|
||||
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
|
||||
{
|
||||
const index_t i_y_for_mask = query_pos * num_queries_per_kv;
|
||||
const auto window_range =
|
||||
mask.GetTileRangeAlongX(i_y_for_mask,
|
||||
ck_tile::number<kBlockQ>{},
|
||||
ck_tile::number<kPageBlockSize>{});
|
||||
const index_t window_blk_lo = ck_tile::max(
|
||||
index_t(0), window_range.at(ck_tile::number<0>{}) / kPageBlockSize);
|
||||
const index_t window_blk_hi = ck_tile::min(
|
||||
total_num_kv_blocks,
|
||||
(window_range.at(ck_tile::number<1>{}) + kPageBlockSize - 1) / kPageBlockSize);
|
||||
num_blocks_start = ck_tile::max(num_blocks_start, window_blk_lo);
|
||||
num_blocks = ck_tile::min(num_blocks, window_blk_hi);
|
||||
}
|
||||
|
||||
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
|
||||
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user