From 2c0b7cbb0a618ed52f7d2f5263baaffa640b6500 Mon Sep 17 00:00:00 2001 From: damien-lejeune <31985270+damien-lejeune@users.noreply.github.com> Date: Tue, 16 Jun 2026 07:41:58 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#8424 (commit debb669) Add missing constraint in the FMHA qr async pipeline to enforce bk0=bk1 (#8424) ## Motivation The purpose of this change is to add a guardrail to what values bk0 and bk1 can take. This is to avoid ill defined sizes, silently failing and generating NaN (or other error) at runtime. An example of such failure can be obtained using the tile engine: ``` cd rocm-libraries/projects/composablekernel/tile_engine/ops/fmha python fmha_benchmark.py configs/batch_prefill.json \ --problems "1,4,1,8000,8000,256" \ --filter "c.data_type=='bf16' and c.hdim_q==256 and c.pipeline=='qr_async' and c.mode=='group' and c.tile_n0==32 and c.tile_k0==64" ``` ## Technical Details The qr_async pipeline stages data in the K dimensions into LDS using a bk1-descriptor, while the (Q*K^T) gemm0 consumes bk0 ## Test Plan See command above ## Test Result Before the change: (invalid) generate instances, error at runtime After this change: no instance generated ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Damien Lejeune --- dispatcher/codegen/fmha/instance_gen.py | 6 ++++++ ...ck_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/dispatcher/codegen/fmha/instance_gen.py b/dispatcher/codegen/fmha/instance_gen.py index 20536cabdf..9286e0bf20 100644 --- a/dispatcher/codegen/fmha/instance_gen.py +++ b/dispatcher/codegen/fmha/instance_gen.py @@ -2396,6 +2396,12 @@ def _expand_batch_prefill( bp_specs = get_batch_prefill_pipelines(dtype, hq, receipt) for tc in tiles: bk1 = _bp_bk1(tc.bm0, tc.bn0, tc.bk0, hq) + + # qr_async stages K into LDS through a bk1-major descriptor while the gemm0 + # loop reads bk0 chunks, therefore the pipeline requires bk0 == bk1 + if tc.bk0 != bk1: + continue + for spec in bp_specs: mm = _MASK_MAP.get(spec.mask, spec.mask) mb = _BIAS_MAP.get(spec.bias, spec.bias) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 68f54662d4..ff301346fa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -460,6 +460,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + // K is staged into LDS through a single K-major descriptor whose depth is parameterized + // by kK1 (see MakeKLdsStoreBlockDescriptor / GetSingleSmemElementSpaceSize in the custom + // policy), while the gemm0 (QK^T) main loop advances the K DRAM window by kK0 and reads + // kK0-deep K chunks from that same LDS buffer. The two only agree when kK0 == kK1; with + // kK0 != kK1 the async copy mis-strides K into LDS and gemm0 reads past the per-buffer + // chunk into neighboring buffers/padding, silently corrupting the result. Enforce the + // invariant so illegal tiles (e.g. bk0=64, bk1=32) fail to compile instead of producing + // garbage at runtime. + static_assert(kK0 == kK1, + "qr_ks_vs_async stages K through a bk1-major LDS descriptor; bk0 must " + "equal bk1."); + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); // K tile in LDS