mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
## Motivation
- On gfx11/gfx12, the existing float -> bf16 conversion path in FMHA
forward adds noticeable overhead and causes a meaningful performance gap
versus fp16. The asm-based path (mode 3) does not improve this on RDNA
and can perform even worse.
- In particular, on gfx12, bf16 FMHA forward can be up to ~20% slower
than the corresponding fp16 path.
- This PR reduces that gap by switching FMHA forward to a different BF16
conversion strategy based on Triton’s canonical-NaN
round-to-nearest-even behavior.
## Technical Details
- Add a new `standard_cnan` BF16 conversion mode to CK Tile.
- Implement a canonical-NaN RTN `float -> bf16` conversion path based on
the Triton implementation.
- Enable this conversion mode by default for FMHA forward builds
targeting gfx11/gfx12.
- Retune gfx11/gfx12 FMHA forward kernel selection thresholds for some
`hdim=128` cases to keep kernel selection aligned with the updated
conversion behavior.
## Test Plan
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={hdim} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
## Test Result
- all tests passed when running `test_ck_tile_fmha`
- BF16 FMHA forward performance improves by up to ~5% on gfx11.
- BF16 FMHA forward performance improves by up to ~10% on gfx12.
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.