mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#7141 (commit 37e40c3)
[CK_TILE] Fix typo in fmha_fwd_kernel K-dram unmerge tuple sizes (#7141) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The qr_async_trload K-dram lambda's `else (XorLengthFold == 1)` branch in `fmha_fwd_kernel.hpp` writes the outer-tile dim of its 3-tuple unmerge/xor/merge as ```cpp number<FmhaPipeline::kQKHeaddim / kDramTileK / FmhaPipeline::kAlignmentK>{} ``` which divides one extra time. For every fp16/bf16 hdim=128 configuration the outer length collapses to **0**, e.g. `128 / 128 / 8 == 0`. The 3-tuple product no longer equals `kQKHeaddim`, so unmerge → xor → merge stops round-tripping the head dimension. This bug was masked by the async-load path: it only walks the descriptor via stride and silently absorbs a length=0 outer dim. Any consumer that actually traverses the descriptor (e.g. the TDM path on gfx1250) immediately faults on the resulting `tuple<int, constant<0>>`. The fix drops the extra `/ kAlignmentK` in all three call sites in the same lambda so the outer dim becomes `kQKHeaddim / kDramTileK` and the product is restored to `kQKHeaddim`. Strides are unaffected, so the async path is bit-identical. | Config (fp16/bf16) | hdim | kDramTileK | kAlignmentK | a (typo) | a (fixed) | product (typo) | product (fixed) | |---|---|---|---|---|---|---|---| | hdim128, kKLoadOnce | 128 | 128 | 8 | 0 | 1 | **0** | **128** | | hdim128, kK0=32 | 128 | 32 | 8 | 0 | 4 | **0** | **128** | | hdim64, kKLoadOnce | 64 | 64 | 8 | 0 | 1 | **0** | **64** | | hdim256, kK0=32 | 256 | 32 | 8 | 1 | 8 | **32** | **256** | Bug introduced in 2cc0af6a815a (PR #2888 \"[CK_TILE] FMHA FWD bug fix\"), where the original 2-tuple unmerge was generalized to a 3-tuple and the typo slipped in. ## Test plan - [x] Built `test_ck_tile_fmha_fwd` (umbrella, 5 gtest binaries) on gfx950 native at develop b3bdc63a509 with `dev-gfx950` preset (clang 22, ROCm 7.2.2). Compiles cleanly with `-Werror -Weverything`. - [x] Ran `ctest -R test_ck_tile_fmha_fwd` on gfx950 native, baseline vs patched: identical pass/fail (3 pass / 2 fail), identical failing case set (114 gtest fails + 2 GPU memory access faults, all in pre-existing fp16/bf16 group-mode `Alibi`/`Dropout` cases that reproduce on develop without this patch). Total wall time 403s → 393s. Per-case latency drift ±8% (noise). - [x] CI to verify on other gfx9 / gfx11 architectures.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
cc29502c28
commit
41064d8684
@@ -2661,8 +2661,7 @@ struct FmhaFwdKernel
|
||||
k_dram_pad,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
@@ -2674,8 +2673,7 @@ struct FmhaFwdKernel
|
||||
make_xor_transform(make_tuple(
|
||||
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{}),
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK>{}),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
@@ -2684,8 +2682,7 @@ struct FmhaFwdKernel
|
||||
k_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
|
||||
|
||||
Reference in New Issue
Block a user