mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6914 (commit b791478)
[CK_TILE][FMHA] Fix sink un-mask under right-window and emit fp8bf16 batch_prefill sink kernels (#6914) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Two related fixes to `ck_tile` FMHA so that StreamLLM-sink + sliding-window batch-prefill works correctly for fp8 KV / bf16 compute. Review the commits in this order: 1. `fmha: emit sink kernels for fp8bf16 batch_prefill` Extends `example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py` so the fp8(KV) / bf16(QO) batch-prefill codegen also emits the `mask=mask_enum::generic_with_sink` variant. Without this the runtime could not dispatch to a sink-aware kernel for the fp8bf16 path. 2. `fmha: respect right-window in IsOutOfSinkBound` The sink un-mask in `GenericAttentionMask::IsOutOfSinkBound` (local-mask branch) used `(i_y + x) > 1` as the gate, which conditioned on the row index instead of the column index. As a result, queries `1..sink-1` could attend to *future* sink positions (violating causal / right-window), while query `0` fell back to the plain causal mask. The fix replaces the guard with `i_x < i_y + x` so every query only sees sink columns up to its own right-window boundary. 3. `fmha: clarify IsOutOfSinkBound predicate comment` Doc-only follow-up that rewrites the comment above the predicate as a clause-by-clause explanation (`i_x < sink`, `i_x < i_y + x`, `y < y_total`, `i_y < x_total`). ## Test plan - [x] Repro on aiter `op_tests/test_batch_prefill.py` (fp8 + bf16_dequant modes with `sink=4`, `win_left=1023`, `softcap=0.0`, `sal=True`) now passes for all parametrized shapes. - [x] Existing fp16/bf16 batch-prefill paths (no sink) unchanged — codegen diff only adds the `generic_with_sink` variant for fp8bf16; existing kernel object lists unaffected. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
207a95d5e4
commit
33b62ed087
@@ -715,14 +715,20 @@ class KernelComponentFactory:
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
# sink tokens are only meaningful when masking is enabled;
|
||||
# skip the sink="t" + nomask combinations to avoid emitting
|
||||
# kernels that can never be dispatched.
|
||||
if sink == "t" and mask in ("no", "s_no"):
|
||||
continue
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout/sink kernels
|
||||
# no need lse/dropout kernels (sink is supported via kHasSink)
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
sink,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
@@ -730,10 +736,16 @@ class KernelComponentFactory:
|
||||
["pertensor", "kv_blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
# sink tokens are only meaningful when masking is enabled;
|
||||
# skip the sink="t" + nomask combinations to avoid emitting
|
||||
# kernels that can never be dispatched.
|
||||
if sink == "t" and mask in ("no", "s_no"):
|
||||
continue
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
Reference in New Issue
Block a user