mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Fix typo in fmha_fwd_kernel K-dram unmerge tuple sizes (#7141)
## Summary
The qr_async_trload K-dram lambda's `else (XorLengthFold == 1)` branch
in `fmha_fwd_kernel.hpp` writes the outer-tile dim of its 3-tuple
unmerge/xor/merge as
```cpp
number<FmhaPipeline::kQKHeaddim / kDramTileK / FmhaPipeline::kAlignmentK>{}
```
which divides one extra time. For every fp16/bf16 hdim=128 configuration
the outer length collapses to **0**, e.g. `128 / 128 / 8 == 0`. The
3-tuple product no longer equals `kQKHeaddim`, so unmerge → xor → merge
stops round-tripping the head dimension.
This bug was masked by the async-load path: it only walks the descriptor
via stride and silently absorbs a length=0 outer dim. Any consumer that
actually traverses the descriptor (e.g. the TDM path on gfx1250)
immediately faults on the resulting `tuple<int, constant<0>>`.
The fix drops the extra `/ kAlignmentK` in all three call sites in the
same lambda so the outer dim becomes `kQKHeaddim / kDramTileK` and the
product is restored to `kQKHeaddim`. Strides are unaffected, so the
async path is bit-identical.
| Config (fp16/bf16) | hdim | kDramTileK | kAlignmentK | a (typo) | a
(fixed) | product (typo) | product (fixed) |
|---|---|---|---|---|---|---|---|
| hdim128, kKLoadOnce | 128 | 128 | 8 | 0 | 1 | **0** | **128** |
| hdim128, kK0=32 | 128 | 32 | 8 | 0 | 4 | **0** | **128** |
| hdim64, kKLoadOnce | 64 | 64 | 8 | 0 | 1 | **0** | **64** |
| hdim256, kK0=32 | 256 | 32 | 8 | 1 | 8 | **32** | **256** |
Bug introduced in 2cc0af6a815a (PR #2888 \"[CK_TILE] FMHA FWD bug
fix\"), where the original 2-tuple unmerge was generalized to a 3-tuple
and the typo slipped in.
## Test plan
- [x] Built `test_ck_tile_fmha_fwd` (umbrella, 5 gtest binaries) on
gfx950 native at develop b3bdc63a509 with `dev-gfx950` preset (clang 22,
ROCm 7.2.2). Compiles cleanly with `-Werror -Weverything`.
- [x] Ran `ctest -R test_ck_tile_fmha_fwd` on gfx950 native, baseline vs
patched: identical pass/fail (3 pass / 2 fail), identical failing case
set (114 gtest fails + 2 GPU memory access faults, all in pre-existing
fp16/bf16 group-mode `Alibi`/`Dropout` cases that reproduce on develop
without this patch). Total wall time 403s → 393s. Per-case latency drift
±8% (noise).
- [x] CI to verify on other gfx9 / gfx11 architectures.
This commit is contained in:
@@ -2689,8 +2689,7 @@ struct FmhaFwdKernel
|
||||
k_dram_pad,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
@@ -2702,8 +2701,7 @@ struct FmhaFwdKernel
|
||||
make_xor_transform(make_tuple(
|
||||
height, number<kDramTileK / FmhaPipeline::kAlignmentK>{})),
|
||||
make_pass_through_transform(
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{}),
|
||||
number<FmhaPipeline::kQKHeaddim / kDramTileK>{}),
|
||||
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
@@ -2712,8 +2710,7 @@ struct FmhaFwdKernel
|
||||
k_dram_permuted,
|
||||
make_tuple(make_pass_through_transform(height),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
|
||||
FmhaPipeline::kAlignmentK>{},
|
||||
make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK>{},
|
||||
number<kDramTileK / FmhaPipeline::kAlignmentK>{},
|
||||
number<FmhaPipeline::kAlignmentK>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
|
||||
|
||||
Reference in New Issue
Block a user