diff --git a/dispatcher/codegen/fmha/instance_gen.py b/dispatcher/codegen/fmha/instance_gen.py index 20536cabdf..9286e0bf20 100644 --- a/dispatcher/codegen/fmha/instance_gen.py +++ b/dispatcher/codegen/fmha/instance_gen.py @@ -2396,6 +2396,12 @@ def _expand_batch_prefill( bp_specs = get_batch_prefill_pipelines(dtype, hq, receipt) for tc in tiles: bk1 = _bp_bk1(tc.bm0, tc.bn0, tc.bk0, hq) + + # qr_async stages K into LDS through a bk1-major descriptor while the gemm0 + # loop reads bk0 chunks, therefore the pipeline requires bk0 == bk1 + if tc.bk0 != bk1: + continue + for spec in bp_specs: mm = _MASK_MAP.get(spec.mask, spec.mask) mb = _BIAS_MAP.get(spec.bias, spec.bias) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 68f54662d4..ff301346fa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -460,6 +460,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + // K is staged into LDS through a single K-major descriptor whose depth is parameterized + // by kK1 (see MakeKLdsStoreBlockDescriptor / GetSingleSmemElementSpaceSize in the custom + // policy), while the gemm0 (QK^T) main loop advances the K DRAM window by kK0 and reads + // kK0-deep K chunks from that same LDS buffer. The two only agree when kK0 == kK1; with + // kK0 != kK1 the async copy mis-strides K into LDS and gemm0 reads past the per-buffer + // chunk into neighboring buffers/padding, silently corrupting the result. Enforce the + // invariant so illegal tiles (e.g. bk0=64, bk1=32) fail to compile instead of producing + // garbage at runtime. + static_assert(kK0 == kK1, + "qr_ks_vs_async stages K through a bk1-major LDS descriptor; bk0 must " + "equal bk1."); + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); // K tile in LDS