mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[rocm-libraries] ROCm/rocm-libraries#8424 (commit debb669)
Add missing constraint in the FMHA qr async pipeline to enforce bk0=bk1 (#8424) ## Motivation The purpose of this change is to add a guardrail to what values bk0 and bk1 can take. This is to avoid ill defined sizes, silently failing and generating NaN (or other error) at runtime. An example of such failure can be obtained using the tile engine: ``` cd rocm-libraries/projects/composablekernel/tile_engine/ops/fmha python fmha_benchmark.py configs/batch_prefill.json \ --problems "1,4,1,8000,8000,256" \ --filter "c.data_type=='bf16' and c.hdim_q==256 and c.pipeline=='qr_async' and c.mode=='group' and c.tile_n0==32 and c.tile_k0==64" ``` ## Technical Details The qr_async pipeline stages data in the K dimensions into LDS using a bk1-descriptor, while the (Q*K^T) gemm0 consumes bk0 ## Test Plan See command above ## Test Result Before the change: (invalid) generate instances, error at runtime After this change: no instance generated ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1b649a8d4b
commit
2c0b7cbb0a
@@ -2396,6 +2396,12 @@ def _expand_batch_prefill(
|
||||
bp_specs = get_batch_prefill_pipelines(dtype, hq, receipt)
|
||||
for tc in tiles:
|
||||
bk1 = _bp_bk1(tc.bm0, tc.bn0, tc.bk0, hq)
|
||||
|
||||
# qr_async stages K into LDS through a bk1-major descriptor while the gemm0
|
||||
# loop reads bk0 chunks, therefore the pipeline requires bk0 == bk1
|
||||
if tc.bk0 != bk1:
|
||||
continue
|
||||
|
||||
for spec in bp_specs:
|
||||
mm = _MASK_MAP.get(spec.mask, spec.mask)
|
||||
mb = _BIAS_MAP.get(spec.bias, spec.bias)
|
||||
|
||||
@@ -460,6 +460,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// K is staged into LDS through a single K-major descriptor whose depth is parameterized
|
||||
// by kK1 (see MakeKLdsStoreBlockDescriptor / GetSingleSmemElementSpaceSize in the custom
|
||||
// policy), while the gemm0 (QK^T) main loop advances the K DRAM window by kK0 and reads
|
||||
// kK0-deep K chunks from that same LDS buffer. The two only agree when kK0 == kK1; with
|
||||
// kK0 != kK1 the async copy mis-strides K into LDS and gemm0 reads past the per-buffer
|
||||
// chunk into neighboring buffers/padding, silently corrupting the result. Enforce the
|
||||
// invariant so illegal tiles (e.g. bk0=64, bk1=32) fail to compile instead of producing
|
||||
// garbage at runtime.
|
||||
static_assert(kK0 == kK1,
|
||||
"qr_ks_vs_async stages K through a bk1-major LDS descriptor; bk0 must "
|
||||
"equal bk1.");
|
||||
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
// K tile in LDS
|
||||
|
||||
Reference in New Issue
Block a user