mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Adding SWA implementation + instances
This commit is contained in:
@@ -86,6 +86,12 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t stride_v_cache_3;
|
||||
ck_tile::index_t output_stride_0;
|
||||
ck_tile::index_t output_stride_1;
|
||||
|
||||
// Sliding-window-attention parameters. <0 means "unbounded on that side".
|
||||
// (left=-1, right=0, is_top_left=false) reproduces classical bottom-right causal.
|
||||
ck_tile::index_t window_size_left = -1;
|
||||
ck_tile::index_t window_size_right = -1;
|
||||
bool is_top_left = false;
|
||||
};
|
||||
|
||||
struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs
|
||||
@@ -140,7 +146,10 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t block_table_stride,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs)
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t window_size_left = -1,
|
||||
ck_tile::index_t window_size_right = -1,
|
||||
bool is_top_left = false)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -167,7 +176,10 @@ struct UnifiedAttentionKernel
|
||||
stride_v_cache_2,
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
output_stride_1,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
is_top_left},
|
||||
block_tables_ptr,
|
||||
block_table_stride,
|
||||
seq_lens_ptr,
|
||||
@@ -443,17 +455,41 @@ struct UnifiedAttentionKernel
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
-1,
|
||||
0,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false);
|
||||
kargs.window_size_left, // <0 means unbounded on the left
|
||||
kargs.window_size_right, // <0 means unbounded on the right
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated
|
||||
// num_queries_per_kv times along x dim of the tile
|
||||
kargs.is_top_left);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
// Sliding-window-attention: tighten the KV-block iteration to the row of tiles
|
||||
// that actually overlap the window for this Q-tile. Without this, blocks wholly
|
||||
// outside the window would still be loaded, scaled and masked tile-by-tile —
|
||||
// which is both slow and (because the kernel's softmax accumulator interleaves
|
||||
// m/l updates with prefetch and warp-group barriers) sensitive to having any
|
||||
// all-(-inf) blocks in the loop. Skipping them entirely keeps each iterated
|
||||
// tile either fully-inside-window or a true edge tile that per-pixel masking
|
||||
// can clean up correctly.
|
||||
if constexpr(FmhaMask::IsMasking && FmhaMask::IsLocal)
|
||||
{
|
||||
const index_t i_y_for_mask = query_pos * num_queries_per_kv;
|
||||
const auto window_range =
|
||||
mask.GetTileRangeAlongX(i_y_for_mask,
|
||||
ck_tile::number<kBlockQ>{},
|
||||
ck_tile::number<kPageBlockSize>{});
|
||||
const index_t window_blk_lo = ck_tile::max(
|
||||
index_t(0), window_range.at(ck_tile::number<0>{}) / kPageBlockSize);
|
||||
const index_t window_blk_hi = ck_tile::min(
|
||||
total_num_kv_blocks,
|
||||
(window_range.at(ck_tile::number<1>{}) + kPageBlockSize - 1) / kPageBlockSize);
|
||||
num_blocks_start = ck_tile::max(num_blocks_start, window_blk_lo);
|
||||
num_blocks = ck_tile::min(num_blocks, window_blk_hi);
|
||||
}
|
||||
|
||||
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
|
||||
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
|
||||
|
||||
|
||||
@@ -360,8 +360,22 @@ struct UnifiedAttentionPipeline
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
assert(k_block_idx == v_block_idx); // because of the following line
|
||||
block_table_offset += num_blocks_start;
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
// num_blocks_start is in kPageBlockSize sub-block units, but block_table_offset
|
||||
// (and block_tables_ptr) are in PageSize-page units. Convert correctly so that
|
||||
// a sub-block-aligned start position lands on the right page AND the right
|
||||
// sub-block within that page. This matters for SWA (tile_lo > 0) and for
|
||||
// split-KV when kv_page_size_in_blocks > 1.
|
||||
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
|
||||
const index_t init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
|
||||
block_table_offset += page_advance;
|
||||
// k_block_idx counts sub-blocks relative to the first iterated page. Starting it
|
||||
// at init_sub_block_offset lets the existing K_mem_load math (k_block_idx /
|
||||
// kv_page_size_in_blocks for the page index, k_block_idx % kv_page_size_in_blocks
|
||||
// for the within-page sub-block) keep working unchanged.
|
||||
k_block_idx = init_sub_block_offset;
|
||||
v_block_idx = init_sub_block_offset;
|
||||
index_t kv_blk_idx_initial =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
|
||||
// When row strides are provided, use pointer rebasing to avoid int32 overflow
|
||||
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
|
||||
@@ -376,16 +390,25 @@ struct UnifiedAttentionPipeline
|
||||
auto* k_base_ptr = k_view.buf_.p_data_;
|
||||
auto* v_base_ptr = v_view.buf_.p_data_;
|
||||
|
||||
// Within-page byte offset (in K/V rows) for the very first iterated sub-block.
|
||||
// Non-zero whenever num_blocks_start is not a multiple of kv_page_size_in_blocks.
|
||||
const index_t initial_intra_page_row_offset =
|
||||
init_sub_block_offset * kPageBlockSize;
|
||||
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
|
||||
static_cast<long_index_t>(initial_intra_page_row_offset)) *
|
||||
k_row_stride;
|
||||
k_view.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_view.buf_.buffer_size_ - k_off;
|
||||
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
|
||||
static_cast<long_index_t>(initial_intra_page_row_offset)) *
|
||||
v_row_stride;
|
||||
v_view.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_view.buf_.buffer_size_ - v_off;
|
||||
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
@@ -397,7 +420,10 @@ struct UnifiedAttentionPipeline
|
||||
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
}
|
||||
|
||||
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
|
||||
const index_t init_origin = use_ptr_rebase
|
||||
? 0
|
||||
: kv_blk_idx_initial * PageSize +
|
||||
initial_intra_page_row_offset;
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_view,
|
||||
@@ -1147,11 +1173,20 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}
|
||||
label_main_loops_exit:
|
||||
if(num_total_loop % 2)
|
||||
// Post-process must consume the *last iteration's* sp/V buffer slot. Pre-stage
|
||||
// always writes to slot 0; the main loop alternates 1, 0, 1, ... So the last
|
||||
// written slot is determined by the number of iterations actually executed
|
||||
// (= num_total_loop - num_blocks_start), not by num_total_loop alone. Keying
|
||||
// on num_total_loop matches when num_blocks_start is even but reads
|
||||
// uninitialised sp(1)/V[1] when it is odd, which silently corrupts o_acc.
|
||||
// This previously only mattered for split-KV with an odd split boundary; SWA
|
||||
// exposes it whenever the per-Q-tile lower-bound clip is odd.
|
||||
const index_t num_iters = num_total_loop - num_blocks_start;
|
||||
if(num_iters % 2)
|
||||
{
|
||||
fmha_post_process(number<1>{});
|
||||
}
|
||||
if(!(num_total_loop % 2))
|
||||
if(!(num_iters % 2))
|
||||
{
|
||||
fmha_post_process(number<0>{});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user