From fd1060f6fe09695ed0ba75aa2a1897bcb3067c1d Mon Sep 17 00:00:00 2001 From: Hosang Yoon <156028780+hyoon1@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:52:24 -0400 Subject: [PATCH] [CK_TILE] Enable canonical-NaN BF16 conversion for FMHA on RDNA (#6253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation - On gfx11/gfx12, the existing float -> bf16 conversion path in FMHA forward adds noticeable overhead and causes a meaningful performance gap versus fp16. The asm-based path (mode 3) does not improve this on RDNA and can perform even worse. - In particular, on gfx12, bf16 FMHA forward can be up to ~20% slower than the corresponding fp16 path. - This PR reduces that gap by switching FMHA forward to a different BF16 conversion strategy based on Triton’s canonical-NaN round-to-nearest-even behavior. ## Technical Details - Add a new `standard_cnan` BF16 conversion mode to CK Tile. - Implement a canonical-NaN RTN `float -> bf16` conversion path based on the Triton implementation. - Enable this conversion mode by default for FMHA forward builds targeting gfx11/gfx12. - Retune gfx11/gfx12 FMHA forward kernel selection thresholds for some `hdim=128` cases to keep kernel selection aligned with the updated conversion behavior. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={hdim} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result - all tests passed when running `test_ck_tile_fmha` - BF16 FMHA forward performance improves by up to ~5% on gfx11. - BF16 FMHA forward performance improves by up to ~10% on gfx12. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/01_fmha/CMakeLists.txt | 28 ++++++++++++++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +-- include/ck_tile/core/config.hpp | 1 + include/ck_tile/core/numeric/bfloat16.hpp | 38 ++++++++++++++++++- 4 files changed, 69 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e..fca8374f3b 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -180,6 +180,34 @@ if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +set(FMHA_HAS_RDNA_TARGET OFF) +set(FMHA_HAS_NON_RDNA_TARGET OFF) +foreach(inst_target ${INST_TARGETS}) + if(inst_target MATCHES "^(gfx11|gfx12)") + set(FMHA_HAS_RDNA_TARGET ON) + else() + set(FMHA_HAS_NON_RDNA_TARGET ON) + endif() +endforeach() + +if(FMHA_HAS_RDNA_TARGET) + set(FMHA_FWD_RDNA_GEN_BLOBS) + foreach(fwd_blob ${FMHA_FWD_GEN_BLOBS}) + if(fwd_blob MATCHES "_gfx1[12][^/]*\\.cpp$") + list(APPEND FMHA_FWD_RDNA_GEN_BLOBS ${fwd_blob}) + endif() + endforeach() + + if(FMHA_FWD_RDNA_GEN_BLOBS) + set_property(SOURCE ${FMHA_FWD_RDNA_GEN_BLOBS} + APPEND PROPERTY COMPILE_DEFINITIONS CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5) + endif() + + if(NOT FMHA_HAS_NON_RDNA_TARGET) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5) + endif() +endif() + # use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 978c9d0a75..542bf2f2fa 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1183,8 +1183,6 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): def get_rules(cls) -> List[CompatibilityRule]: rules = super().get_rules() - # For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile: - # the exact-hdim variant (dpad=dvpad=f) is much slower here. def check_d128_tile_pipeline( problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> bool: @@ -1215,6 +1213,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + # max_seqlen_q cutoff retuned after the bf16 standard_cnan change. (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")), FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -1278,7 +1277,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")), + # max_seqlen_q cutoff retuned after the bf16 standard_cnan change. + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 4096")), FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 06220d2780..ba195427be 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -74,6 +74,7 @@ #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3 #define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4 +#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_CNAN 5 #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 3508c0705e..226115df66 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -22,7 +22,8 @@ enum class bf16_rounding_mode truncate_with_nan, truncate, standard_asm, - rta_asm, // round to nearest away + rta_asm, // round to nearest away + standard_cnan, // rtn with canonical NaN }; template (f); + constexpr uint32_t exp_mask = 0x7f800000; + constexpr uint32_t mant_mask = 0x007fffff; + + return (bits & exp_mask) == exp_mask && (bits & mant_mask); +#endif +} + +// Round to nearest even, but canonicalize any NaN input to the canonical quiet bf16 NaN +// (`0x7fff`). Unlike `float_to_bf16_rtn_raw`, this does not preserve signaling NaN +// payload/state. +CK_TILE_HOST_DEVICE +constexpr uint16_t float_to_bf16_rtn_cnan_raw(float f) +{ +#if defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && __FINITE_MATH_ONLY__) + // Fast/finite-math can fold the NaN predicate away, so fall back to standard RTN. + return float_to_bf16_rtn_raw(f); +#else + // `-fgpu-flush-denormals-to-zero` only affects denormals, not NaN handling. + uint32_t bits = bit_cast(f); + uint32_t tmp = (bits >> 16) & 1; + uint32_t res = float_is_nan_raw(f) ? 0x7fff0000 : bits + tmp + 0x7fff; + + return uint16_t(res >> 16); +#endif +} + // Truncate instead of rounding, preserving SNaN CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_nan_raw(float f) @@ -249,6 +283,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant