From 41064d8684132c9b7332657aca510d3de9039c4d Mon Sep 17 00:00:00 2001 From: Yi DING <28386673+DDEle@users.noreply.github.com> Date: Fri, 8 May 2026 08:51:33 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7141 (commit 37e40c3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [CK_TILE] Fix typo in fmha_fwd_kernel K-dram unmerge tuple sizes (#7141) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The qr_async_trload K-dram lambda's `else (XorLengthFold == 1)` branch in `fmha_fwd_kernel.hpp` writes the outer-tile dim of its 3-tuple unmerge/xor/merge as ```cpp number{} ``` which divides one extra time. For every fp16/bf16 hdim=128 configuration the outer length collapses to **0**, e.g. `128 / 128 / 8 == 0`. The 3-tuple product no longer equals `kQKHeaddim`, so unmerge → xor → merge stops round-tripping the head dimension. This bug was masked by the async-load path: it only walks the descriptor via stride and silently absorbs a length=0 outer dim. Any consumer that actually traverses the descriptor (e.g. the TDM path on gfx1250) immediately faults on the resulting `tuple>`. The fix drops the extra `/ kAlignmentK` in all three call sites in the same lambda so the outer dim becomes `kQKHeaddim / kDramTileK` and the product is restored to `kQKHeaddim`. Strides are unaffected, so the async path is bit-identical. | Config (fp16/bf16) | hdim | kDramTileK | kAlignmentK | a (typo) | a (fixed) | product (typo) | product (fixed) | |---|---|---|---|---|---|---|---| | hdim128, kKLoadOnce | 128 | 128 | 8 | 0 | 1 | **0** | **128** | | hdim128, kK0=32 | 128 | 32 | 8 | 0 | 4 | **0** | **128** | | hdim64, kKLoadOnce | 64 | 64 | 8 | 0 | 1 | **0** | **64** | | hdim256, kK0=32 | 256 | 32 | 8 | 1 | 8 | **32** | **256** | Bug introduced in 2cc0af6a815a (PR #2888 \"[CK_TILE] FMHA FWD bug fix\"), where the original 2-tuple unmerge was generalized to a 3-tuple and the typo slipped in. ## Test plan - [x] Built `test_ck_tile_fmha_fwd` (umbrella, 5 gtest binaries) on gfx950 native at develop b3bdc63a509 with `dev-gfx950` preset (clang 22, ROCm 7.2.2). Compiles cleanly with `-Werror -Weverything`. - [x] Ran `ctest -R test_ck_tile_fmha_fwd` on gfx950 native, baseline vs patched: identical pass/fail (3 pass / 2 fail), identical failing case set (114 gtest fails + 2 GPU memory access faults, all in pre-existing fp16/bf16 group-mode `Alibi`/`Dropout` cases that reproduce on develop without this patch). Total wall time 403s → 393s. Per-case latency drift ±8% (noise). - [x] CI to verify on other gfx9 / gfx11 architectures. --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index cf7f9a270e..fcb73c48b7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -2661,8 +2661,7 @@ struct FmhaFwdKernel k_dram_pad, make_tuple(make_pass_through_transform(height), make_unmerge_transform( - make_tuple(number{}, + make_tuple(number{}, number{}, number{}))), make_tuple(sequence<0>{}, sequence<1>{}), @@ -2674,8 +2673,7 @@ struct FmhaFwdKernel make_xor_transform(make_tuple( height, number{})), make_pass_through_transform( - number{}), + number{}), make_pass_through_transform(number{})), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); @@ -2684,8 +2682,7 @@ struct FmhaFwdKernel k_dram_permuted, make_tuple(make_pass_through_transform(height), make_merge_transform_v3_division_mod( - make_tuple(number{}, + make_tuple(number{}, number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),