mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[CK_TILE] Enable canonical-NaN BF16 conversion for FMHA on RDNA (#6253) ## Motivation - On gfx11/gfx12, the existing float -> bf16 conversion path in FMHA forward adds noticeable overhead and causes a meaningful performance gap versus fp16. The asm-based path (mode 3) does not improve this on RDNA and can perform even worse. - In particular, on gfx12, bf16 FMHA forward can be up to ~20% slower than the corresponding fp16 path. - This PR reduces that gap by switching FMHA forward to a different BF16 conversion strategy based on Triton’s canonical-NaN round-to-nearest-even behavior. ## Technical Details - Add a new `standard_cnan` BF16 conversion mode to CK Tile. - Implement a canonical-NaN RTN `float -> bf16` conversion path based on the Triton implementation. - Enable this conversion mode by default for FMHA forward builds targeting gfx11/gfx12. - Retune gfx11/gfx12 FMHA forward kernel selection thresholds for some `hdim=128` cases to keep kernel selection aligned with the updated conversion behavior. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={hdim} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result - all tests passed when running `test_ck_tile_fmha` - BF16 FMHA forward performance improves by up to ~5% on gfx11. - BF16 FMHA forward performance improves by up to ~10% on gfx12. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.