[CK_TILE][FMHA] Fix sink un-mask under right-window and emit fp8bf16 batch_prefill sink kernels (#6914)

## Summary

Two related fixes to `ck_tile` FMHA so that StreamLLM-sink +
sliding-window
  batch-prefill works correctly for fp8 KV / bf16 compute.

  Review the commits in this order:

  1. `fmha: emit sink kernels for fp8bf16 batch_prefill`
Extends `example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py` so
     the fp8(KV) / bf16(QO) batch-prefill codegen also emits the
`mask=mask_enum::generic_with_sink` variant. Without this the runtime
     could not dispatch to a sink-aware kernel for the fp8bf16 path.

  2. `fmha: respect right-window in IsOutOfSinkBound`
The sink un-mask in `GenericAttentionMask::IsOutOfSinkBound` (local-mask
branch) used `(i_y + x) > 1` as the gate, which conditioned on the row
     index instead of the column index. As a result, queries `1..sink-1`
could attend to *future* sink positions (violating causal /
right-window),
while query `0` fell back to the plain causal mask. The fix replaces the
guard with `i_x < i_y + x` so every query only sees sink columns up to
     its own right-window boundary.

  3. `fmha: clarify IsOutOfSinkBound predicate comment`
Doc-only follow-up that rewrites the comment above the predicate as a
     clause-by-clause explanation (`i_x < sink`, `i_x < i_y + x`,
     `y < y_total`, `i_y < x_total`).

  ## Test plan

- [x] Repro on aiter `op_tests/test_batch_prefill.py` (fp8 +
bf16_dequant
        modes with `sink=4`, `win_left=1023`, `softcap=0.0`, `sal=True`)
        now passes for all parametrized shapes.
- [x] Existing fp16/bf16 batch-prefill paths (no sink) unchanged —
codegen
diff only adds the `generic_with_sink` variant for fp8bf16; existing
        kernel object lists unaffected.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: fengjunda.aml <fengjunda.aml@bytedance.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: root <root@smci350-rck-g03-f12-31.rck.dcgpu>
This commit is contained in:
Linjun-AMD
2026-05-07 10:39:55 +08:00
committed by GitHub
parent 0398b864c3
commit 1cf336d87a
2 changed files with 30 additions and 5 deletions

View File

@@ -715,14 +715,20 @@ class KernelComponentFactory:
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
# sink tokens are only meaningful when masking is enabled;
# skip the sink="t" + nomask combinations to avoid emitting
# kernels that can never be dispatched.
if sink == "t" and mask in ("no", "s_no"):
continue
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout/sink kernels
# no need lse/dropout kernels (sink is supported via kHasSink)
for (
logits,
qscale,
mask,
bias,
sink,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
@@ -730,10 +736,16 @@ class KernelComponentFactory:
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
# sink tokens are only meaningful when masking is enabled;
# skip the sink="t" + nomask combinations to avoid emitting
# kernels that can never be dispatched.
if sink == "t" and mask in ("no", "s_no"):
continue
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines

View File

@@ -242,16 +242,27 @@ struct GenericAttentionMask
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
// Sink un-mask predicate, clause by clause:
// i_x < sink : the column lives inside the StreamLLM sink prefix.
// i_x < i_y + x : the column is not in the masked-out future of the
// window (= < x_end modulo the min with x_total);
// without this, queries <= sink-1 would be allowed
// to look at later sink columns/positions than they
// should under causality / right-window.
// y < y_total : the local window doesn't already span everything
// (otherwise sink un-mask is meaningless).
// i_y < x_total : the query row is in-range vs. the key sequence
// (handles seqlen_q > seqlen_k padding).
if constexpr(IsLocal)
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end;
}
else
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total)
return false;
else
return i_x >= x_end || i_y >= y_total;
@@ -498,7 +509,9 @@ struct SimplifiedGenericAttentionMask
return i_x >= x_total;
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
// See note in the local-mask IsOutOfSinkBound: the sink column i_x is
// only valid up to the right-window boundary i_y + x.
if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end || i_y >= y_total;