Files
composable_kernel/include/ck_tile/ops
juuso-oskari 4aff2fa016 CK-UA: fix no-mask multi-Q-block path — was reading too-short K prefix
The kernel's `_max_seq_prefix_len` computation unconditionally applied a
causal upper bound on the KV-tile loop:

    _max_seq_prefix_len = context_len
                        + q_block_local_idx * kBlockQ_dyn
                        + (kBlockQ_dyn - 1) + 1

Under causal masking this is the correct optimisation — a Q-block whose
largest row index is R only needs to read K[0..R] because rows beyond R
are softmax-masked to zero. Under `mask_type=0` (no mask) every Q row
must attend to all K rows, so this truncation is incorrect: every
Q-block other than the last one ends up reading a too-short prefix of K
and the resulting softmax / weighted-sum is over the wrong support.

Symptoms at sq=sk=512, hq=hk=5, d=128, bf16, no-mask:
  Q-block 0 (rows 0..255):  max diff vs fp32 attention_ref ≈ 0.25
  Q-block 1 (rows 256..511): max diff vs fp32 attention_ref ≈ 1e-3 (ULP)

The bug never showed up in the cross-impl sweeps because Triton-UA
asserts causal=True (its only supported mode) and sweep_fp8.sh forwards
that default through.

Fix: gate the truncation behind kHasMask. When kHasMask == false the
loop bound is simply `seq_len`, matching the math.

Validated against `aiter.test_mha_common.attention_ref` across:
  - MHA d={64,128} sq=sk∈{256..2048}  bf16/fp16  no-mask & causal
  - GQA-8 d=128   sq=sk∈{256..1024}   bf16       no-mask & causal
22/22 stages PASS within bf16/fp16 ULP. sweep_fp8.sh (causal) timings
unchanged — the truncation still fires for the causal kernels.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 14:29:48 +00:00
..
2026-01-13 09:21:29 -08:00