From 3979378e9647af036f572872d16a972055322d9a Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Fri, 30 Aug 2024 15:38:09 +0800 Subject: [PATCH] [CK_TILE] float -> bf16 inline asm rtn (#1482) * asm rtn * add asm rtn macro * reorder macro --------- Co-authored-by: carlushuang [ROCm/composable_kernel commit: b8addae29357f1aaf69321ffe097e72698edaf02] --- include/ck_tile/core/config.hpp | 1 + include/ck_tile/core/numeric/bfloat16.hpp | 34 +++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index ee47d136d0..a08c4f3811 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -46,6 +46,7 @@ #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 +#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3 #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 4fdf8f9dae..5f4b64466e 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -17,6 +17,7 @@ enum class bf16_rounding_mode standard = 0, // rtn truncate_with_nan, truncate, + standard_asm, }; template > 16); } +CK_TILE_HOST +constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); } + +CK_TILE_DEVICE +uint16_t float_to_bf16_rtn_asm(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + + static constexpr uint32_t FP32_NAN = 0x7fff0000; + static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff; + + using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); + uint32x2_t check_nan; + uint32_t tmp; + asm volatile("\n \ + v_cmp_u_f32 %0, %2, %2 \n \ + v_bfe_u32 %1, %2, 16, 1 \n \ + v_add3_u32 %1, %2, %1, %3 \n \ + v_cndmask_b32 %2, %1, %4, %0 \n \ + v_lshrrev_b32 %2, 16, %2 \n \ + " + : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32) + : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN)); + + return uint16_t(u.int32); +} + // Truncate instead of rounding, preserving SNaN CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_truc_nan_raw(float f) @@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant