From 33b62ed0878369db891a85d743576605e62b3d1c Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 7 May 2026 02:40:45 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6914 (commit b791478) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE][FMHA] Fix sink un-mask under right-window and emit fp8bf16 batch_prefill sink kernels (#6914) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Two related fixes to `ck_tile` FMHA so that StreamLLM-sink + sliding-window batch-prefill works correctly for fp8 KV / bf16 compute. Review the commits in this order: 1. `fmha: emit sink kernels for fp8bf16 batch_prefill` Extends `example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py` so the fp8(KV) / bf16(QO) batch-prefill codegen also emits the `mask=mask_enum::generic_with_sink` variant. Without this the runtime could not dispatch to a sink-aware kernel for the fp8bf16 path. 2. `fmha: respect right-window in IsOutOfSinkBound` The sink un-mask in `GenericAttentionMask::IsOutOfSinkBound` (local-mask branch) used `(i_y + x) > 1` as the gate, which conditioned on the row index instead of the column index. As a result, queries `1..sink-1` could attend to *future* sink positions (violating causal / right-window), while query `0` fell back to the plain causal mask. The fix replaces the guard with `i_x < i_y + x` so every query only sees sink columns up to its own right-window boundary. 3. `fmha: clarify IsOutOfSinkBound predicate comment` Doc-only follow-up that rewrites the comment above the predicate as a clause-by-clause explanation (`i_x < sink`, `i_x < i_y + x`, `y < y_total`, `i_y < x_total`). ## Test plan - [x] Repro on aiter `op_tests/test_batch_prefill.py` (fp8 + bf16_dequant modes with `sink=4`, `win_left=1023`, `softcap=0.0`, `sal=True`) now passes for all parametrized shapes. - [x] Existing fp16/bf16 batch-prefill paths (no sink) unchanged — codegen diff only adds the `generic_with_sink` variant for fp8bf16; existing kernel object lists unaffected. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 16 ++++++++++++++-- .../ck_tile/ops/fmha/block/block_masking.hpp | 19 ++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 8c006c09db..475631a885 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -715,14 +715,20 @@ class KernelComponentFactory: SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): + # sink tokens are only meaningful when masking is enabled; + # skip the sink="t" + nomask combinations to avoid emitting + # kernels that can never be dispatched. + if sink == "t" and mask in ("no", "s_no"): + continue pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: - # no need lse/dropout/sink kernels + # no need lse/dropout kernels (sink is supported via kHasSink) for ( logits, qscale, mask, bias, + sink, kv_memory_layout, kv_lookup_table, ) in itertools.product( @@ -730,10 +736,16 @@ class KernelComponentFactory: ["pertensor", "kv_blockscale"], get_mask_map(mask_impl).keys(), ["no"], + ["t", "f"], SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip + # sink tokens are only meaningful when masking is enabled; + # skip the sink="t" + nomask combinations to avoid emitting + # kernels that can never be dispatched. + if sink == "t" and mask in ("no", "s_no"): + continue + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 4ffb303812..134cb6acbb 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -242,16 +242,27 @@ struct GenericAttentionMask index_t x_start = -y + i_y + 1; index_t x_end = min(i_y + x, x_total); + // Sink un-mask predicate, clause by clause: + // i_x < sink : the column lives inside the StreamLLM sink prefix. + // i_x < i_y + x : the column is not in the masked-out future of the + // window (= < x_end modulo the min with x_total); + // without this, queries <= sink-1 would be allowed + // to look at later sink columns/positions than they + // should under causality / right-window. + // y < y_total : the local window doesn't already span everything + // (otherwise sink un-mask is meaningless). + // i_y < x_total : the query row is in-range vs. the key sequence + // (handles seqlen_q > seqlen_k padding). if constexpr(IsLocal) { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x < x_start || i_x >= x_end; } else { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x >= x_end || i_y >= y_total; @@ -498,7 +509,9 @@ struct SimplifiedGenericAttentionMask return i_x >= x_total; index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + // See note in the local-mask IsOutOfSinkBound: the sink column i_x is + // only valid up to the right-window boundary i_y + x. + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x < x_start || i_x >= x_end || i_y >= y_total;