diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 8c006c09db..475631a885 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -715,14 +715,20 @@ class KernelComponentFactory: SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): + # sink tokens are only meaningful when masking is enabled; + # skip the sink="t" + nomask combinations to avoid emitting + # kernels that can never be dispatched. + if sink == "t" and mask in ("no", "s_no"): + continue pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: - # no need lse/dropout/sink kernels + # no need lse/dropout kernels (sink is supported via kHasSink) for ( logits, qscale, mask, bias, + sink, kv_memory_layout, kv_lookup_table, ) in itertools.product( @@ -730,10 +736,16 @@ class KernelComponentFactory: ["pertensor", "kv_blockscale"], get_mask_map(mask_impl).keys(), ["no"], + ["t", "f"], SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip + # sink tokens are only meaningful when masking is enabled; + # skip the sink="t" + nomask combinations to avoid emitting + # kernels that can never be dispatched. + if sink == "t" and mask in ("no", "s_no"): + continue + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 4ffb303812..134cb6acbb 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -242,16 +242,27 @@ struct GenericAttentionMask index_t x_start = -y + i_y + 1; index_t x_end = min(i_y + x, x_total); + // Sink un-mask predicate, clause by clause: + // i_x < sink : the column lives inside the StreamLLM sink prefix. + // i_x < i_y + x : the column is not in the masked-out future of the + // window (= < x_end modulo the min with x_total); + // without this, queries <= sink-1 would be allowed + // to look at later sink columns/positions than they + // should under causality / right-window. + // y < y_total : the local window doesn't already span everything + // (otherwise sink un-mask is meaningless). + // i_y < x_total : the query row is in-range vs. the key sequence + // (handles seqlen_q > seqlen_k padding). if constexpr(IsLocal) { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x < x_start || i_x >= x_end; } else { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x >= x_end || i_y >= y_total; @@ -498,7 +509,9 @@ struct SimplifiedGenericAttentionMask return i_x >= x_total; index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + // See note in the local-mask IsOutOfSinkBound: the sink column i_x is + // only valid up to the right-window boundary i_y + x. + if((i_x < sink) && (i_x < i_y + x) && (y < y_total) && i_y < x_total) return false; else return i_x < x_start || i_x >= x_end || i_y >= y_total;