mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Enable canonical-NaN BF16 conversion for FMHA on RDNA (#6253)
## Motivation
- On gfx11/gfx12, the existing float -> bf16 conversion path in FMHA
forward adds noticeable overhead and causes a meaningful performance gap
versus fp16. The asm-based path (mode 3) does not improve this on RDNA
and can perform even worse.
- In particular, on gfx12, bf16 FMHA forward can be up to ~20% slower
than the corresponding fp16 path.
- This PR reduces that gap by switching FMHA forward to a different BF16
conversion strategy based on Triton’s canonical-NaN
round-to-nearest-even behavior.
## Technical Details
- Add a new `standard_cnan` BF16 conversion mode to CK Tile.
- Implement a canonical-NaN RTN `float -> bf16` conversion path based on
the Triton implementation.
- Enable this conversion mode by default for FMHA forward builds
targeting gfx11/gfx12.
- Retune gfx11/gfx12 FMHA forward kernel selection thresholds for some
`hdim=128` cases to keep kernel selection aligned with the updated
conversion behavior.
## Test Plan
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={hdim} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
## Test Result
- all tests passed when running `test_ck_tile_fmha`
- BF16 FMHA forward performance improves by up to ~5% on gfx11.
- BF16 FMHA forward performance improves by up to ~10% on gfx12.
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -180,6 +180,34 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
set(FMHA_HAS_RDNA_TARGET OFF)
|
||||
set(FMHA_HAS_NON_RDNA_TARGET OFF)
|
||||
foreach(inst_target ${INST_TARGETS})
|
||||
if(inst_target MATCHES "^(gfx11|gfx12)")
|
||||
set(FMHA_HAS_RDNA_TARGET ON)
|
||||
else()
|
||||
set(FMHA_HAS_NON_RDNA_TARGET ON)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(FMHA_HAS_RDNA_TARGET)
|
||||
set(FMHA_FWD_RDNA_GEN_BLOBS)
|
||||
foreach(fwd_blob ${FMHA_FWD_GEN_BLOBS})
|
||||
if(fwd_blob MATCHES "_gfx1[12][^/]*\\.cpp$")
|
||||
list(APPEND FMHA_FWD_RDNA_GEN_BLOBS ${fwd_blob})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(FMHA_FWD_RDNA_GEN_BLOBS)
|
||||
set_property(SOURCE ${FMHA_FWD_RDNA_GEN_BLOBS}
|
||||
APPEND PROPERTY COMPILE_DEFINITIONS CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5)
|
||||
endif()
|
||||
|
||||
if(NOT FMHA_HAS_NON_RDNA_TARGET)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=5)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
@@ -1183,8 +1183,6 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = super().get_rules()
|
||||
|
||||
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
|
||||
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
|
||||
def check_d128_tile_pipeline(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
@@ -1215,6 +1213,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
# max_seqlen_q cutoff retuned after the bf16 standard_cnan change.
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
@@ -1278,7 +1277,8 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 8192")),
|
||||
# max_seqlen_q cutoff retuned after the bf16 standard_cnan change.
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q <= 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_CNAN 5
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD
|
||||
|
||||
@@ -22,7 +22,8 @@ enum class bf16_rounding_mode
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
rta_asm, // round to nearest away
|
||||
standard_cnan, // rtn with canonical NaN
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
@@ -226,6 +227,39 @@ uint16_t float_to_bf16_rta_asm(float f)
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr bool float_is_nan_raw(float f)
|
||||
{
|
||||
#if defined(__has_builtin) && __has_builtin(__builtin_isnan)
|
||||
return __builtin_isnan(f);
|
||||
#else
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
constexpr uint32_t exp_mask = 0x7f800000;
|
||||
constexpr uint32_t mant_mask = 0x007fffff;
|
||||
|
||||
return (bits & exp_mask) == exp_mask && (bits & mant_mask);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Round to nearest even, but canonicalize any NaN input to the canonical quiet bf16 NaN
|
||||
// (`0x7fff`). Unlike `float_to_bf16_rtn_raw`, this does not preserve signaling NaN
|
||||
// payload/state.
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_cnan_raw(float f)
|
||||
{
|
||||
#if defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && __FINITE_MATH_ONLY__)
|
||||
// Fast/finite-math can fold the NaN predicate away, so fall back to standard RTN.
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
#else
|
||||
// `-fgpu-flush-denormals-to-zero` only affects denormals, not NaN handling.
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
uint32_t tmp = (bits >> 16) & 1;
|
||||
uint32_t res = float_is_nan_raw(f) ? 0x7fff0000 : bits + tmp + 0x7fff;
|
||||
|
||||
return uint16_t(res >> 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
@@ -249,6 +283,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_cnan)
|
||||
return float_to_bf16_rtn_cnan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
|
||||
Reference in New Issue
Block a user