[rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d)

CK][fmha] Add StreamLLM sink support to batch_prefill
 pipeline (#6479)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The existing paged-KV attention pipelines (pagedkv, splitkv) support
  StreamLLM-style sink tokens — a fixed set of initial tokens kept in
  attention alongside the sliding window. The `batch_prefill` pipeline
  (chunked-prefill with VLLM-style block tables) previously hardcoded
  `kHasSink = false`, making it incompatible with sink-based attention
  patterns in LLM serving scenarios.

  This PR extends `batch_prefill` to support `kHasSink` and wires it
into `fmha_fwd_runner` for validation against the existing CPU
reference.

## Technical Details

 **Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`):
- When `kHasSink`, the K/V loop splits into a sink phase [0,
sink_seq_end)
and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv.
  - K advance at the sink→window transition jumps
    `seqlen_k_start - sink_seq_end + kN0` to bridge the gap.
- V scatter-gather offsets are re-initialized at the transition to fix a
window mismatch bug: V was lagging kN0 behind K after the large jump,
    loading from the wrong sequence position.
- Bias window, dropout seq_offset, and mask type (LogitsSinkMask)
updated
    for sink-awareness.

**Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`,
`fmha_batch_prefill.py`):
- `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded
`false`).
- Codegen adds `F_sink` field; skips batch-mode kernels (group mode
required).
  - CMake test filter broadened from 9 → 33 instances covering
    fp16/bf16 × mask/nmask × lse/nlse × sink/nsink.

  **Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`):
  - `fmha_batch_prefill()` dispatched from `run_fwd` when:
    group mode + paged KV + num_splits == 1.
- K/V strides corrected for runner's [num_pages, nhead_k,
page_block_size, hdim] layout.
  - `page_block_size % 128` check relaxed: batch_prefill supports ps=16.
  - CPU reference paged-KV reordering guards extended with
    `CK_TILE_FMHA_FWD_BATCH_PREFILL_API`.

## Test Plan

Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run
  `tile_example_fmha_fwd` in group mode with page_block_size=16.

  Test matrix:
  - Mask: no-mask, causal, sliding window
  - Sink: nsink, sink=1..128
  - dtype: fp16, bf16
  - LSE output: on/off
  - seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024}
  - GQA, chunked prefill, large batch×seqlen
  - page_block_size: 16, 32

## Test Result

171 test cases, all valid:y:
  - nmask + nsink: ✓
  - causal + nsink: ✓
  - causal + sink=8: ✓
  - sliding window + sink=8 (d=128, d=256): ✓
  - bf16, LSE output, GQA: ✓

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Linjun-AMD
2026-04-21 11:05:12 +00:00
committed by assistant-librarian[bot]
parent b75afb4274
commit d22aafb48b
7 changed files with 261 additions and 59 deletions

View File

@@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.sink_ptr != nullptr
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
: -numeric<float>::infinity();
const index_t seqlen_k = [&]() {
// WA i_batch capture structure binding before c++20
const index_t seqlen_k = [&, i_batch_ = i_batch]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
const int32_t page_start = kargs.page_table.kv_indptr[i_batch_];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1];
const int32_t num_page_blocks = page_end - page_start;
const int32_t last_page_len = [&]() {
if constexpr(kPageBlockSize == 1)
return static_cast<int32_t>(kPageBlockSize);
else
return kargs.page_table.kv_last_page_lens[i_batch];
return kargs.page_table.kv_last_page_lens[i_batch_];
}();
return num_page_blocks > 0
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
@@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
if(kargs.page_table.seqlen_k_ptr != nullptr)
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch_]);
else
return kargs.seqlen_k;
}
}();
const int32_t* page_idx = [&]() {
// WA i_batch capture structure binding before c++20
const int32_t* page_idx = [&, i_batch_ = i_batch]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_];
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
return kargs.page_table.block_table_ptr +
static_cast<long_index_t>(i_batch) *
static_cast<long_index_t>(i_batch_) *
kargs.page_table.batch_stride_block_table;
}
}();

View File

@@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
@@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
auto k_coord = k_dist.calculate_index();
@@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
// kPageBlockSize < kN0: global offset, must fit int32
statically_indexed_array<index_t, NRepeat> k_offsets;
index_t current_seq_k = seqlen_k_start;
index_t current_seq_k = kv_load_start;
// Load physical pages first, then compute offsets.
// k_physical_pages can be reused for descale lookup later.
@@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
@@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
v_dist,
v_offsets,
number<1>{}, // HsGatherDim
@@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == num_sink_loop - 1)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(
std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
index_t seq_offset = [&]() {
if constexpr(kHasSink)
{
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window,
{0, seqlen_k_start - sink_seq_end});
return in_sink_phase
? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}
else
return seqlen_k_start + i_total_loops * kN0;
}();
dropout
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
@@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
current_seq_k += kN0;
// For sink: after the last sink tile, jump K/V to seqlen_k_start;
// otherwise advance by one normal tile.
const index_t k_advance = [&]() -> index_t {
if constexpr(kHasSink)
return (i_total_loops == num_sink_loop)
? (seqlen_k_start - sink_seq_end + kN0)
: kN0;
else
return kN0;
}();
current_seq_k += k_advance;
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
move_tile_window(k_dram_block_window, {k_advance, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
// KV_BLOCKSCALE: reload physical pages for the new tile
@@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
// After sink→window transition (i_total_loops == num_sink_loop), V window
// was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance
// = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k.
if constexpr(kHasSink)
{
if(i_total_loops == num_sink_loop && num_sink_loop > 0)
{
prefetch_v_physical_pages(number<0>{});
update_v_offsets(number<0>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
}
}
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();

View File

@@ -53,6 +53,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false, /* StreamLLM sink tokens */
index_t kPageBlockSize_ = 1,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
@@ -70,7 +71,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
QScaleEnum_,
kBlockPerCu_,
kSkipMinSeqlenQ_,
false>
kHasSink_>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;